diff --git a/CHANGELOG.md b/CHANGELOG.md index 6733c1211..857ed060b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ - `BigInt` operation performance improvement and fixes by @zhelenskiy (#328) - Integration between `MST` and Symja `IExpr` - Complex power +- Separate methods for UInt, Int and Number powers. NaN safety. ### Changed - Exponential operations merged with hyperbolic functions diff --git a/kmath-complex/src/commonMain/kotlin/space/kscience/kmath/complex/ComplexFieldND.kt b/kmath-complex/src/commonMain/kotlin/space/kscience/kmath/complex/ComplexFieldND.kt index 9d5b1cddd..2eaa17ded 100644 --- a/kmath-complex/src/commonMain/kotlin/space/kscience/kmath/complex/ComplexFieldND.kt +++ b/kmath-complex/src/commonMain/kotlin/space/kscience/kmath/complex/ComplexFieldND.kt @@ -18,7 +18,7 @@ import kotlin.contracts.contract */ @OptIn(UnstableKMathAPI::class) public sealed class ComplexFieldOpsND : BufferedFieldOpsND(ComplexField.bufferAlgebra), - ScaleOperations>, ExtendedFieldOps> { + ScaleOperations>, ExtendedFieldOps>, PowerOperations> { override fun StructureND.toBufferND(): BufferND = when (this) { is BufferND -> this @@ -33,9 +33,6 @@ public sealed class ComplexFieldOpsND : BufferedFieldOpsND, value: Double): BufferND = mapInline(a.toBufferND()) { it * value } - override fun power(arg: StructureND, pow: Number): BufferND = - mapInline(arg.toBufferND()) { power(it, pow) } - override fun exp(arg: StructureND): BufferND = mapInline(arg.toBufferND()) { exp(it) } override fun ln(arg: StructureND): BufferND = mapInline(arg.toBufferND()) { ln(it) } @@ -53,6 +50,9 @@ public sealed class ComplexFieldOpsND : BufferedFieldOpsND): BufferND = mapInline(arg.toBufferND()) { acosh(it) } override fun atanh(arg: StructureND): BufferND = mapInline(arg.toBufferND()) { atanh(it) } + override fun power(arg: StructureND, pow: Number): StructureND = + mapInline(arg.toBufferND()) { power(it,pow) } + public companion object : ComplexFieldOpsND() } @@ -63,7 +63,8 @@ public val ComplexField.bufferAlgebra: BufferFieldOps @OptIn(UnstableKMathAPI::class) public class ComplexFieldND(override val shape: Shape) : - ComplexFieldOpsND(), FieldND, NumbersAddOps> { + ComplexFieldOpsND(), FieldND, + NumbersAddOps> { override fun number(value: Number): BufferND { val d = value.toDouble() // minimize conversions diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/SimpleAutoDiff.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/SimpleAutoDiff.kt index 704c4edd8..aa9dd01ce 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/SimpleAutoDiff.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/SimpleAutoDiff.kt @@ -252,7 +252,7 @@ public class SimpleAutoDiffExpression>( * Generate [AutoDiffProcessor] for [SimpleAutoDiffExpression] */ public fun > simpleAutoDiff( - field: F + field: F, ): AutoDiffProcessor, SimpleAutoDiffField> = AutoDiffProcessor { function -> SimpleAutoDiffExpression(field, function) @@ -272,8 +272,8 @@ public fun > SimpleAutoDiffField.sqrt(x: Aut public fun > SimpleAutoDiffField.pow( x: AutoDiffValue, y: Double, -): AutoDiffValue = derive(const { power(x.value, y) }) { z -> - x.d += z.d * y * power(x.value, y - 1) +): AutoDiffValue = derive(const { x.value.pow(y)}) { z -> + x.d += z.d * y * x.value.pow(y - 1) } public fun > SimpleAutoDiffField.pow( diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/BufferAlgebraND.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/BufferAlgebraND.kt index 0e094a8c7..d25b455f4 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/BufferAlgebraND.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/BufferAlgebraND.kt @@ -53,21 +53,28 @@ public interface BufferAlgebraND> : AlgebraND { public inline fun > BufferAlgebraND.mapInline( arg: BufferND, - crossinline transform: A.(T) -> T + crossinline transform: A.(T) -> T, ): BufferND { val indexes = arg.indices - return BufferND(indexes, bufferAlgebra.mapInline(arg.buffer, transform)) + val buffer = arg.buffer + return BufferND( + indexes, + bufferAlgebra.run { + bufferFactory(buffer.size) { elementAlgebra.transform(buffer[it]) } + } + ) } internal inline fun > BufferAlgebraND.mapIndexedInline( arg: BufferND, - crossinline transform: A.(index: IntArray, arg: T) -> T + crossinline transform: A.(index: IntArray, arg: T) -> T, ): BufferND { val indexes = arg.indices + val buffer = arg.buffer return BufferND( indexes, - bufferAlgebra.mapIndexedInline(arg.buffer) { offset, value -> - transform(indexes.index(offset), value) + bufferAlgebra.run { + bufferFactory(buffer.size) { elementAlgebra.transform(indexes.index(it), buffer[it]) } } ) } @@ -75,35 +82,42 @@ internal inline fun > BufferAlgebraND.mapIndexedInline( internal inline fun > BufferAlgebraND.zipInline( l: BufferND, r: BufferND, - crossinline block: A.(l: T, r: T) -> T + crossinline block: A.(l: T, r: T) -> T, ): BufferND { require(l.indices == r.indices) { "Zip requires the same shapes, but found ${l.shape} on the left and ${r.shape} on the right" } val indexes = l.indices - return BufferND(indexes, bufferAlgebra.zipInline(l.buffer, r.buffer, block)) + val lbuffer = l.buffer + val rbuffer = r.buffer + return BufferND( + indexes, + bufferAlgebra.run { + bufferFactory(lbuffer.size) { elementAlgebra.block(lbuffer[it], rbuffer[it]) } + } + ) } @OptIn(PerformancePitfall::class) public open class BufferedGroupNDOps>( override val bufferAlgebra: BufferAlgebra, - override val indexerBuilder: (IntArray) -> ShapeIndexer = BufferAlgebraND.defaultIndexerBuilder + override val indexerBuilder: (IntArray) -> ShapeIndexer = BufferAlgebraND.defaultIndexerBuilder, ) : GroupOpsND, BufferAlgebraND { override fun StructureND.unaryMinus(): StructureND = map { -it } } public open class BufferedRingOpsND>( bufferAlgebra: BufferAlgebra, - indexerBuilder: (IntArray) -> ShapeIndexer = BufferAlgebraND.defaultIndexerBuilder + indexerBuilder: (IntArray) -> ShapeIndexer = BufferAlgebraND.defaultIndexerBuilder, ) : BufferedGroupNDOps(bufferAlgebra, indexerBuilder), RingOpsND public open class BufferedFieldOpsND>( bufferAlgebra: BufferAlgebra, - indexerBuilder: (IntArray) -> ShapeIndexer = BufferAlgebraND.defaultIndexerBuilder + indexerBuilder: (IntArray) -> ShapeIndexer = BufferAlgebraND.defaultIndexerBuilder, ) : BufferedRingOpsND(bufferAlgebra, indexerBuilder), FieldOpsND { public constructor( elementAlgebra: A, bufferFactory: BufferFactory, - indexerBuilder: (IntArray) -> ShapeIndexer = BufferAlgebraND.defaultIndexerBuilder + indexerBuilder: (IntArray) -> ShapeIndexer = BufferAlgebraND.defaultIndexerBuilder, ) : this(BufferFieldOps(elementAlgebra, bufferFactory), indexerBuilder) @OptIn(PerformancePitfall::class) @@ -117,11 +131,11 @@ public val > BufferAlgebra.nd: BufferedFieldOpsND ge public fun > BufferAlgebraND.structureND( vararg shape: Int, - initializer: A.(IntArray) -> T + initializer: A.(IntArray) -> T, ): BufferND = structureND(shape, initializer) public fun , A> A.structureND( - initializer: EA.(IntArray) -> T + initializer: EA.(IntArray) -> T, ): BufferND where A : BufferAlgebraND, A : WithShape = structureND(shape, initializer) //// group factories diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/DoubleFieldND.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/DoubleFieldND.kt index 7285fdb24..8baeac21f 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/DoubleFieldND.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/DoubleFieldND.kt @@ -11,7 +11,7 @@ import space.kscience.kmath.operations.* import space.kscience.kmath.structures.DoubleBuffer import kotlin.contracts.InvocationKind import kotlin.contracts.contract -import kotlin.math.pow +import kotlin.math.pow as kpow public class DoubleBufferND( indexes: ShapeIndexer, @@ -30,9 +30,9 @@ public sealed class DoubleFieldOpsND : BufferedFieldOpsND(D } } - private inline fun mapInline( + protected inline fun mapInline( arg: DoubleBufferND, - transform: (Double) -> Double + transform: (Double) -> Double, ): DoubleBufferND { val indexes = arg.indices val array = arg.buffer.array @@ -42,7 +42,7 @@ public sealed class DoubleFieldOpsND : BufferedFieldOpsND(D private inline fun zipInline( l: DoubleBufferND, r: DoubleBufferND, - block: (l: Double, r: Double) -> Double + block: (l: Double, r: Double) -> Double, ): DoubleBufferND { require(l.indices == r.indices) { "Zip requires the same shapes, but found ${l.shape} on the left and ${r.shape} on the right" } val indexes = l.indices @@ -60,7 +60,7 @@ public sealed class DoubleFieldOpsND : BufferedFieldOpsND(D override fun zip( left: StructureND, right: StructureND, - transform: DoubleField.(Double, Double) -> Double + transform: DoubleField.(Double, Double) -> Double, ): BufferND = zipInline(left.toBufferND(), right.toBufferND()) { l, r -> DoubleField.transform(l, r) } override fun structureND(shape: Shape, initializer: DoubleField.(IntArray) -> Double): DoubleBufferND { @@ -123,9 +123,6 @@ public sealed class DoubleFieldOpsND : BufferedFieldOpsND(D override fun scale(a: StructureND, value: Double): DoubleBufferND = mapInline(a.toBufferND()) { it * value } - override fun power(arg: StructureND, pow: Number): DoubleBufferND = - mapInline(arg.toBufferND()) { it.pow(pow.toDouble()) } - override fun exp(arg: StructureND): DoubleBufferND = mapInline(arg.toBufferND()) { kotlin.math.exp(it) } @@ -173,7 +170,38 @@ public sealed class DoubleFieldOpsND : BufferedFieldOpsND(D @OptIn(UnstableKMathAPI::class) public class DoubleFieldND(override val shape: Shape) : - DoubleFieldOpsND(), FieldND, NumbersAddOps> { + DoubleFieldOpsND(), FieldND, NumbersAddOps>, + ExtendedField> { + + override fun power(arg: StructureND, pow: UInt): DoubleBufferND = mapInline(arg.toBufferND()) { + it.kpow(pow.toInt()) + } + + override fun power(arg: StructureND, pow: Int): DoubleBufferND = mapInline(arg.toBufferND()) { + it.kpow(pow) + } + + override fun power(arg: StructureND, pow: Number): DoubleBufferND = if(pow.isInteger()){ + power(arg, pow.toInt()) + } else { + val dpow = pow.toDouble() + mapInline(arg.toBufferND()) { + if (it < 0) throw IllegalArgumentException("Can't raise negative $it to a fractional power") + else it.kpow(dpow) + } + } + + override fun sinh(arg: StructureND): DoubleBufferND = super.sinh(arg) + + override fun cosh(arg: StructureND): DoubleBufferND = super.cosh(arg) + + override fun tanh(arg: StructureND): DoubleBufferND = super.tan(arg) + + override fun asinh(arg: StructureND): DoubleBufferND = super.asinh(arg) + + override fun acosh(arg: StructureND): DoubleBufferND = super.acosh(arg) + + override fun atanh(arg: StructureND): DoubleBufferND = super.atanh(arg) override fun number(value: Number): DoubleBufferND { val d = value.toDouble() // minimize conversions diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/Algebra.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/Algebra.kt index 992c0e015..244b9fea7 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/Algebra.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/Algebra.kt @@ -7,6 +7,7 @@ package space.kscience.kmath.operations import space.kscience.kmath.expressions.Symbol import space.kscience.kmath.misc.UnstableKMathAPI +import space.kscience.kmath.operations.Ring.Companion.optimizedPower /** * Stub for DSL the [Algebra] is. @@ -257,6 +258,40 @@ public interface Ring : Group, RingOps { * The neutral element of multiplication */ public val one: T + + /** + * Raises [arg] to the integer power [pow]. + */ + public fun power(arg: T, pow: UInt): T = optimizedPower(arg, pow) + + public companion object{ + /** + * Raises [arg] to the non-negative integer power [exponent]. + * + * Special case: 0 ^ 0 is 1. + * + * @receiver the algebra to provide multiplication. + * @param arg the base. + * @param exponent the exponent. + * @return the base raised to the power. + * @author Evgeniy Zhelenskiy + */ + internal fun Ring.optimizedPower(arg: T, exponent: UInt): T = when { + arg == zero && exponent > 0U -> zero + arg == one -> arg + arg == -one -> powWithoutOptimization(arg, exponent % 2U) + else -> powWithoutOptimization(arg, exponent) + } + + private fun Ring.powWithoutOptimization(base: T, exponent: UInt): T = when (exponent) { + 0U -> one + 1U -> base + else -> { + val pre = powWithoutOptimization(base, exponent shr 1).let { it * it } + if (exponent and 1U == 0U) pre else pre * base + } + } + } } /** @@ -307,4 +342,24 @@ public interface FieldOps : RingOps { */ public interface Field : Ring, FieldOps, ScaleOperations, NumericAlgebra { override fun number(value: Number): T = scale(one, value.toDouble()) + + public fun power(arg: T, pow: Int): T = optimizedPower(arg, pow) + + public companion object{ + /** + * Raises [arg] to the integer power [exponent]. + * + * Special case: 0 ^ 0 is 1. + * + * @receiver the algebra to provide multiplication and division. + * @param arg the base. + * @param exponent the exponent. + * @return the base raised to the power. + * @author Iaroslav Postovalov, Evgeniy Zhelenskiy + */ + private fun Field.optimizedPower(arg: T, exponent: Int): T = when { + exponent < 0 -> one / (this as Ring).optimizedPower(arg, if (exponent == Int.MIN_VALUE) Int.MAX_VALUE.toUInt().inc() else (-exponent).toUInt()) + else -> (this as Ring).optimizedPower(arg, exponent.toUInt()) + } + } } diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/BufferAlgebra.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/BufferAlgebra.kt index bc05f3904..634a115c7 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/BufferAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/BufferAlgebra.kt @@ -5,6 +5,7 @@ package space.kscience.kmath.operations +import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.BufferFactory import space.kscience.kmath.structures.DoubleBuffer @@ -34,11 +35,13 @@ public interface BufferAlgebra> : Algebra> { public fun Buffer.zip(other: Buffer, block: A.(left: T, right: T) -> T): Buffer = zipInline(this, other, block) + @UnstableKMathAPI override fun unaryOperationFunction(operation: String): (arg: Buffer) -> Buffer { val operationFunction = elementAlgebra.unaryOperationFunction(operation) return { arg -> bufferFactory(arg.size) { operationFunction(arg[it]) } } } + @UnstableKMathAPI override fun binaryOperationFunction(operation: String): (left: Buffer, right: Buffer) -> Buffer { val operationFunction = elementAlgebra.binaryOperationFunction(operation) return { left, right -> @@ -50,7 +53,7 @@ public interface BufferAlgebra> : Algebra> { /** * Inline map */ -public inline fun > BufferAlgebra.mapInline( +private inline fun > BufferAlgebra.mapInline( buffer: Buffer, crossinline block: A.(T) -> T ): Buffer = bufferFactory(buffer.size) { elementAlgebra.block(buffer[it]) } @@ -58,7 +61,7 @@ public inline fun > BufferAlgebra.mapInline( /** * Inline map */ -public inline fun > BufferAlgebra.mapIndexedInline( +private inline fun > BufferAlgebra.mapIndexedInline( buffer: Buffer, crossinline block: A.(index: Int, arg: T) -> T ): Buffer = bufferFactory(buffer.size) { elementAlgebra.block(it, buffer[it]) } @@ -66,7 +69,7 @@ public inline fun > BufferAlgebra.mapIndexedInline( /** * Inline zip */ -public inline fun > BufferAlgebra.zipInline( +private inline fun > BufferAlgebra.zipInline( l: Buffer, r: Buffer, crossinline block: A.(l: T, r: T) -> T @@ -126,7 +129,7 @@ public fun > BufferAlgebra.atanh(arg: Buff mapInline(arg) { atanh(it) } public fun > BufferAlgebra.pow(arg: Buffer, pow: Number): Buffer = - mapInline(arg) { power(it, pow) } + mapInline(arg) {it.pow(pow) } public open class BufferRingOps>( @@ -138,9 +141,11 @@ public open class BufferRingOps>( override fun multiply(left: Buffer, right: Buffer): Buffer = zipInline(left, right) { l, r -> l * r } override fun Buffer.unaryMinus(): Buffer = map { -it } + @UnstableKMathAPI override fun unaryOperationFunction(operation: String): (arg: Buffer) -> Buffer = super.unaryOperationFunction(operation) + @UnstableKMathAPI override fun binaryOperationFunction(operation: String): (left: Buffer, right: Buffer) -> Buffer = super.binaryOperationFunction(operation) } @@ -160,6 +165,7 @@ public open class BufferFieldOps>( override fun scale(a: Buffer, value: Double): Buffer = a.map { scale(it, value) } override fun Buffer.unaryMinus(): Buffer = map { -it } + @UnstableKMathAPI override fun binaryOperationFunction(operation: String): (left: Buffer, right: Buffer) -> Buffer = super.binaryOperationFunction(operation) } diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/DoubleBufferField.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/DoubleBufferField.kt index 060ea5a7e..0deb647a3 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/DoubleBufferField.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/DoubleBufferField.kt @@ -5,6 +5,8 @@ package space.kscience.kmath.operations +import space.kscience.kmath.misc.UnstableKMathAPI +import space.kscience.kmath.operations.DoubleField.pow import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.DoubleBuffer @@ -27,7 +29,20 @@ public class DoubleBufferField(public val size: Int) : ExtendedField): DoubleBuffer = super.acosh(arg) - override fun atanh(arg: Buffer): DoubleBuffer= super.atanh(arg) + override fun atanh(arg: Buffer): DoubleBuffer = super.atanh(arg) + + override fun power(arg: Buffer, pow: Number): DoubleBuffer = if (pow.isInteger()) { + arg.mapInline { it.pow(pow.toInt()) } + } else { + arg.mapInline { + if(it<0) throw IllegalArgumentException("Negative argument $it could not be raised to the fractional power") + it.pow(pow.toDouble()) + } + } + + @UnstableKMathAPI + override fun unaryOperationFunction(operation: String): (arg: Buffer) -> Buffer = + super.unaryOperationFunction(operation) // override fun number(value: Number): Buffer = DoubleBuffer(size) { value.toDouble() } // diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/DoubleBufferOps.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/DoubleBufferOps.kt index 3d51b3d32..28238c466 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/DoubleBufferOps.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/DoubleBufferOps.kt @@ -6,21 +6,32 @@ package space.kscience.kmath.operations import space.kscience.kmath.linear.Point +import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.structures.Buffer +import space.kscience.kmath.structures.BufferFactory import space.kscience.kmath.structures.DoubleBuffer +import space.kscience.kmath.structures.asBuffer import kotlin.math.* /** * [ExtendedFieldOps] over [DoubleBuffer]. */ -public abstract class DoubleBufferOps : ExtendedFieldOps>, Norm, Double> { +public abstract class DoubleBufferOps : + BufferAlgebra, ExtendedFieldOps>, Norm, Double> { - override fun Buffer.unaryMinus(): DoubleBuffer = if (this is DoubleBuffer) { - DoubleBuffer(size) { -array[it] } - } else { - DoubleBuffer(size) { -get(it) } - } + override val elementAlgebra: DoubleField get() = DoubleField + override val bufferFactory: BufferFactory get() = ::DoubleBuffer + + @UnstableKMathAPI + override fun unaryOperationFunction(operation: String): (arg: Buffer) -> Buffer = + super.unaryOperationFunction(operation) + + @UnstableKMathAPI + override fun binaryOperationFunction(operation: String): (left: Buffer, right: Buffer) -> Buffer = + super.binaryOperationFunction(operation) + + override fun Buffer.unaryMinus(): DoubleBuffer = mapInline { -it } override fun add(left: Buffer, right: Buffer): DoubleBuffer { require(right.size == left.size) { @@ -92,101 +103,46 @@ public abstract class DoubleBufferOps : ExtendedFieldOps>, Norm): DoubleBuffer = if (arg is DoubleBuffer) { - val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { sin(array[it]) }) - } else DoubleBuffer(DoubleArray(arg.size) { sin(arg[it]) }) + override fun sin(arg: Buffer): DoubleBuffer = arg.mapInline(::sin) - override fun cos(arg: Buffer): DoubleBuffer = if (arg is DoubleBuffer) { - val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { cos(array[it]) }) - } else DoubleBuffer(DoubleArray(arg.size) { cos(arg[it]) }) + override fun cos(arg: Buffer): DoubleBuffer = arg.mapInline(::cos) - override fun tan(arg: Buffer): DoubleBuffer = if (arg is DoubleBuffer) { - val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { tan(array[it]) }) - } else DoubleBuffer(DoubleArray(arg.size) { tan(arg[it]) }) + override fun tan(arg: Buffer): DoubleBuffer = arg.mapInline(::tan) - override fun asin(arg: Buffer): DoubleBuffer = if (arg is DoubleBuffer) { - val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { asin(array[it]) }) - } else - DoubleBuffer(DoubleArray(arg.size) { asin(arg[it]) }) + override fun asin(arg: Buffer): DoubleBuffer = arg.mapInline(::asin) - override fun acos(arg: Buffer): DoubleBuffer = if (arg is DoubleBuffer) { - val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { acos(array[it]) }) - } else - DoubleBuffer(DoubleArray(arg.size) { acos(arg[it]) }) + override fun acos(arg: Buffer): DoubleBuffer = arg.mapInline(::acos) - override fun atan(arg: Buffer): DoubleBuffer = if (arg is DoubleBuffer) { - val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { atan(array[it]) }) - } else - DoubleBuffer(DoubleArray(arg.size) { atan(arg[it]) }) + override fun atan(arg: Buffer): DoubleBuffer = arg.mapInline(::atan) - override fun sinh(arg: Buffer): DoubleBuffer = if (arg is DoubleBuffer) { - val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { sinh(array[it]) }) - } else - DoubleBuffer(DoubleArray(arg.size) { sinh(arg[it]) }) + override fun sinh(arg: Buffer): DoubleBuffer = arg.mapInline(::sinh) - override fun cosh(arg: Buffer): DoubleBuffer = if (arg is DoubleBuffer) { - val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { cosh(array[it]) }) - } else - DoubleBuffer(DoubleArray(arg.size) { cosh(arg[it]) }) + override fun cosh(arg: Buffer): DoubleBuffer = arg.mapInline(::cosh) - override fun tanh(arg: Buffer): DoubleBuffer = if (arg is DoubleBuffer) { - val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { tanh(array[it]) }) - } else - DoubleBuffer(DoubleArray(arg.size) { tanh(arg[it]) }) + override fun tanh(arg: Buffer): DoubleBuffer = arg.mapInline(::tanh) - override fun asinh(arg: Buffer): DoubleBuffer = if (arg is DoubleBuffer) { - val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { asinh(array[it]) }) - } else - DoubleBuffer(DoubleArray(arg.size) { asinh(arg[it]) }) + override fun asinh(arg: Buffer): DoubleBuffer = arg.mapInline(::asinh) - override fun acosh(arg: Buffer): DoubleBuffer = if (arg is DoubleBuffer) { - val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { acosh(array[it]) }) - } else - DoubleBuffer(DoubleArray(arg.size) { acosh(arg[it]) }) + override fun acosh(arg: Buffer): DoubleBuffer = arg.mapInline(::acosh) - override fun atanh(arg: Buffer): DoubleBuffer = if (arg is DoubleBuffer) { - val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { atanh(array[it]) }) - } else - DoubleBuffer(DoubleArray(arg.size) { atanh(arg[it]) }) + override fun atanh(arg: Buffer): DoubleBuffer = arg.mapInline(::atanh) - override fun power(arg: Buffer, pow: Number): DoubleBuffer = if (arg is DoubleBuffer) { - val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { array[it].pow(pow.toDouble()) }) - } else - DoubleBuffer(DoubleArray(arg.size) { arg[it].pow(pow.toDouble()) }) + override fun exp(arg: Buffer): DoubleBuffer = arg.mapInline(::exp) - override fun exp(arg: Buffer): DoubleBuffer = if (arg is DoubleBuffer) { - val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { exp(array[it]) }) - } else DoubleBuffer(DoubleArray(arg.size) { exp(arg[it]) }) - - override fun ln(arg: Buffer): DoubleBuffer = if (arg is DoubleBuffer) { - val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { ln(array[it]) }) - } else { - DoubleBuffer(DoubleArray(arg.size) { ln(arg[it]) }) - } + override fun ln(arg: Buffer): DoubleBuffer = arg.mapInline(::ln) override fun norm(arg: Buffer): Double = DoubleL2Norm.norm(arg) - override fun scale(a: Buffer, value: Double): DoubleBuffer = if (a is DoubleBuffer) { - val aArray = a.array - DoubleBuffer(DoubleArray(a.size) { aArray[it] * value }) - } else DoubleBuffer(DoubleArray(a.size) { a[it] * value }) + override fun scale(a: Buffer, value: Double): DoubleBuffer = a.mapInline { it * value } - public companion object : DoubleBufferOps() + public companion object : DoubleBufferOps() { + public inline fun Buffer.mapInline(block: (Double) -> Double): DoubleBuffer = + if (this is DoubleBuffer) { + DoubleArray(size) { block(array[it]) }.asBuffer() + } else { + DoubleArray(size) { block(get(it)) }.asBuffer() + } + } } public object DoubleL2Norm : Norm, Double> { diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/OptionalOperations.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/OptionalOperations.kt index d32e03533..332617158 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/OptionalOperations.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/OptionalOperations.kt @@ -74,14 +74,21 @@ public interface TrigonometricOperations : Algebra { } } +/** + * Check if number is an integer from platform point of view + */ +public expect fun Number.isInteger(): Boolean + /** * A context extension to include power operations based on exponentiation. * * @param T the type of element of this structure. */ -public interface PowerOperations : Algebra { +public interface PowerOperations : FieldOps { + /** - * Raises [arg] to the power [pow]. + * Raises [arg] to a power if possible (negative number could not be raised to a fractional power). + * Throws [IllegalArgumentException] if not possible. */ public fun power(arg: T, pow: Number): T @@ -108,6 +115,7 @@ public interface PowerOperations : Algebra { } } + /** * A container for operations related to `exp` and `ln` functions. * diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/algebraExtensions.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/algebraExtensions.kt index b26ebb2ea..493d90d2f 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/algebraExtensions.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/algebraExtensions.kt @@ -96,46 +96,3 @@ public fun Iterable.averageWith(space: S): T where S : Ring, S : Sc public fun Sequence.averageWith(space: S): T where S : Ring, S : ScaleOperations = space.average(this) -/** - * Raises [arg] to the non-negative integer power [exponent]. - * - * Special case: 0 ^ 0 is 1. - * - * @receiver the algebra to provide multiplication. - * @param arg the base. - * @param exponent the exponent. - * @return the base raised to the power. - * @author Evgeniy Zhelenskiy - */ -public fun Ring.power(arg: T, exponent: UInt): T = when { - arg == zero && exponent > 0U -> zero - arg == one -> arg - arg == -one -> powWithoutOptimization(arg, exponent % 2U) - else -> powWithoutOptimization(arg, exponent) -} - -private fun Ring.powWithoutOptimization(base: T, exponent: UInt): T = when (exponent) { - 0U -> one - 1U -> base - else -> { - val pre = powWithoutOptimization(base, exponent shr 1).let { it * it } - if (exponent and 1U == 0U) pre else pre * base - } -} - - -/** - * Raises [arg] to the integer power [exponent]. - * - * Special case: 0 ^ 0 is 1. - * - * @receiver the algebra to provide multiplication and division. - * @param arg the base. - * @param exponent the exponent. - * @return the base raised to the power. - * @author Iaroslav Postovalov, Evgeniy Zhelenskiy - */ -public fun Field.power(arg: T, exponent: Int): T = when { - exponent < 0 -> one / (this as Ring).power(arg, if (exponent == Int.MIN_VALUE) Int.MAX_VALUE.toUInt().inc() else (-exponent).toUInt()) - else -> (this as Ring).power(arg, exponent.toUInt()) -} diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/numbers.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/numbers.kt index ceb85f3ab..7c8030168 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/numbers.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/numbers.kt @@ -13,9 +13,8 @@ import kotlin.math.pow as kpow public interface ExtendedFieldOps : FieldOps, TrigonometricOperations, - PowerOperations, ExponentialOperations, - ScaleOperations { + ScaleOperations { override fun tan(arg: T): T = sin(arg) / cos(arg) override fun tanh(arg: T): T = sinh(arg) / cosh(arg) @@ -26,7 +25,6 @@ public interface ExtendedFieldOps : TrigonometricOperations.ACOS_OPERATION -> ::acos TrigonometricOperations.ASIN_OPERATION -> ::asin TrigonometricOperations.ATAN_OPERATION -> ::atan - PowerOperations.SQRT_OPERATION -> ::sqrt ExponentialOperations.EXP_OPERATION -> ::exp ExponentialOperations.LN_OPERATION -> ::ln ExponentialOperations.COSH_OPERATION -> ::cosh @@ -42,7 +40,7 @@ public interface ExtendedFieldOps : /** * Advanced Number-like field that implements basic operations. */ -public interface ExtendedField : ExtendedFieldOps, Field, NumericAlgebra{ +public interface ExtendedField : ExtendedFieldOps, Field, PowerOperations, NumericAlgebra { override fun sinh(arg: T): T = (exp(arg) - exp(-arg)) / 2.0 override fun cosh(arg: T): T = (exp(arg) + exp(-arg)) / 2.0 override fun tanh(arg: T): T = (exp(arg) - exp(-arg)) / (exp(-arg) + exp(arg)) @@ -50,6 +48,11 @@ public interface ExtendedField : ExtendedFieldOps, Field, NumericAlgebr override fun acosh(arg: T): T = ln(arg + sqrt((arg - one) * (arg + one))) override fun atanh(arg: T): T = (ln(arg + one) - ln(one - arg)) / 2.0 + override fun unaryOperationFunction(operation: String): (arg: T) -> T { + return if (operation == PowerOperations.SQRT_OPERATION) ::sqrt + else super.unaryOperationFunction(operation) + } + override fun rightSideNumberOperationFunction(operation: String): (left: T, right: Number) -> T = when (operation) { PowerOperations.POW_OPERATION -> ::power @@ -69,7 +72,7 @@ public object DoubleField : ExtendedField, Norm, ScaleOp override fun binaryOperationFunction(operation: String): (left: Double, right: Double) -> Double = when (operation) { - PowerOperations.POW_OPERATION -> ::power + PowerOperations.POW_OPERATION -> { l, r -> l.kpow(r) } else -> super.binaryOperationFunction(operation) } @@ -94,8 +97,13 @@ public object DoubleField : ExtendedField, Norm, ScaleOp override inline fun acosh(arg: Double): Double = kotlin.math.acosh(arg) override inline fun atanh(arg: Double): Double = kotlin.math.atanh(arg) - override inline fun sqrt(arg: Double): Double = kotlin.math.sqrt(arg) - override inline fun power(arg: Double, pow: Number): Double = arg.kpow(pow.toDouble()) + override fun sqrt(arg: Double): Double = kotlin.math.sqrt(arg) + override fun power(arg: Double, pow: Number): Double = when { + pow.isInteger() -> arg.kpow(pow.toInt()) + arg < 0 -> throw IllegalArgumentException("Can't raise negative $arg to a fractional power $pow") + else -> arg.kpow(pow.toDouble()) + } + override inline fun exp(arg: Double): Double = kotlin.math.exp(arg) override inline fun ln(arg: Double): Double = kotlin.math.ln(arg) @@ -122,7 +130,7 @@ public object FloatField : ExtendedField, Norm { override fun binaryOperationFunction(operation: String): (left: Float, right: Float) -> Float = when (operation) { - PowerOperations.POW_OPERATION -> ::power + PowerOperations.POW_OPERATION -> { l, r -> l.kpow(r) } else -> super.binaryOperationFunction(operation) } @@ -149,6 +157,7 @@ public object FloatField : ExtendedField, Norm { override inline fun sqrt(arg: Float): Float = kotlin.math.sqrt(arg) override inline fun power(arg: Float, pow: Number): Float = arg.kpow(pow.toFloat()) + override inline fun exp(arg: Float): Float = kotlin.math.exp(arg) override inline fun ln(arg: Float): Float = kotlin.math.ln(arg) diff --git a/kmath-core/src/jsMain/kotlin/space/kscience/kmath/operations/isInteger.kt b/kmath-core/src/jsMain/kotlin/space/kscience/kmath/operations/isInteger.kt new file mode 100644 index 000000000..c15669145 --- /dev/null +++ b/kmath-core/src/jsMain/kotlin/space/kscience/kmath/operations/isInteger.kt @@ -0,0 +1,6 @@ +package space.kscience.kmath.operations + +/** + * Check if number is an integer + */ +public actual fun Number.isInteger(): Boolean = js("Number").isInteger(this) as Boolean \ No newline at end of file diff --git a/kmath-core/src/jvmMain/kotlin/space/kscience/kmath/operations/isInteger.kt b/kmath-core/src/jvmMain/kotlin/space/kscience/kmath/operations/isInteger.kt new file mode 100644 index 000000000..746d1e530 --- /dev/null +++ b/kmath-core/src/jvmMain/kotlin/space/kscience/kmath/operations/isInteger.kt @@ -0,0 +1,6 @@ +package space.kscience.kmath.operations + +/** + * Check if number is an integer + */ +public actual fun Number.isInteger(): Boolean = (this is Int) || (this is Long) || (this is Short) || (this.toDouble() % 1 == 0.0) \ No newline at end of file diff --git a/kmath-core/src/nativeMain/kotlin/space/kscience/kmath/operations/isInteger.kt b/kmath-core/src/nativeMain/kotlin/space/kscience/kmath/operations/isInteger.kt new file mode 100644 index 000000000..746d1e530 --- /dev/null +++ b/kmath-core/src/nativeMain/kotlin/space/kscience/kmath/operations/isInteger.kt @@ -0,0 +1,6 @@ +package space.kscience.kmath.operations + +/** + * Check if number is an integer + */ +public actual fun Number.isInteger(): Boolean = (this is Int) || (this is Long) || (this is Short) || (this.toDouble() % 1 == 0.0) \ No newline at end of file diff --git a/kmath-kotlingrad/src/test/kotlin/space/kscience/kmath/kotlingrad/AdaptingTests.kt b/kmath-kotlingrad/src/test/kotlin/space/kscience/kmath/kotlingrad/AdaptingTests.kt index 67332a680..4047b9a67 100644 --- a/kmath-kotlingrad/src/test/kotlin/space/kscience/kmath/kotlingrad/AdaptingTests.kt +++ b/kmath-kotlingrad/src/test/kotlin/space/kscience/kmath/kotlingrad/AdaptingTests.kt @@ -62,6 +62,6 @@ internal class AdaptingTests { .parseMath() .compileToExpression(DoubleField) - assertEquals(actualDerivative(x to 0.1), expectedDerivative(x to 0.1)) + assertEquals(actualDerivative(x to -0.1), expectedDerivative(x to -0.1)) } } diff --git a/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jArrayAlgebra.kt b/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jArrayAlgebra.kt index b1cc1f834..dd27bc817 100644 --- a/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jArrayAlgebra.kt +++ b/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jArrayAlgebra.kt @@ -155,7 +155,7 @@ public sealed interface Nd4jArrayField> : FieldOpsND, * Represents intersection of [ExtendedField] and [Field] over [Nd4jArrayStructure]. */ public sealed interface Nd4jArrayExtendedFieldOps> : - ExtendedFieldOps>, Nd4jArrayField { + ExtendedFieldOps>, Nd4jArrayField, PowerOperations> { override fun sin(arg: StructureND): StructureND = Transforms.sin(arg.ndArray).wrap() override fun cos(arg: StructureND): StructureND = Transforms.cos(arg.ndArray).wrap() diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/AnalyticTensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/AnalyticTensorAlgebra.kt index c756584a4..743105fdf 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/AnalyticTensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/AnalyticTensorAlgebra.kt @@ -5,8 +5,8 @@ package space.kscience.kmath.tensors.api -import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.nd.StructureND +import space.kscience.kmath.operations.ExtendedFieldOps import space.kscience.kmath.operations.Field @@ -15,7 +15,8 @@ import space.kscience.kmath.operations.Field * * @param T the type of items closed under analytic functions in the tensors. */ -public interface AnalyticTensorAlgebra> : TensorPartialDivisionAlgebra { +public interface AnalyticTensorAlgebra> : + TensorPartialDivisionAlgebra, ExtendedFieldOps> { /** * @return the mean of all elements in the input tensor. @@ -122,7 +123,27 @@ public interface AnalyticTensorAlgebra> : TensorPartialDivisionA //For information: https://pytorch.org/docs/stable/generated/torch.floor.html#torch.floor public fun StructureND.floor(): Tensor -} + override fun sin(arg: StructureND): StructureND = arg.sin() -@UnstableKMathAPI -public fun > ATA.exp(arg: StructureND): Tensor = arg.exp() \ No newline at end of file + override fun cos(arg: StructureND): StructureND = arg.cos() + + override fun asin(arg: StructureND): StructureND = arg.asin() + + override fun acos(arg: StructureND): StructureND = arg.acos() + + override fun atan(arg: StructureND): StructureND = arg.atan() + + override fun exp(arg: StructureND): StructureND = arg.exp() + + override fun ln(arg: StructureND): StructureND = arg.ln() + + override fun sinh(arg: StructureND): StructureND = arg.sinh() + + override fun cosh(arg: StructureND): StructureND = arg.cosh() + + override fun asinh(arg: StructureND): StructureND = arg.asinh() + + override fun acosh(arg: StructureND): StructureND = arg.acosh() + + override fun atanh(arg: StructureND): StructureND = arg.atanh() +} \ No newline at end of file diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/BroadcastDoubleTensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/BroadcastDoubleTensorAlgebra.kt index 7353ecab1..10c747777 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/BroadcastDoubleTensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/BroadcastDoubleTensorAlgebra.kt @@ -5,6 +5,7 @@ package space.kscience.kmath.tensors.core +import space.kscience.kmath.misc.PerformancePitfall import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.nd.StructureND import space.kscience.kmath.tensors.api.Tensor @@ -17,6 +18,8 @@ import space.kscience.kmath.tensors.core.internal.tensor * Basic linear algebra operations implemented with broadcasting. * For more information: https://pytorch.org/docs/stable/notes/broadcasting.html */ + +@PerformancePitfall public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() { override fun StructureND.plus(arg: StructureND): DoubleTensor { @@ -99,5 +102,6 @@ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() { * Compute a value using broadcast double tensor algebra */ @UnstableKMathAPI +@PerformancePitfall public fun DoubleTensorAlgebra.withBroadcast(block: BroadcastDoubleTensorAlgebra.() -> R): R = BroadcastDoubleTensorAlgebra.block() \ No newline at end of file diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt index 5e7ae262f..bae49c037 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt @@ -3,8 +3,12 @@ * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ + +@file:OptIn(PerformancePitfall::class) + package space.kscience.kmath.tensors.core +import space.kscience.kmath.misc.PerformancePitfall import space.kscience.kmath.nd.MutableStructure2D import space.kscience.kmath.nd.StructureND import space.kscience.kmath.nd.as1D @@ -39,6 +43,7 @@ public open class DoubleTensorAlgebra : * @param transform the function to be applied to each element of the tensor. * @return the resulting tensor after applying the function. */ + @PerformancePitfall @Suppress("OVERRIDE_BY_INLINE") final override inline fun StructureND.map(transform: DoubleField.(Double) -> Double): DoubleTensor { val tensor = this.tensor @@ -52,6 +57,7 @@ public open class DoubleTensorAlgebra : ) } + @PerformancePitfall @Suppress("OVERRIDE_BY_INLINE") final override inline fun StructureND.mapIndexed(transform: DoubleField.(index: IntArray, Double) -> Double): DoubleTensor { val tensor = this.tensor @@ -65,6 +71,7 @@ public open class DoubleTensorAlgebra : ) } + @PerformancePitfall override fun zip( left: StructureND, right: StructureND, @@ -377,6 +384,7 @@ public open class DoubleTensorAlgebra : override fun Tensor.viewAs(other: StructureND): DoubleTensor = tensor.view(other.shape) + @PerformancePitfall override infix fun StructureND.dot(other: StructureND): DoubleTensor { if (tensor.shape.size == 1 && other.shape.size == 1) { return DoubleTensor(intArrayOf(1), doubleArrayOf(tensor.times(other).tensor.mutableBuffer.array().sum())) @@ -691,14 +699,19 @@ public open class DoubleTensorAlgebra : return resTensor } + @OptIn(PerformancePitfall::class) override fun StructureND.exp(): DoubleTensor = tensor.map { exp(it) } + @OptIn(PerformancePitfall::class) override fun StructureND.ln(): DoubleTensor = tensor.map { ln(it) } + @OptIn(PerformancePitfall::class) override fun StructureND.sqrt(): DoubleTensor = tensor.map { sqrt(it) } + @OptIn(PerformancePitfall::class) override fun StructureND.cos(): DoubleTensor = tensor.map { cos(it) } + @OptIn(PerformancePitfall::class) override fun StructureND.acos(): DoubleTensor = tensor.map { acos(it) } override fun StructureND.cosh(): DoubleTensor = tensor.map { cosh(it) } diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/tensorAlgebraExtensions.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/tensorAlgebraExtensions.kt index 916388ba9..1e6dfd52e 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/tensorAlgebraExtensions.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/tensorAlgebraExtensions.kt @@ -3,8 +3,11 @@ * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ +@file:OptIn(PerformancePitfall::class) + package space.kscience.kmath.tensors.core +import space.kscience.kmath.misc.PerformancePitfall import space.kscience.kmath.nd.Shape import kotlin.jvm.JvmName diff --git a/kmath-viktor/src/main/kotlin/space/kscience/kmath/viktor/ViktorFieldOpsND.kt b/kmath-viktor/src/main/kotlin/space/kscience/kmath/viktor/ViktorFieldOpsND.kt index c50404c9c..e43bbbc6f 100644 --- a/kmath-viktor/src/main/kotlin/space/kscience/kmath/viktor/ViktorFieldOpsND.kt +++ b/kmath-viktor/src/main/kotlin/space/kscience/kmath/viktor/ViktorFieldOpsND.kt @@ -6,17 +6,20 @@ package space.kscience.kmath.viktor import org.jetbrains.bio.viktor.F64Array +import space.kscience.kmath.misc.PerformancePitfall import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.nd.* import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.ExtendedFieldOps import space.kscience.kmath.operations.NumbersAddOps +import space.kscience.kmath.operations.PowerOperations @OptIn(UnstableKMathAPI::class) @Suppress("OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") public open class ViktorFieldOpsND : FieldOpsND, - ExtendedFieldOps> { + ExtendedFieldOps>, + PowerOperations> { public val StructureND.f64Buffer: F64Array get() = when (this) { @@ -35,6 +38,7 @@ public open class ViktorFieldOpsND : override fun StructureND.unaryMinus(): StructureND = -1 * this + @PerformancePitfall override fun StructureND.map(transform: DoubleField.(Double) -> Double): ViktorStructureND = F64Array(*shape).apply { DefaultStrides(shape).asSequence().forEach { index -> @@ -42,6 +46,7 @@ public open class ViktorFieldOpsND : } }.asStructure() + @PerformancePitfall override fun StructureND.mapIndexed( transform: DoubleField.(index: IntArray, Double) -> Double, ): ViktorStructureND = F64Array(*shape).apply { @@ -50,6 +55,7 @@ public open class ViktorFieldOpsND : } }.asStructure() + @PerformancePitfall override fun zip( left: StructureND, right: StructureND, @@ -110,7 +116,7 @@ public open class ViktorFieldOpsND : public val DoubleField.viktorAlgebra: ViktorFieldOpsND get() = ViktorFieldOpsND public open class ViktorFieldND( - override val shape: Shape + override val shape: Shape, ) : ViktorFieldOpsND(), FieldND, NumbersAddOps> { override val zero: ViktorStructureND by lazy { F64Array.full(init = 0.0, shape = shape).asStructure() } override val one: ViktorStructureND by lazy { F64Array.full(init = 1.0, shape = shape).asStructure() }