From 57dabba0a3a41ab2059503cb3417beffbee6516f Mon Sep 17 00:00:00 2001 From: Iaroslav Postovalov Date: Mon, 4 Apr 2022 18:43:20 +0700 Subject: [PATCH] First steps in applying context receivers to operator extension functions --- buildSrc/gradle.properties | 2 +- .../space/kscience/kmath/complex/Complex.kt | 3 +- .../kscience/kmath/complex/Quaternion.kt | 2 +- .../FunctionalExpressionAlgebra.kt | 4 +- .../kscience/kmath/expressions/MstAlgebra.kt | 15 ++---- .../kmath/expressions/SimpleAutoDiff.kt | 4 +- .../kscience/kmath/linear/LinearSpace.kt | 5 +- .../kscience/kmath/nd/BufferAlgebraND.kt | 2 +- .../space/kscience/kmath/nd/DoubleFieldND.kt | 4 +- .../kscience/kmath/operations/Algebra.kt | 50 ++++++++++--------- .../space/kscience/kmath/operations/BigInt.kt | 2 +- .../kmath/operations/BufferAlgebra.kt | 4 +- .../kmath/operations/DoubleBufferOps.kt | 2 +- .../kscience/kmath/operations/numbers.kt | 30 ++++++----- .../kscience/kmath/operations/BigNumbers.kt | 6 +-- .../kscience/kmath/functions/Polynomial.kt | 6 +-- .../kmath/geometry/Euclidean2DSpace.kt | 2 +- .../kscience/kmath/jafama/KMathJafama.kt | 2 +- .../kmath/multik/MultikTensorAlgebra.kt | 3 +- .../kscience/kmath/stat/SamplerAlgebra.kt | 2 +- .../kmath/tensorflow/TensorFlowAlgebra.kt | 2 +- .../kmath/tensors/api/TensorAlgebra.kt | 14 +++--- 22 files changed, 81 insertions(+), 85 deletions(-) diff --git a/buildSrc/gradle.properties b/buildSrc/gradle.properties index a0b05e812..906db76f9 100644 --- a/buildSrc/gradle.properties +++ b/buildSrc/gradle.properties @@ -4,4 +4,4 @@ # kotlin.code.style=official -toolsVersion=0.11.2-kotlin-1.6.10 +toolsVersion=0.11.2-kotlin-1.6.20 diff --git a/kmath-complex/src/commonMain/kotlin/space/kscience/kmath/complex/Complex.kt b/kmath-complex/src/commonMain/kotlin/space/kscience/kmath/complex/Complex.kt index 77fe782a9..8f73005cf 100644 --- a/kmath-complex/src/commonMain/kotlin/space/kscience/kmath/complex/Complex.kt +++ b/kmath-complex/src/commonMain/kotlin/space/kscience/kmath/complex/Complex.kt @@ -71,14 +71,13 @@ public object ComplexField : */ public val i: Complex by lazy { Complex(0.0, 1.0) } - override fun Complex.unaryMinus(): Complex = Complex(-re, -im) override fun number(value: Number): Complex = Complex(value.toDouble(), 0.0) override fun scale(a: Complex, value: Double): Complex = Complex(a.re * value, a.im * value) override fun add(left: Complex, right: Complex): Complex = Complex(left.re + right.re, left.im + right.im) -// override fun multiply(a: Complex, k: Number): Complex = Complex(a.re * k.toDouble(), a.im * k.toDouble()) + override fun negate(arg: Complex): Complex = Complex(-arg.re, -arg.im) override fun multiply(left: Complex, right: Complex): Complex = Complex(left.re * right.re - left.im * right.im, left.re * right.im + left.im * right.re) diff --git a/kmath-complex/src/commonMain/kotlin/space/kscience/kmath/complex/Quaternion.kt b/kmath-complex/src/commonMain/kotlin/space/kscience/kmath/complex/Quaternion.kt index 3ef3428c6..545f3041b 100644 --- a/kmath-complex/src/commonMain/kotlin/space/kscience/kmath/complex/Quaternion.kt +++ b/kmath-complex/src/commonMain/kotlin/space/kscience/kmath/complex/Quaternion.kt @@ -169,7 +169,7 @@ public object QuaternionField : Field, Norm, override operator fun Number.times(arg: Quaternion): Quaternion = Quaternion(toDouble() * arg.w, toDouble() * arg.x, toDouble() * arg.y, toDouble() * arg.z) - override fun Quaternion.unaryMinus(): Quaternion = Quaternion(-w, -x, -y, -z) + override fun negate(arg: Quaternion): Quaternion = Quaternion(-arg.w, -arg.x, -arg.y, -arg.z) override fun norm(arg: Quaternion): Quaternion = sqrt(arg.conjugate * arg) override fun bindSymbolOrNull(value: String): Quaternion? = when (value) { diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt index 68cc8e791..86be1b3dc 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt @@ -51,8 +51,8 @@ public open class FunctionalExpressionGroup>( ) : FunctionalExpressionAlgebra(algebra), Group> { override val zero: Expression get() = const(algebra.zero) - override fun Expression.unaryMinus(): Expression = - unaryOperation(GroupOps.MINUS_OPERATION, this) + override fun negate(arg: Expression): Expression = + unaryOperation(GroupOps.MINUS_OPERATION, arg) /** * Builds an Expression of addition of two another expressions. diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/MstAlgebra.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/MstAlgebra.kt index 4bd2a6c53..a1f65e1f1 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/MstAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/MstAlgebra.kt @@ -32,11 +32,9 @@ public object MstGroup : Group, NumericAlgebra, ScaleOperations { override fun number(value: Number): MST.Numeric = MstNumericAlgebra.number(value) override fun bindSymbolOrNull(value: String): Symbol = MstNumericAlgebra.bindSymbolOrNull(value) override fun add(left: MST, right: MST): MST.Binary = binaryOperationFunction(GroupOps.PLUS_OPERATION)(left, right) - override operator fun MST.unaryPlus(): MST.Unary = - unaryOperationFunction(GroupOps.PLUS_OPERATION)(this) - override operator fun MST.unaryMinus(): MST.Unary = - unaryOperationFunction(GroupOps.MINUS_OPERATION)(this) + override fun negate(arg: MST): MST.Unary = + unaryOperationFunction(GroupOps.MINUS_OPERATION)(arg) override operator fun MST.minus(arg: MST): MST.Binary = binaryOperationFunction(GroupOps.MINUS_OPERATION)(this, arg) @@ -70,8 +68,7 @@ public object MstRing : Ring, NumbersAddOps, ScaleOperations { override fun multiply(left: MST, right: MST): MST.Binary = binaryOperationFunction(RingOps.TIMES_OPERATION)(left, right) - override operator fun MST.unaryPlus(): MST.Unary = MstGroup { +this@unaryPlus } - override operator fun MST.unaryMinus(): MST.Unary = MstGroup { -this@unaryMinus } + override fun negate(arg: MST): MST.Unary = MstGroup.negate(arg) override operator fun MST.minus(arg: MST): MST.Binary = MstGroup { this@minus - arg } override fun binaryOperationFunction(operation: String): (left: MST, right: MST) -> MST.Binary = @@ -101,8 +98,7 @@ public object MstField : Field, NumbersAddOps, ScaleOperations { override fun divide(left: MST, right: MST): MST.Binary = binaryOperationFunction(FieldOps.DIV_OPERATION)(left, right) - override operator fun MST.unaryPlus(): MST.Unary = MstRing { +this@unaryPlus } - override operator fun MST.unaryMinus(): MST.Unary = MstRing { -this@unaryMinus } + override fun negate(arg: MST): MST.Unary = MstRing.negate(arg) override operator fun MST.minus(arg: MST): MST.Binary = MstRing { this@minus - arg } override fun binaryOperationFunction(operation: String): (left: MST, right: MST) -> MST.Binary = @@ -142,8 +138,7 @@ public object MstExtendedField : ExtendedField, NumericAlgebra { override fun multiply(left: MST, right: MST): MST.Binary = MstField.multiply(left, right) override fun divide(left: MST, right: MST): MST.Binary = MstField.divide(left, right) - override operator fun MST.unaryPlus(): MST.Unary = MstField { +this@unaryPlus } - override operator fun MST.unaryMinus(): MST.Unary = MstField { -this@unaryMinus } + override fun negate(arg: MST): MST.Unary = MstField.negate(arg) override operator fun MST.minus(arg: MST): MST.Binary = MstField { this@minus - arg } override fun power(arg: MST, pow: Number): MST.Binary = diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/SimpleAutoDiff.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/SimpleAutoDiff.kt index ac8c44446..15cfac1c5 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/SimpleAutoDiff.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/SimpleAutoDiff.kt @@ -163,8 +163,8 @@ public open class SimpleAutoDiffField>( // derive(const { this@minus.value - one * b.toDouble() }) { z -> d += z.d } - override fun AutoDiffValue.unaryMinus(): AutoDiffValue = - derive(const { -value }) { z -> d -= z.d } + override fun negate(arg: AutoDiffValue): AutoDiffValue = + derive(const { -arg.value }) { z -> arg.d -= z.d } // Basic math (+, -, *, /) diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/LinearSpace.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/LinearSpace.kt index 715fad07b..34a1ef036 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/LinearSpace.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/LinearSpace.kt @@ -10,10 +10,7 @@ import space.kscience.kmath.nd.MutableStructure2D import space.kscience.kmath.nd.Structure2D import space.kscience.kmath.nd.StructureFeature import space.kscience.kmath.nd.as1D -import space.kscience.kmath.operations.BufferRingOps -import space.kscience.kmath.operations.DoubleField -import space.kscience.kmath.operations.Ring -import space.kscience.kmath.operations.invoke +import space.kscience.kmath.operations.* import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.BufferFactory import space.kscience.kmath.structures.DoubleBuffer diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/BufferAlgebraND.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/BufferAlgebraND.kt index b09344d12..80c4c28ac 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/BufferAlgebraND.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/BufferAlgebraND.kt @@ -101,7 +101,7 @@ public open class BufferedGroupNDOps>( override val bufferAlgebra: BufferAlgebra, override val indexerBuilder: (IntArray) -> ShapeIndexer = BufferAlgebraND.defaultIndexerBuilder, ) : GroupOpsND, BufferAlgebraND { - override fun StructureND.unaryMinus(): StructureND = map { -it } + override fun negate(arg: StructureND): StructureND = arg.map { -it } } public open class BufferedRingOpsND>( diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/DoubleFieldND.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/DoubleFieldND.kt index d01a8ee95..5dd1eb101 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/DoubleFieldND.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/DoubleFieldND.kt @@ -79,7 +79,7 @@ public sealed class DoubleFieldOpsND : BufferedFieldOpsND(D override fun multiply(left: StructureND, right: StructureND): DoubleBufferND = zipInline(left.toBufferND(), right.toBufferND()) { l, r -> l * r } - override fun StructureND.unaryMinus(): DoubleBufferND = mapInline(toBufferND()) { -it } + override fun negate(arg:StructureND): DoubleBufferND = mapInline(arg.toBufferND()) { -it } override fun StructureND.div(arg: StructureND): DoubleBufferND = zipInline(toBufferND(), arg.toBufferND()) { l, r -> l / r } @@ -93,8 +93,6 @@ public sealed class DoubleFieldOpsND : BufferedFieldOpsND(D override fun Double.div(arg: StructureND): DoubleBufferND = mapInline(arg.toBufferND()) { this / it } - override fun StructureND.unaryPlus(): DoubleBufferND = toBufferND() - override fun StructureND.plus(arg: StructureND): DoubleBufferND = zipInline(toBufferND(), arg.toBufferND()) { l: Double, r: Double -> l + r } diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/Algebra.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/Algebra.kt index 45ba32c13..b66f67228 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/Algebra.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/Algebra.kt @@ -9,12 +9,6 @@ import space.kscience.kmath.expressions.Symbol import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.operations.Ring.Companion.optimizedPower -/** - * Stub for DSL the [Algebra] is. - */ -@DslMarker -public annotation class KMathContext - /** * Represents an algebraic structure. * @@ -137,23 +131,13 @@ public interface GroupOps : Algebra { */ public fun add(left: T, right: T): T - // Operations to be performed in this context. Could be moved to extensions in case of KEEP-176. - /** * The negation of this element. * - * @receiver this value. + * @param arg the element. * @return the additive inverse of this value. */ - public operator fun T.unaryMinus(): T - - /** - * Returns this value. - * - * @receiver this value. - * @return this value. - */ - public operator fun T.unaryPlus(): T = this + public fun negate(arg: T): T /** * Addition of two elements. @@ -173,10 +157,9 @@ public interface GroupOps : Algebra { */ public operator fun T.minus(arg: T): T = add(this, -arg) - // Dynamic dispatch of operations override fun unaryOperationFunction(operation: String): (arg: T) -> T = when (operation) { PLUS_OPERATION -> { arg -> +arg } - MINUS_OPERATION -> { arg -> -arg } + MINUS_OPERATION -> ::negate else -> super.unaryOperationFunction(operation) } @@ -199,6 +182,24 @@ public interface GroupOps : Algebra { } } +/** + * The negation of this element. + * + * @receiver the element. + * @return the additive inverse of this value. + */ +context(GroupOps) +public operator fun T.unaryMinus(): T = negate(this) + +/** + * Returns this value. + * + * @receiver this value. + * @return this value. + */ +context(GroupOps) +public operator fun T.unaryPlus(): T = this + /** * Represents group i.e., algebraic structure with associative, binary operation [add]. * @@ -264,7 +265,7 @@ public interface Ring : Group, RingOps { */ public fun power(arg: T, pow: UInt): T = optimizedPower(arg, pow) - public companion object{ + public companion object { /** * Raises [arg] to the non-negative integer power [exponent]. * @@ -345,7 +346,7 @@ public interface Field : Ring, FieldOps, ScaleOperations, NumericAlg public fun power(arg: T, pow: Int): T = optimizedPower(arg, pow) - public companion object{ + public companion object { /** * Raises [arg] to the integer power [exponent]. * @@ -358,7 +359,10 @@ public interface Field : Ring, FieldOps, ScaleOperations, NumericAlg * @author Iaroslav Postovalov, Evgeniy Zhelenskiy */ private fun Field.optimizedPower(arg: T, exponent: Int): T = when { - exponent < 0 -> one / (this as Ring).optimizedPower(arg, if (exponent == Int.MIN_VALUE) Int.MAX_VALUE.toUInt().inc() else (-exponent).toUInt()) + exponent < 0 -> one / (this as Ring).optimizedPower( + arg, + if (exponent == Int.MIN_VALUE) Int.MAX_VALUE.toUInt().inc() else (-exponent).toUInt() + ) else -> (this as Ring).optimizedPower(arg, exponent.toUInt()) } } diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/BigInt.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/BigInt.kt index 99268348b..b4b39dfc0 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/BigInt.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/BigInt.kt @@ -33,7 +33,7 @@ public object BigIntField : Field, NumbersAddOps, ScaleOperation override fun number(value: Number): BigInt = value.toLong().toBigInt() @Suppress("EXTENSION_SHADOWED_BY_MEMBER") - override fun BigInt.unaryMinus(): BigInt = -this + override fun negate(arg: BigInt): BigInt = -arg override fun add(left: BigInt, right: BigInt): BigInt = left.plus(right) override fun scale(a: BigInt, value: Double): BigInt = a.times(number(value)) override fun multiply(left: BigInt, right: BigInt): BigInt = left.times(right) diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/BufferAlgebra.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/BufferAlgebra.kt index 653552044..3764484ef 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/BufferAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/BufferAlgebra.kt @@ -137,7 +137,7 @@ public open class BufferRingOps>( override fun add(left: Buffer, right: Buffer): Buffer = zipInline(left, right) { l, r -> l + r } override fun multiply(left: Buffer, right: Buffer): Buffer = zipInline(left, right) { l, r -> l * r } - override fun Buffer.unaryMinus(): Buffer = map { -it } + override fun negate(arg: Buffer): Buffer = arg.map { negate(it) } override fun unaryOperationFunction(operation: String): (arg: Buffer) -> Buffer = super.unaryOperationFunction(operation) @@ -159,7 +159,7 @@ public open class BufferFieldOps>( override fun divide(left: Buffer, right: Buffer): Buffer = zipInline(left, right) { l, r -> l / r } override fun scale(a: Buffer, value: Double): Buffer = a.map { scale(it, value) } - override fun Buffer.unaryMinus(): Buffer = map { -it } + override fun negate(arg: Buffer): Buffer = arg.map { -it } override fun binaryOperationFunction(operation: String): (left: Buffer, right: Buffer) -> Buffer = super.binaryOperationFunction(operation) diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/DoubleBufferOps.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/DoubleBufferOps.kt index 0ee591acc..c1bd1daa3 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/DoubleBufferOps.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/DoubleBufferOps.kt @@ -32,7 +32,7 @@ public abstract class DoubleBufferOps : BufferAlgebra, Exte override fun binaryOperationFunction(operation: String): (left: Buffer, right: Buffer) -> Buffer = super.binaryOperationFunction(operation) - override fun Buffer.unaryMinus(): DoubleBuffer = mapInline { -it } + override fun negate(arg: Buffer): DoubleBuffer = arg.mapInline { -it } override fun add(left: Buffer, right: Buffer): DoubleBuffer { require(right.size == left.size) { diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/numbers.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/numbers.kt index 07a137415..3376a166b 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/numbers.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/numbers.kt @@ -68,7 +68,7 @@ public object DoubleField : ExtendedField, Norm, ScaleOp override inline val zero: Double get() = 0.0 override inline val one: Double get() = 1.0 - override inline fun number(value: Number): Double = value.toDouble() + override fun number(value: Number): Double = value.toDouble() override fun binaryOperationFunction(operation: String): (left: Double, right: Double) -> Double = when (operation) { @@ -77,6 +77,7 @@ public object DoubleField : ExtendedField, Norm, ScaleOp } override inline fun add(left: Double, right: Double): Double = left + right + override inline fun negate(arg: Double): Double = -arg override inline fun multiply(left: Double, right: Double): Double = left * right override inline fun divide(left: Double, right: Double): Double = left / right @@ -109,7 +110,6 @@ public object DoubleField : ExtendedField, Norm, ScaleOp override inline fun norm(arg: Double): Double = abs(arg) - override inline fun Double.unaryMinus(): Double = -this override inline fun Double.plus(arg: Double): Double = this + arg override inline fun Double.minus(arg: Double): Double = this - arg override inline fun Double.times(arg: Double): Double = this * arg @@ -135,7 +135,9 @@ public object FloatField : ExtendedField, Norm { } override inline fun add(left: Float, right: Float): Float = left + right - override fun scale(a: Float, value: Double): Float = a * value.toFloat() + override inline fun negate(arg: Float): Float = -arg + + override inline fun scale(a: Float, value: Double): Float = a * value.toFloat() override inline fun multiply(left: Float, right: Float): Float = left * right @@ -163,7 +165,6 @@ public object FloatField : ExtendedField, Norm { override inline fun norm(arg: Float): Float = abs(arg) - override inline fun Float.unaryMinus(): Float = -this override inline fun Float.plus(arg: Float): Float = this + arg override inline fun Float.minus(arg: Float): Float = this - arg override inline fun Float.times(arg: Float): Float = this * arg @@ -185,10 +186,11 @@ public object IntRing : Ring, Norm, NumericAlgebra { override fun number(value: Number): Int = value.toInt() override inline fun add(left: Int, right: Int): Int = left + right + override inline fun negate(arg: Int): Int = -arg + override inline fun multiply(left: Int, right: Int): Int = left * right override inline fun norm(arg: Int): Int = abs(arg) - override inline fun Int.unaryMinus(): Int = -this override inline fun Int.plus(arg: Int): Int = this + arg override inline fun Int.minus(arg: Int): Int = this - arg override inline fun Int.times(arg: Int): Int = this * arg @@ -209,10 +211,11 @@ public object ShortRing : Ring, Norm, NumericAlgebra override fun number(value: Number): Short = value.toShort() override inline fun add(left: Short, right: Short): Short = (left + right).toShort() - override inline fun multiply(left: Short, right: Short): Short = (left * right).toShort() - override fun norm(arg: Short): Short = if (arg > 0) arg else (-arg).toShort() + override inline fun negate(arg: Short): Short = (-arg).toShort() + + override inline fun multiply(left: Short, right: Short): Short = (left * right).toShort() + override inline fun norm(arg: Short): Short = if (arg > 0) arg else (-arg).toShort() - override inline fun Short.unaryMinus(): Short = (-this).toShort() override inline fun Short.plus(arg: Short): Short = (this + arg).toShort() override inline fun Short.minus(arg: Short): Short = (this - arg).toShort() override inline fun Short.times(arg: Short): Short = (this * arg).toShort() @@ -233,10 +236,10 @@ public object ByteRing : Ring, Norm, NumericAlgebra { override fun number(value: Number): Byte = value.toByte() override inline fun add(left: Byte, right: Byte): Byte = (left + right).toByte() + override inline fun negate(arg: Byte): Byte = (-arg).toByte() override inline fun multiply(left: Byte, right: Byte): Byte = (left * right).toByte() - override fun norm(arg: Byte): Byte = if (arg > 0) arg else (-arg).toByte() + override inline fun norm(arg: Byte): Byte = if (arg > 0) arg else (-arg).toByte() - override inline fun Byte.unaryMinus(): Byte = (-this).toByte() override inline fun Byte.plus(arg: Byte): Byte = (this + arg).toByte() override inline fun Byte.minus(arg: Byte): Byte = (this - arg).toByte() override inline fun Byte.times(arg: Byte): Byte = (this * arg).toByte() @@ -257,10 +260,11 @@ public object LongRing : Ring, Norm, NumericAlgebra { override fun number(value: Number): Long = value.toLong() override inline fun add(left: Long, right: Long): Long = left + right - override inline fun multiply(left: Long, right: Long): Long = left * right - override fun norm(arg: Long): Long = abs(arg) + override inline fun negate(arg: Long): Long = (-arg) + + override inline fun multiply(left: Long, right: Long): Long = left * right + override inline fun norm(arg: Long): Long = abs(arg) - override inline fun Long.unaryMinus(): Long = (-this) override inline fun Long.plus(arg: Long): Long = (this + arg) override inline fun Long.minus(arg: Long): Long = (this - arg) override inline fun Long.times(arg: Long): Long = (this * arg) diff --git a/kmath-core/src/jvmMain/kotlin/space/kscience/kmath/operations/BigNumbers.kt b/kmath-core/src/jvmMain/kotlin/space/kscience/kmath/operations/BigNumbers.kt index 6e22c2381..8a59adcef 100644 --- a/kmath-core/src/jvmMain/kotlin/space/kscience/kmath/operations/BigNumbers.kt +++ b/kmath-core/src/jvmMain/kotlin/space/kscience/kmath/operations/BigNumbers.kt @@ -19,10 +19,10 @@ public object JBigIntegerField : Ring, NumericAlgebra { override fun number(value: Number): BigInteger = BigInteger.valueOf(value.toLong()) override fun add(left: BigInteger, right: BigInteger): BigInteger = left.add(right) + override fun negate(arg: BigInteger): BigInteger = arg.negate() + override operator fun BigInteger.minus(arg: BigInteger): BigInteger = subtract(arg) override fun multiply(left: BigInteger, right: BigInteger): BigInteger = left.multiply(right) - - override operator fun BigInteger.unaryMinus(): BigInteger = negate() } /** @@ -40,6 +40,7 @@ public abstract class JBigDecimalFieldBase internal constructor( get() = BigDecimal.ONE override fun add(left: BigDecimal, right: BigDecimal): BigDecimal = left.add(right) + override fun negate(arg: BigDecimal): BigDecimal = arg.negate(mathContext) override operator fun BigDecimal.minus(arg: BigDecimal): BigDecimal = subtract(arg) override fun number(value: Number): BigDecimal = BigDecimal.valueOf(value.toDouble()) @@ -50,7 +51,6 @@ public abstract class JBigDecimalFieldBase internal constructor( override fun divide(left: BigDecimal, right: BigDecimal): BigDecimal = left.divide(right, mathContext) override fun power(arg: BigDecimal, pow: Number): BigDecimal = arg.pow(pow.toInt(), mathContext) override fun sqrt(arg: BigDecimal): BigDecimal = arg.sqrt(mathContext) - override operator fun BigDecimal.unaryMinus(): BigDecimal = negate(mathContext) } /** diff --git a/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/functions/Polynomial.kt b/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/functions/Polynomial.kt index a36d36f52..f782496f3 100644 --- a/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/functions/Polynomial.kt +++ b/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/functions/Polynomial.kt @@ -71,7 +71,7 @@ public fun Polynomial.integrate( ): Polynomial where A : Field, A : NumericAlgebra = algebra { val integratedCoefficients = buildList(coefficients.size + 1) { add(zero) - coefficients.forEachIndexed{ index, t -> add(t / (number(index) + one)) } + coefficients.forEachIndexed { index, t -> add(t / (number(index) + one)) } } Polynomial(integratedCoefficients) } @@ -100,8 +100,8 @@ public class PolynomialSpace( ) : Group>, ScaleOperations> where C : Ring, C : ScaleOperations { override val zero: Polynomial = Polynomial(emptyList()) - override fun Polynomial.unaryMinus(): Polynomial = ring { - Polynomial(coefficients.map { -it }) + override fun negate(arg: Polynomial): Polynomial = ring { + Polynomial(arg.coefficients.map { -it }) } override fun add(left: Polynomial, right: Polynomial): Polynomial { diff --git a/kmath-geometry/src/commonMain/kotlin/space/kscience/kmath/geometry/Euclidean2DSpace.kt b/kmath-geometry/src/commonMain/kotlin/space/kscience/kmath/geometry/Euclidean2DSpace.kt index d00575bcc..afe1d0ea5 100644 --- a/kmath-geometry/src/commonMain/kotlin/space/kscience/kmath/geometry/Euclidean2DSpace.kt +++ b/kmath-geometry/src/commonMain/kotlin/space/kscience/kmath/geometry/Euclidean2DSpace.kt @@ -44,7 +44,7 @@ public object Euclidean2DSpace : GeometrySpace, ScaleOperations, Norm, Norm> : TensorAlgebra } } - override fun StructureND.unaryMinus(): MultikTensor = - asMultik().array.unaryMinus().wrap() + override fun negate(arg: StructureND): MultikTensor = arg.asMultik().array.unaryMinus().wrap() override fun Tensor.get(i: Int): MultikTensor = asMultik().array.mutableView(i).wrap() diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/SamplerAlgebra.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/SamplerAlgebra.kt index 1f442c09b..ddf0ea622 100644 --- a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/SamplerAlgebra.kt +++ b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/SamplerAlgebra.kt @@ -51,5 +51,5 @@ public class SamplerSpace(public val algebra: S) : Group.unaryMinus(): Sampler = scale(this, -1.0) + override fun negate(arg: Sampler): Sampler = scale(arg, -1.0) } diff --git a/kmath-tensorflow/src/main/kotlin/space/kscience/kmath/tensorflow/TensorFlowAlgebra.kt b/kmath-tensorflow/src/main/kotlin/space/kscience/kmath/tensorflow/TensorFlowAlgebra.kt index b40739ee0..18f4c7472 100644 --- a/kmath-tensorflow/src/main/kotlin/space/kscience/kmath/tensorflow/TensorFlowAlgebra.kt +++ b/kmath-tensorflow/src/main/kotlin/space/kscience/kmath/tensorflow/TensorFlowAlgebra.kt @@ -179,7 +179,7 @@ public abstract class TensorFlowAlgebra> internal c override fun Tensor.timesAssign(arg: StructureND): Unit = operateInPlace(arg, ops.math::mul) - override fun StructureND.unaryMinus(): TensorFlowOutput = operate(ops.math::neg) + override fun negate(arg: StructureND): TensorFlowOutput = arg.operate(ops.math::neg) override fun Tensor.get(i: Int): Tensor = operate { TODO("Not yet implemented") diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt index 86d4eaa4e..26c21c7a9 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt @@ -152,13 +152,6 @@ public interface TensorAlgebra> : RingOpsND { */ public operator fun Tensor.timesAssign(arg: StructureND) - /** - * Numerical negative, element-wise. - * - * @return tensor negation of the original tensor. - */ - override operator fun StructureND.unaryMinus(): Tensor - /** * Returns the tensor at index i * For more information: https://pytorch.org/cppdocs/notes/tensor_indexing.html @@ -328,5 +321,12 @@ public interface TensorAlgebra> : RingOpsND { override fun add(left: StructureND, right: StructureND): Tensor = left + right + /** + * Numerical negative, element-wise. + * + * @return tensor negation of the original tensor. + */ + override fun negate(arg: StructureND): Tensor + override fun multiply(left: StructureND, right: StructureND): Tensor = left * right }