First steps in applying context receivers to operator extension functions

This commit is contained in:
Iaroslav Postovalov 2022-04-04 18:43:20 +07:00
parent 5988b9ad30
commit 57dabba0a3
No known key found for this signature in database
GPG Key ID: 46E15E4A31B3BCD7
22 changed files with 81 additions and 85 deletions

View File

@ -4,4 +4,4 @@
#
kotlin.code.style=official
toolsVersion=0.11.2-kotlin-1.6.10
toolsVersion=0.11.2-kotlin-1.6.20

View File

@ -71,14 +71,13 @@ public object ComplexField :
*/
public val i: Complex by lazy { Complex(0.0, 1.0) }
override fun Complex.unaryMinus(): Complex = Complex(-re, -im)
override fun number(value: Number): Complex = Complex(value.toDouble(), 0.0)
override fun scale(a: Complex, value: Double): Complex = Complex(a.re * value, a.im * value)
override fun add(left: Complex, right: Complex): Complex = Complex(left.re + right.re, left.im + right.im)
// override fun multiply(a: Complex, k: Number): Complex = Complex(a.re * k.toDouble(), a.im * k.toDouble())
override fun negate(arg: Complex): Complex = Complex(-arg.re, -arg.im)
override fun multiply(left: Complex, right: Complex): Complex =
Complex(left.re * right.re - left.im * right.im, left.re * right.im + left.im * right.re)

View File

@ -169,7 +169,7 @@ public object QuaternionField : Field<Quaternion>, Norm<Quaternion, Quaternion>,
override operator fun Number.times(arg: Quaternion): Quaternion =
Quaternion(toDouble() * arg.w, toDouble() * arg.x, toDouble() * arg.y, toDouble() * arg.z)
override fun Quaternion.unaryMinus(): Quaternion = Quaternion(-w, -x, -y, -z)
override fun negate(arg: Quaternion): Quaternion = Quaternion(-arg.w, -arg.x, -arg.y, -arg.z)
override fun norm(arg: Quaternion): Quaternion = sqrt(arg.conjugate * arg)
override fun bindSymbolOrNull(value: String): Quaternion? = when (value) {

View File

@ -51,8 +51,8 @@ public open class FunctionalExpressionGroup<T, out A : Group<T>>(
) : FunctionalExpressionAlgebra<T, A>(algebra), Group<Expression<T>> {
override val zero: Expression<T> get() = const(algebra.zero)
override fun Expression<T>.unaryMinus(): Expression<T> =
unaryOperation(GroupOps.MINUS_OPERATION, this)
override fun negate(arg: Expression<T>): Expression<T> =
unaryOperation(GroupOps.MINUS_OPERATION, arg)
/**
* Builds an Expression of addition of two another expressions.

View File

@ -32,11 +32,9 @@ public object MstGroup : Group<MST>, NumericAlgebra<MST>, ScaleOperations<MST> {
override fun number(value: Number): MST.Numeric = MstNumericAlgebra.number(value)
override fun bindSymbolOrNull(value: String): Symbol = MstNumericAlgebra.bindSymbolOrNull(value)
override fun add(left: MST, right: MST): MST.Binary = binaryOperationFunction(GroupOps.PLUS_OPERATION)(left, right)
override operator fun MST.unaryPlus(): MST.Unary =
unaryOperationFunction(GroupOps.PLUS_OPERATION)(this)
override operator fun MST.unaryMinus(): MST.Unary =
unaryOperationFunction(GroupOps.MINUS_OPERATION)(this)
override fun negate(arg: MST): MST.Unary =
unaryOperationFunction(GroupOps.MINUS_OPERATION)(arg)
override operator fun MST.minus(arg: MST): MST.Binary =
binaryOperationFunction(GroupOps.MINUS_OPERATION)(this, arg)
@ -70,8 +68,7 @@ public object MstRing : Ring<MST>, NumbersAddOps<MST>, ScaleOperations<MST> {
override fun multiply(left: MST, right: MST): MST.Binary =
binaryOperationFunction(RingOps.TIMES_OPERATION)(left, right)
override operator fun MST.unaryPlus(): MST.Unary = MstGroup { +this@unaryPlus }
override operator fun MST.unaryMinus(): MST.Unary = MstGroup { -this@unaryMinus }
override fun negate(arg: MST): MST.Unary = MstGroup.negate(arg)
override operator fun MST.minus(arg: MST): MST.Binary = MstGroup { this@minus - arg }
override fun binaryOperationFunction(operation: String): (left: MST, right: MST) -> MST.Binary =
@ -101,8 +98,7 @@ public object MstField : Field<MST>, NumbersAddOps<MST>, ScaleOperations<MST> {
override fun divide(left: MST, right: MST): MST.Binary =
binaryOperationFunction(FieldOps.DIV_OPERATION)(left, right)
override operator fun MST.unaryPlus(): MST.Unary = MstRing { +this@unaryPlus }
override operator fun MST.unaryMinus(): MST.Unary = MstRing { -this@unaryMinus }
override fun negate(arg: MST): MST.Unary = MstRing.negate(arg)
override operator fun MST.minus(arg: MST): MST.Binary = MstRing { this@minus - arg }
override fun binaryOperationFunction(operation: String): (left: MST, right: MST) -> MST.Binary =
@ -142,8 +138,7 @@ public object MstExtendedField : ExtendedField<MST>, NumericAlgebra<MST> {
override fun multiply(left: MST, right: MST): MST.Binary = MstField.multiply(left, right)
override fun divide(left: MST, right: MST): MST.Binary = MstField.divide(left, right)
override operator fun MST.unaryPlus(): MST.Unary = MstField { +this@unaryPlus }
override operator fun MST.unaryMinus(): MST.Unary = MstField { -this@unaryMinus }
override fun negate(arg: MST): MST.Unary = MstField.negate(arg)
override operator fun MST.minus(arg: MST): MST.Binary = MstField { this@minus - arg }
override fun power(arg: MST, pow: Number): MST.Binary =

View File

@ -163,8 +163,8 @@ public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
// derive(const { this@minus.value - one * b.toDouble() }) { z -> d += z.d }
override fun AutoDiffValue<T>.unaryMinus(): AutoDiffValue<T> =
derive(const { -value }) { z -> d -= z.d }
override fun negate(arg: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { -arg.value }) { z -> arg.d -= z.d }
// Basic math (+, -, *, /)

View File

@ -10,10 +10,7 @@ import space.kscience.kmath.nd.MutableStructure2D
import space.kscience.kmath.nd.Structure2D
import space.kscience.kmath.nd.StructureFeature
import space.kscience.kmath.nd.as1D
import space.kscience.kmath.operations.BufferRingOps
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.Ring
import space.kscience.kmath.operations.invoke
import space.kscience.kmath.operations.*
import space.kscience.kmath.structures.Buffer
import space.kscience.kmath.structures.BufferFactory
import space.kscience.kmath.structures.DoubleBuffer

View File

@ -101,7 +101,7 @@ public open class BufferedGroupNDOps<T, out A : Group<T>>(
override val bufferAlgebra: BufferAlgebra<T, A>,
override val indexerBuilder: (IntArray) -> ShapeIndexer = BufferAlgebraND.defaultIndexerBuilder,
) : GroupOpsND<T, A>, BufferAlgebraND<T, A> {
override fun StructureND<T>.unaryMinus(): StructureND<T> = map { -it }
override fun negate(arg: StructureND<T>): StructureND<T> = arg.map { -it }
}
public open class BufferedRingOpsND<T, out A : Ring<T>>(

View File

@ -79,7 +79,7 @@ public sealed class DoubleFieldOpsND : BufferedFieldOpsND<Double, DoubleField>(D
override fun multiply(left: StructureND<Double>, right: StructureND<Double>): DoubleBufferND =
zipInline(left.toBufferND(), right.toBufferND()) { l, r -> l * r }
override fun StructureND<Double>.unaryMinus(): DoubleBufferND = mapInline(toBufferND()) { -it }
override fun negate(arg:StructureND<Double>): DoubleBufferND = mapInline(arg.toBufferND()) { -it }
override fun StructureND<Double>.div(arg: StructureND<Double>): DoubleBufferND =
zipInline(toBufferND(), arg.toBufferND()) { l, r -> l / r }
@ -93,8 +93,6 @@ public sealed class DoubleFieldOpsND : BufferedFieldOpsND<Double, DoubleField>(D
override fun Double.div(arg: StructureND<Double>): DoubleBufferND =
mapInline(arg.toBufferND()) { this / it }
override fun StructureND<Double>.unaryPlus(): DoubleBufferND = toBufferND()
override fun StructureND<Double>.plus(arg: StructureND<Double>): DoubleBufferND =
zipInline(toBufferND(), arg.toBufferND()) { l: Double, r: Double -> l + r }

View File

@ -9,12 +9,6 @@ import space.kscience.kmath.expressions.Symbol
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.Ring.Companion.optimizedPower
/**
* Stub for DSL the [Algebra] is.
*/
@DslMarker
public annotation class KMathContext
/**
* Represents an algebraic structure.
*
@ -137,23 +131,13 @@ public interface GroupOps<T> : Algebra<T> {
*/
public fun add(left: T, right: T): T
// Operations to be performed in this context. Could be moved to extensions in case of KEEP-176.
/**
* The negation of this element.
*
* @receiver this value.
* @param arg the element.
* @return the additive inverse of this value.
*/
public operator fun T.unaryMinus(): T
/**
* Returns this value.
*
* @receiver this value.
* @return this value.
*/
public operator fun T.unaryPlus(): T = this
public fun negate(arg: T): T
/**
* Addition of two elements.
@ -173,10 +157,9 @@ public interface GroupOps<T> : Algebra<T> {
*/
public operator fun T.minus(arg: T): T = add(this, -arg)
// Dynamic dispatch of operations
override fun unaryOperationFunction(operation: String): (arg: T) -> T = when (operation) {
PLUS_OPERATION -> { arg -> +arg }
MINUS_OPERATION -> { arg -> -arg }
MINUS_OPERATION -> ::negate
else -> super.unaryOperationFunction(operation)
}
@ -199,6 +182,24 @@ public interface GroupOps<T> : Algebra<T> {
}
}
/**
* The negation of this element.
*
* @receiver the element.
* @return the additive inverse of this value.
*/
context(GroupOps<T>)
public operator fun <T> T.unaryMinus(): T = negate(this)
/**
* Returns this value.
*
* @receiver this value.
* @return this value.
*/
context(GroupOps<T>)
public operator fun <T> T.unaryPlus(): T = this
/**
* Represents group i.e., algebraic structure with associative, binary operation [add].
*
@ -264,7 +265,7 @@ public interface Ring<T> : Group<T>, RingOps<T> {
*/
public fun power(arg: T, pow: UInt): T = optimizedPower(arg, pow)
public companion object{
public companion object {
/**
* Raises [arg] to the non-negative integer power [exponent].
*
@ -345,7 +346,7 @@ public interface Field<T> : Ring<T>, FieldOps<T>, ScaleOperations<T>, NumericAlg
public fun power(arg: T, pow: Int): T = optimizedPower(arg, pow)
public companion object{
public companion object {
/**
* Raises [arg] to the integer power [exponent].
*
@ -358,7 +359,10 @@ public interface Field<T> : Ring<T>, FieldOps<T>, ScaleOperations<T>, NumericAlg
* @author Iaroslav Postovalov, Evgeniy Zhelenskiy
*/
private fun <T> Field<T>.optimizedPower(arg: T, exponent: Int): T = when {
exponent < 0 -> one / (this as Ring<T>).optimizedPower(arg, if (exponent == Int.MIN_VALUE) Int.MAX_VALUE.toUInt().inc() else (-exponent).toUInt())
exponent < 0 -> one / (this as Ring<T>).optimizedPower(
arg,
if (exponent == Int.MIN_VALUE) Int.MAX_VALUE.toUInt().inc() else (-exponent).toUInt()
)
else -> (this as Ring<T>).optimizedPower(arg, exponent.toUInt())
}
}

View File

@ -33,7 +33,7 @@ public object BigIntField : Field<BigInt>, NumbersAddOps<BigInt>, ScaleOperation
override fun number(value: Number): BigInt = value.toLong().toBigInt()
@Suppress("EXTENSION_SHADOWED_BY_MEMBER")
override fun BigInt.unaryMinus(): BigInt = -this
override fun negate(arg: BigInt): BigInt = -arg
override fun add(left: BigInt, right: BigInt): BigInt = left.plus(right)
override fun scale(a: BigInt, value: Double): BigInt = a.times(number(value))
override fun multiply(left: BigInt, right: BigInt): BigInt = left.times(right)

View File

@ -137,7 +137,7 @@ public open class BufferRingOps<T, A: Ring<T>>(
override fun add(left: Buffer<T>, right: Buffer<T>): Buffer<T> = zipInline(left, right) { l, r -> l + r }
override fun multiply(left: Buffer<T>, right: Buffer<T>): Buffer<T> = zipInline(left, right) { l, r -> l * r }
override fun Buffer<T>.unaryMinus(): Buffer<T> = map { -it }
override fun negate(arg: Buffer<T>): Buffer<T> = arg.map { negate(it) }
override fun unaryOperationFunction(operation: String): (arg: Buffer<T>) -> Buffer<T> =
super<BufferAlgebra>.unaryOperationFunction(operation)
@ -159,7 +159,7 @@ public open class BufferFieldOps<T, A : Field<T>>(
override fun divide(left: Buffer<T>, right: Buffer<T>): Buffer<T> = zipInline(left, right) { l, r -> l / r }
override fun scale(a: Buffer<T>, value: Double): Buffer<T> = a.map { scale(it, value) }
override fun Buffer<T>.unaryMinus(): Buffer<T> = map { -it }
override fun negate(arg: Buffer<T>): Buffer<T> = arg.map { -it }
override fun binaryOperationFunction(operation: String): (left: Buffer<T>, right: Buffer<T>) -> Buffer<T> =
super<BufferRingOps>.binaryOperationFunction(operation)

View File

@ -32,7 +32,7 @@ public abstract class DoubleBufferOps : BufferAlgebra<Double, DoubleField>, Exte
override fun binaryOperationFunction(operation: String): (left: Buffer<Double>, right: Buffer<Double>) -> Buffer<Double> =
super<ExtendedFieldOps>.binaryOperationFunction(operation)
override fun Buffer<Double>.unaryMinus(): DoubleBuffer = mapInline { -it }
override fun negate(arg: Buffer<Double>): DoubleBuffer = arg.mapInline { -it }
override fun add(left: Buffer<Double>, right: Buffer<Double>): DoubleBuffer {
require(right.size == left.size) {

View File

@ -68,7 +68,7 @@ public object DoubleField : ExtendedField<Double>, Norm<Double, Double>, ScaleOp
override inline val zero: Double get() = 0.0
override inline val one: Double get() = 1.0
override inline fun number(value: Number): Double = value.toDouble()
override fun number(value: Number): Double = value.toDouble()
override fun binaryOperationFunction(operation: String): (left: Double, right: Double) -> Double =
when (operation) {
@ -77,6 +77,7 @@ public object DoubleField : ExtendedField<Double>, Norm<Double, Double>, ScaleOp
}
override inline fun add(left: Double, right: Double): Double = left + right
override inline fun negate(arg: Double): Double = -arg
override inline fun multiply(left: Double, right: Double): Double = left * right
override inline fun divide(left: Double, right: Double): Double = left / right
@ -109,7 +110,6 @@ public object DoubleField : ExtendedField<Double>, Norm<Double, Double>, ScaleOp
override inline fun norm(arg: Double): Double = abs(arg)
override inline fun Double.unaryMinus(): Double = -this
override inline fun Double.plus(arg: Double): Double = this + arg
override inline fun Double.minus(arg: Double): Double = this - arg
override inline fun Double.times(arg: Double): Double = this * arg
@ -135,7 +135,9 @@ public object FloatField : ExtendedField<Float>, Norm<Float, Float> {
}
override inline fun add(left: Float, right: Float): Float = left + right
override fun scale(a: Float, value: Double): Float = a * value.toFloat()
override inline fun negate(arg: Float): Float = -arg
override inline fun scale(a: Float, value: Double): Float = a * value.toFloat()
override inline fun multiply(left: Float, right: Float): Float = left * right
@ -163,7 +165,6 @@ public object FloatField : ExtendedField<Float>, Norm<Float, Float> {
override inline fun norm(arg: Float): Float = abs(arg)
override inline fun Float.unaryMinus(): Float = -this
override inline fun Float.plus(arg: Float): Float = this + arg
override inline fun Float.minus(arg: Float): Float = this - arg
override inline fun Float.times(arg: Float): Float = this * arg
@ -185,10 +186,11 @@ public object IntRing : Ring<Int>, Norm<Int, Int>, NumericAlgebra<Int> {
override fun number(value: Number): Int = value.toInt()
override inline fun add(left: Int, right: Int): Int = left + right
override inline fun negate(arg: Int): Int = -arg
override inline fun multiply(left: Int, right: Int): Int = left * right
override inline fun norm(arg: Int): Int = abs(arg)
override inline fun Int.unaryMinus(): Int = -this
override inline fun Int.plus(arg: Int): Int = this + arg
override inline fun Int.minus(arg: Int): Int = this - arg
override inline fun Int.times(arg: Int): Int = this * arg
@ -209,10 +211,11 @@ public object ShortRing : Ring<Short>, Norm<Short, Short>, NumericAlgebra<Short>
override fun number(value: Number): Short = value.toShort()
override inline fun add(left: Short, right: Short): Short = (left + right).toShort()
override inline fun multiply(left: Short, right: Short): Short = (left * right).toShort()
override fun norm(arg: Short): Short = if (arg > 0) arg else (-arg).toShort()
override inline fun negate(arg: Short): Short = (-arg).toShort()
override inline fun multiply(left: Short, right: Short): Short = (left * right).toShort()
override inline fun norm(arg: Short): Short = if (arg > 0) arg else (-arg).toShort()
override inline fun Short.unaryMinus(): Short = (-this).toShort()
override inline fun Short.plus(arg: Short): Short = (this + arg).toShort()
override inline fun Short.minus(arg: Short): Short = (this - arg).toShort()
override inline fun Short.times(arg: Short): Short = (this * arg).toShort()
@ -233,10 +236,10 @@ public object ByteRing : Ring<Byte>, Norm<Byte, Byte>, NumericAlgebra<Byte> {
override fun number(value: Number): Byte = value.toByte()
override inline fun add(left: Byte, right: Byte): Byte = (left + right).toByte()
override inline fun negate(arg: Byte): Byte = (-arg).toByte()
override inline fun multiply(left: Byte, right: Byte): Byte = (left * right).toByte()
override fun norm(arg: Byte): Byte = if (arg > 0) arg else (-arg).toByte()
override inline fun norm(arg: Byte): Byte = if (arg > 0) arg else (-arg).toByte()
override inline fun Byte.unaryMinus(): Byte = (-this).toByte()
override inline fun Byte.plus(arg: Byte): Byte = (this + arg).toByte()
override inline fun Byte.minus(arg: Byte): Byte = (this - arg).toByte()
override inline fun Byte.times(arg: Byte): Byte = (this * arg).toByte()
@ -257,10 +260,11 @@ public object LongRing : Ring<Long>, Norm<Long, Long>, NumericAlgebra<Long> {
override fun number(value: Number): Long = value.toLong()
override inline fun add(left: Long, right: Long): Long = left + right
override inline fun multiply(left: Long, right: Long): Long = left * right
override fun norm(arg: Long): Long = abs(arg)
override inline fun negate(arg: Long): Long = (-arg)
override inline fun multiply(left: Long, right: Long): Long = left * right
override inline fun norm(arg: Long): Long = abs(arg)
override inline fun Long.unaryMinus(): Long = (-this)
override inline fun Long.plus(arg: Long): Long = (this + arg)
override inline fun Long.minus(arg: Long): Long = (this - arg)
override inline fun Long.times(arg: Long): Long = (this * arg)

View File

@ -19,10 +19,10 @@ public object JBigIntegerField : Ring<BigInteger>, NumericAlgebra<BigInteger> {
override fun number(value: Number): BigInteger = BigInteger.valueOf(value.toLong())
override fun add(left: BigInteger, right: BigInteger): BigInteger = left.add(right)
override fun negate(arg: BigInteger): BigInteger = arg.negate()
override operator fun BigInteger.minus(arg: BigInteger): BigInteger = subtract(arg)
override fun multiply(left: BigInteger, right: BigInteger): BigInteger = left.multiply(right)
override operator fun BigInteger.unaryMinus(): BigInteger = negate()
}
/**
@ -40,6 +40,7 @@ public abstract class JBigDecimalFieldBase internal constructor(
get() = BigDecimal.ONE
override fun add(left: BigDecimal, right: BigDecimal): BigDecimal = left.add(right)
override fun negate(arg: BigDecimal): BigDecimal = arg.negate(mathContext)
override operator fun BigDecimal.minus(arg: BigDecimal): BigDecimal = subtract(arg)
override fun number(value: Number): BigDecimal = BigDecimal.valueOf(value.toDouble())
@ -50,7 +51,6 @@ public abstract class JBigDecimalFieldBase internal constructor(
override fun divide(left: BigDecimal, right: BigDecimal): BigDecimal = left.divide(right, mathContext)
override fun power(arg: BigDecimal, pow: Number): BigDecimal = arg.pow(pow.toInt(), mathContext)
override fun sqrt(arg: BigDecimal): BigDecimal = arg.sqrt(mathContext)
override operator fun BigDecimal.unaryMinus(): BigDecimal = negate(mathContext)
}
/**

View File

@ -71,7 +71,7 @@ public fun <T, A> Polynomial<T>.integrate(
): Polynomial<T> where A : Field<T>, A : NumericAlgebra<T> = algebra {
val integratedCoefficients = buildList(coefficients.size + 1) {
add(zero)
coefficients.forEachIndexed{ index, t -> add(t / (number(index) + one)) }
coefficients.forEachIndexed { index, t -> add(t / (number(index) + one)) }
}
Polynomial(integratedCoefficients)
}
@ -100,8 +100,8 @@ public class PolynomialSpace<T, C>(
) : Group<Polynomial<T>>, ScaleOperations<Polynomial<T>> where C : Ring<T>, C : ScaleOperations<T> {
override val zero: Polynomial<T> = Polynomial(emptyList())
override fun Polynomial<T>.unaryMinus(): Polynomial<T> = ring {
Polynomial(coefficients.map { -it })
override fun negate(arg: Polynomial<T>): Polynomial<T> = ring {
Polynomial(arg.coefficients.map { -it })
}
override fun add(left: Polynomial<T>, right: Polynomial<T>): Polynomial<T> {

View File

@ -44,7 +44,7 @@ public object Euclidean2DSpace : GeometrySpace<Vector2D>, ScaleOperations<Vector
override val zero: Vector2D by lazy { Vector2D(0.0, 0.0) }
public fun Vector2D.norm(): Double = sqrt(x * x + y * y)
override fun Vector2D.unaryMinus(): Vector2D = Vector2D(-x, -y)
override fun negate(arg: Vector2D): Vector2D = Vector2D(-arg.x, -arg.y)
override fun Vector2D.distanceTo(other: Vector2D): Double = (this - other).norm()
override fun add(left: Vector2D, right: Vector2D): Vector2D = Vector2D(left.x + right.x, left.y + right.y)

View File

@ -80,6 +80,7 @@ public object StrictJafamaDoubleField : ExtendedField<Double>, Norm<Double, Doub
}
override inline fun add(left: Double, right: Double): Double = left + right
override inline fun negate(arg: Double): Double = -arg
override inline fun multiply(left: Double, right: Double): Double = left * right
override inline fun divide(left: Double, right: Double): Double = left / right
@ -107,7 +108,6 @@ public object StrictJafamaDoubleField : ExtendedField<Double>, Norm<Double, Doub
override inline fun norm(arg: Double): Double = StrictFastMath.abs(arg)
override inline fun Double.unaryMinus(): Double = -this
override inline fun Double.plus(arg: Double): Double = this + arg
override inline fun Double.minus(arg: Double): Double = this - arg
override inline fun Double.times(arg: Double): Double = this * arg

View File

@ -212,8 +212,7 @@ public abstract class MultikTensorAlgebra<T, A : Ring<T>> : TensorAlgebra<T, A>
}
}
override fun StructureND<T>.unaryMinus(): MultikTensor<T> =
asMultik().array.unaryMinus().wrap()
override fun negate(arg: StructureND<T>): MultikTensor<T> = arg.asMultik().array.unaryMinus().wrap()
override fun Tensor<T>.get(i: Int): MultikTensor<T> = asMultik().array.mutableView(i).wrap()

View File

@ -51,5 +51,5 @@ public class SamplerSpace<T : Any, out S>(public val algebra: S) : Group<Sampler
}
}
override fun Sampler<T>.unaryMinus(): Sampler<T> = scale(this, -1.0)
override fun negate(arg: Sampler<T>): Sampler<T> = scale(arg, -1.0)
}

View File

@ -179,7 +179,7 @@ public abstract class TensorFlowAlgebra<T, TT : TNumber, A : Ring<T>> internal c
override fun Tensor<T>.timesAssign(arg: StructureND<T>): Unit = operateInPlace(arg, ops.math::mul)
override fun StructureND<T>.unaryMinus(): TensorFlowOutput<T, TT> = operate(ops.math::neg)
override fun negate(arg: StructureND<T>): TensorFlowOutput<T, TT> = arg.operate(ops.math::neg)
override fun Tensor<T>.get(i: Int): Tensor<T> = operate {
TODO("Not yet implemented")

View File

@ -152,13 +152,6 @@ public interface TensorAlgebra<T, A : Ring<T>> : RingOpsND<T, A> {
*/
public operator fun Tensor<T>.timesAssign(arg: StructureND<T>)
/**
* Numerical negative, element-wise.
*
* @return tensor negation of the original tensor.
*/
override operator fun StructureND<T>.unaryMinus(): Tensor<T>
/**
* Returns the tensor at index i
* For more information: https://pytorch.org/cppdocs/notes/tensor_indexing.html
@ -328,5 +321,12 @@ public interface TensorAlgebra<T, A : Ring<T>> : RingOpsND<T, A> {
override fun add(left: StructureND<T>, right: StructureND<T>): Tensor<T> = left + right
/**
* Numerical negative, element-wise.
*
* @return tensor negation of the original tensor.
*/
override fun negate(arg: StructureND<T>): Tensor<T>
override fun multiply(left: StructureND<T>, right: StructureND<T>): Tensor<T> = left * right
}