diff --git a/.idea/copyright/kmath.xml b/.idea/copyright/kmath.xml deleted file mode 100644 index 1070e5d33..000000000 --- a/.idea/copyright/kmath.xml +++ /dev/null @@ -1,6 +0,0 @@ - - - - diff --git a/.idea/copyright/profiles_settings.xml b/.idea/copyright/profiles_settings.xml deleted file mode 100644 index b538bdf41..000000000 --- a/.idea/copyright/profiles_settings.xml +++ /dev/null @@ -1,21 +0,0 @@ - - - - - - - - - - - - - - - \ No newline at end of file diff --git a/.idea/scopes/Apply_copyright.xml b/.idea/scopes/Apply_copyright.xml deleted file mode 100644 index 0eb589133..000000000 --- a/.idea/scopes/Apply_copyright.xml +++ /dev/null @@ -1,4 +0,0 @@ - - - diff --git a/CHANGELOG.md b/CHANGELOG.md index 05376d425..bb267744e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -42,6 +42,9 @@ - Use `Symbol` factory function instead of `StringSymbol` - New discoverability pattern: `.algebra.` - Adjusted commons-math API for linear solvers to match conventions. +- Buffer algebra does not require size anymore +- Operations -> Ops +- Default Buffer and ND algebras are now Ops and lack neutral elements (0, 1) as well as algebra-level shapes. ### Deprecated - Specialized `DoubleBufferAlgebra` diff --git a/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/NDFieldBenchmark.kt b/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/NDFieldBenchmark.kt index 76fec05d3..7f7c03412 100644 --- a/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/NDFieldBenchmark.kt +++ b/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/NDFieldBenchmark.kt @@ -9,9 +9,10 @@ import kotlinx.benchmark.Benchmark import kotlinx.benchmark.Blackhole import kotlinx.benchmark.Scope import kotlinx.benchmark.State +import space.kscience.kmath.nd.BufferedFieldOpsND import space.kscience.kmath.nd.StructureND -import space.kscience.kmath.nd.autoNdAlgebra import space.kscience.kmath.nd.ndAlgebra +import space.kscience.kmath.nd.one import space.kscience.kmath.nd4j.nd4j import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.structures.Buffer @@ -23,21 +24,21 @@ import space.kscience.kmath.tensors.core.tensorAlgebra internal class NDFieldBenchmark { @Benchmark fun autoFieldAdd(blackhole: Blackhole) = with(autoField) { - var res: StructureND = one - repeat(n) { res += one } + var res: StructureND = one(shape) + repeat(n) { res += 1.0 } blackhole.consume(res) } @Benchmark fun specializedFieldAdd(blackhole: Blackhole) = with(specializedField) { - var res: StructureND = one + var res: StructureND = one(shape) repeat(n) { res += 1.0 } blackhole.consume(res) } @Benchmark fun boxingFieldAdd(blackhole: Blackhole) = with(genericField) { - var res: StructureND = one + var res: StructureND = one(shape) repeat(n) { res += 1.0 } blackhole.consume(res) } @@ -58,7 +59,7 @@ internal class NDFieldBenchmark { // @Benchmark // fun nd4jAdd(blackhole: Blackhole) = with(nd4jField) { -// var res: StructureND = one +// var res: StructureND = one(dim, dim) // repeat(n) { res += 1.0 } // blackhole.consume(res) // } @@ -66,9 +67,10 @@ internal class NDFieldBenchmark { private companion object { private const val dim = 1000 private const val n = 100 - private val autoField = DoubleField.autoNdAlgebra(dim, dim) - private val specializedField = DoubleField.ndAlgebra(dim, dim) - private val genericField = DoubleField.ndAlgebra(Buffer.Companion::boxing, dim, dim) - private val nd4jField = DoubleField.nd4j(dim, dim) + private val shape = intArrayOf(dim, dim) + private val autoField = BufferedFieldOpsND(DoubleField, Buffer.Companion::auto) + private val specializedField = DoubleField.ndAlgebra + private val genericField = BufferedFieldOpsND(DoubleField, Buffer.Companion::boxing) + private val nd4jField = DoubleField.nd4j } } diff --git a/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/ViktorBenchmark.kt b/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/ViktorBenchmark.kt index b97a05a52..6b4d5759b 100644 --- a/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/ViktorBenchmark.kt +++ b/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/ViktorBenchmark.kt @@ -10,18 +10,17 @@ import kotlinx.benchmark.Blackhole import kotlinx.benchmark.Scope import kotlinx.benchmark.State import org.jetbrains.bio.viktor.F64Array -import space.kscience.kmath.nd.StructureND -import space.kscience.kmath.nd.autoNdAlgebra -import space.kscience.kmath.nd.ndAlgebra +import space.kscience.kmath.nd.* import space.kscience.kmath.operations.DoubleField -import space.kscience.kmath.viktor.ViktorNDField +import space.kscience.kmath.structures.Buffer +import space.kscience.kmath.viktor.ViktorFieldND @State(Scope.Benchmark) internal class ViktorBenchmark { @Benchmark fun automaticFieldAddition(blackhole: Blackhole) { with(autoField) { - var res: StructureND = one + var res: StructureND = one(shape) repeat(n) { res += 1.0 } blackhole.consume(res) } @@ -30,7 +29,7 @@ internal class ViktorBenchmark { @Benchmark fun realFieldAddition(blackhole: Blackhole) { with(realField) { - var res: StructureND = one + var res: StructureND = one(shape) repeat(n) { res += 1.0 } blackhole.consume(res) } @@ -39,7 +38,7 @@ internal class ViktorBenchmark { @Benchmark fun viktorFieldAddition(blackhole: Blackhole) { with(viktorField) { - var res = one + var res = one(shape) repeat(n) { res += 1.0 } blackhole.consume(res) } @@ -56,10 +55,11 @@ internal class ViktorBenchmark { private companion object { private const val dim = 1000 private const val n = 100 + private val shape = Shape(dim, dim) // automatically build context most suited for given type. - private val autoField = DoubleField.autoNdAlgebra(dim, dim) - private val realField = DoubleField.ndAlgebra(dim, dim) - private val viktorField = ViktorNDField(dim, dim) + private val autoField = BufferedFieldOpsND(DoubleField, Buffer.Companion::auto) + private val realField = DoubleField.ndAlgebra + private val viktorField = ViktorFieldND(dim, dim) } } diff --git a/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/ViktorLogBenchmark.kt b/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/ViktorLogBenchmark.kt index 91e9dcd76..ef2adaad8 100644 --- a/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/ViktorLogBenchmark.kt +++ b/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/ViktorLogBenchmark.kt @@ -10,18 +10,21 @@ import kotlinx.benchmark.Blackhole import kotlinx.benchmark.Scope import kotlinx.benchmark.State import org.jetbrains.bio.viktor.F64Array -import space.kscience.kmath.nd.autoNdAlgebra +import space.kscience.kmath.nd.BufferedFieldOpsND +import space.kscience.kmath.nd.Shape import space.kscience.kmath.nd.ndAlgebra +import space.kscience.kmath.nd.one import space.kscience.kmath.operations.DoubleField +import space.kscience.kmath.structures.Buffer import space.kscience.kmath.viktor.ViktorFieldND @State(Scope.Benchmark) internal class ViktorLogBenchmark { @Benchmark fun realFieldLog(blackhole: Blackhole) { - with(realNdField) { - val fortyTwo = produce { 42.0 } - var res = one + with(realField) { + val fortyTwo = produce(shape) { 42.0 } + var res = one(shape) repeat(n) { res = ln(fortyTwo) } blackhole.consume(res) } @@ -30,7 +33,7 @@ internal class ViktorLogBenchmark { @Benchmark fun viktorFieldLog(blackhole: Blackhole) { with(viktorField) { - val fortyTwo = produce { 42.0 } + val fortyTwo = produce(shape) { 42.0 } var res = one repeat(n) { res = ln(fortyTwo) } blackhole.consume(res) @@ -48,10 +51,11 @@ internal class ViktorLogBenchmark { private companion object { private const val dim = 1000 private const val n = 100 + private val shape = Shape(dim, dim) // automatically build context most suited for given type. - private val autoField = DoubleField.autoNdAlgebra(dim, dim) - private val realNdField = DoubleField.ndAlgebra(dim, dim) - private val viktorField = ViktorFieldND(intArrayOf(dim, dim)) + private val autoField = BufferedFieldOpsND(DoubleField, Buffer.Companion::auto) + private val realField = DoubleField.ndAlgebra + private val viktorField = ViktorFieldND(dim, dim) } } diff --git a/examples/src/main/kotlin/space/kscience/kmath/functions/matrixIntegration.kt b/examples/src/main/kotlin/space/kscience/kmath/functions/matrixIntegration.kt index 93b5671fe..609afb47e 100644 --- a/examples/src/main/kotlin/space/kscience/kmath/functions/matrixIntegration.kt +++ b/examples/src/main/kotlin/space/kscience/kmath/functions/matrixIntegration.kt @@ -9,6 +9,7 @@ import space.kscience.kmath.integration.gaussIntegrator import space.kscience.kmath.integration.integrate import space.kscience.kmath.integration.value import space.kscience.kmath.nd.StructureND +import space.kscience.kmath.nd.produce import space.kscience.kmath.nd.withNdAlgebra import space.kscience.kmath.operations.algebra import space.kscience.kmath.operations.invoke @@ -17,7 +18,7 @@ fun main(): Unit = Double.algebra { withNdAlgebra(2, 2) { //Produce a diagonal StructureND - fun diagonal(v: Double) = produce { (i, j) -> + fun diagonal(v: Double) = produce { (i, j) -> if (i == j) v else 0.0 } diff --git a/examples/src/main/kotlin/space/kscience/kmath/operations/complexDemo.kt b/examples/src/main/kotlin/space/kscience/kmath/operations/complexDemo.kt index 319221bcc..67d83d77c 100644 --- a/examples/src/main/kotlin/space/kscience/kmath/operations/complexDemo.kt +++ b/examples/src/main/kotlin/space/kscience/kmath/operations/complexDemo.kt @@ -11,27 +11,27 @@ import space.kscience.kmath.complex.bufferAlgebra import space.kscience.kmath.complex.ndAlgebra import space.kscience.kmath.nd.BufferND import space.kscience.kmath.nd.StructureND +import space.kscience.kmath.nd.produce fun main() = Complex.algebra { val complex = 2 + 2 * i println(complex * 8 - 5 * i) //flat buffer - val buffer = bufferAlgebra(8).run { - buffer { Complex(it, -it) }.map { Complex(it.im, it.re) } + val buffer = with(bufferAlgebra){ + buffer(8) { Complex(it, -it) }.map { Complex(it.im, it.re) } } println(buffer) - // 2d element - val element: BufferND = ndAlgebra(2, 2).produce { (i, j) -> + val element: BufferND = ndAlgebra.produce(2, 2) { (i, j) -> Complex(i - j, i + j) } println(element) // 1d element operation - val result: StructureND = ndAlgebra(8).run { - val a = produce { (it) -> i * it - it.toDouble() } + val result: StructureND = ndAlgebra{ + val a = produce(8) { (it) -> i * it - it.toDouble() } val b = 3 val c = Complex(1.0, 1.0) diff --git a/examples/src/main/kotlin/space/kscience/kmath/structures/ComplexND.kt b/examples/src/main/kotlin/space/kscience/kmath/structures/ComplexND.kt index d4554b3ba..42636fafb 100644 --- a/examples/src/main/kotlin/space/kscience/kmath/structures/ComplexND.kt +++ b/examples/src/main/kotlin/space/kscience/kmath/structures/ComplexND.kt @@ -12,6 +12,7 @@ import space.kscience.kmath.linear.transpose import space.kscience.kmath.nd.StructureND import space.kscience.kmath.nd.as2D import space.kscience.kmath.nd.ndAlgebra +import space.kscience.kmath.nd.produce import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.invoke import kotlin.system.measureTimeMillis diff --git a/examples/src/main/kotlin/space/kscience/kmath/structures/NDField.kt b/examples/src/main/kotlin/space/kscience/kmath/structures/NDField.kt index 5b0e2eb30..cf0721ce7 100644 --- a/examples/src/main/kotlin/space/kscience/kmath/structures/NDField.kt +++ b/examples/src/main/kotlin/space/kscience/kmath/structures/NDField.kt @@ -8,13 +8,11 @@ package space.kscience.kmath.structures import kotlinx.coroutines.DelicateCoroutinesApi import kotlinx.coroutines.GlobalScope import org.nd4j.linalg.factory.Nd4j -import space.kscience.kmath.nd.StructureND -import space.kscience.kmath.nd.autoNdAlgebra -import space.kscience.kmath.nd.ndAlgebra -import space.kscience.kmath.nd4j.Nd4jArrayField +import space.kscience.kmath.nd.* +import space.kscience.kmath.nd4j.nd4j import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.invoke -import space.kscience.kmath.viktor.ViktorNDField +import space.kscience.kmath.viktor.ViktorFieldND import kotlin.contracts.InvocationKind import kotlin.contracts.contract import kotlin.system.measureTimeMillis @@ -31,37 +29,39 @@ fun main() { Nd4j.zeros(0) val dim = 1000 val n = 1000 + val shape = Shape(dim, dim) + // automatically build context most suited for given type. - val autoField = DoubleField.autoNdAlgebra(dim, dim) + val autoField = BufferedFieldOpsND(DoubleField, Buffer.Companion::auto) // specialized nd-field for Double. It works as generic Double field as well. - val realField = DoubleField.ndAlgebra(dim, dim) + val realField = DoubleField.ndAlgebra //A generic boxing field. It should be used for objects, not primitives. - val boxingField = DoubleField.ndAlgebra(Buffer.Companion::boxing, dim, dim) + val boxingField = BufferedFieldOpsND(DoubleField, Buffer.Companion::boxing) // Nd4j specialized field. - val nd4jField = Nd4jArrayField.real(dim, dim) + val nd4jField = DoubleField.nd4j //viktor field - val viktorField = ViktorNDField(dim, dim) + val viktorField = ViktorFieldND(dim, dim) //parallel processing based on Java Streams val parallelField = DoubleField.ndStreaming(dim, dim) measureAndPrint("Boxing addition") { boxingField { - var res: StructureND = one + var res: StructureND = one(shape) repeat(n) { res += 1.0 } } } measureAndPrint("Specialized addition") { realField { - var res: StructureND = one + var res: StructureND = one(shape) repeat(n) { res += 1.0 } } } measureAndPrint("Nd4j specialized addition") { nd4jField { - var res: StructureND = one + var res: StructureND = one(shape) repeat(n) { res += 1.0 } } } @@ -82,13 +82,13 @@ fun main() { measureAndPrint("Automatic field addition") { autoField { - var res: StructureND = one + var res: StructureND = one(shape) repeat(n) { res += 1.0 } } } measureAndPrint("Lazy addition") { - val res = realField.one.mapAsync(GlobalScope) { + val res = realField.one(shape).mapAsync(GlobalScope) { var c = 0.0 repeat(n) { c += 1.0 diff --git a/examples/src/main/kotlin/space/kscience/kmath/structures/StreamDoubleFieldND.kt b/examples/src/main/kotlin/space/kscience/kmath/structures/StreamDoubleFieldND.kt index b1248bd0f..dfd06973e 100644 --- a/examples/src/main/kotlin/space/kscience/kmath/structures/StreamDoubleFieldND.kt +++ b/examples/src/main/kotlin/space/kscience/kmath/structures/StreamDoubleFieldND.kt @@ -8,7 +8,7 @@ package space.kscience.kmath.structures import space.kscience.kmath.nd.* import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.ExtendedField -import space.kscience.kmath.operations.NumbersAddOperations +import space.kscience.kmath.operations.NumbersAddOps import java.util.* import java.util.stream.IntStream @@ -17,17 +17,17 @@ import java.util.stream.IntStream * execution. */ class StreamDoubleFieldND(override val shape: IntArray) : FieldND, - NumbersAddOperations>, + NumbersAddOps>, ExtendedField> { private val strides = DefaultStrides(shape) - override val elementContext: DoubleField get() = DoubleField - override val zero: BufferND by lazy { produce { zero } } - override val one: BufferND by lazy { produce { one } } + override val elementAlgebra: DoubleField get() = DoubleField + override val zero: BufferND by lazy { produce(shape) { zero } } + override val one: BufferND by lazy { produce(shape) { one } } override fun number(value: Number): BufferND { val d = value.toDouble() // minimize conversions - return produce { d } + return produce(shape) { d } } private val StructureND.buffer: DoubleBuffer @@ -36,11 +36,11 @@ class StreamDoubleFieldND(override val shape: IntArray) : FieldND this.buffer as DoubleBuffer + this is BufferND && this.indexes == this@StreamDoubleFieldND.strides -> this.buffer as DoubleBuffer else -> DoubleBuffer(strides.linearSize) { offset -> get(strides.index(offset)) } } - override fun produce(initializer: DoubleField.(IntArray) -> Double): BufferND { + override fun produce(shape: Shape, initializer: DoubleField.(IntArray) -> Double): BufferND { val array = IntStream.range(0, strides.linearSize).parallel().mapToDouble { offset -> val index = strides.index(offset) DoubleField.initializer(index) @@ -69,13 +69,13 @@ class StreamDoubleFieldND(override val shape: IntArray) : FieldND, - b: StructureND, + override fun zip( + left: StructureND, + right: StructureND, transform: DoubleField.(Double, Double) -> Double, ): BufferND { val array = IntStream.range(0, strides.linearSize).parallel().mapToDouble { offset -> - DoubleField.transform(a.buffer.array[offset], b.buffer.array[offset]) + DoubleField.transform(left.buffer.array[offset], right.buffer.array[offset]) }.toArray() return BufferND(strides, array.asBuffer()) } diff --git a/examples/src/main/kotlin/space/kscience/kmath/structures/buffers.kt b/examples/src/main/kotlin/space/kscience/kmath/structures/buffers.kt index d78141507..889ea99bd 100644 --- a/examples/src/main/kotlin/space/kscience/kmath/structures/buffers.kt +++ b/examples/src/main/kotlin/space/kscience/kmath/structures/buffers.kt @@ -8,6 +8,7 @@ package space.kscience.kmath.structures import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.buffer import space.kscience.kmath.operations.bufferAlgebra +import space.kscience.kmath.operations.withSize inline fun MutableBuffer.Companion.same( n: Int, @@ -16,7 +17,7 @@ inline fun MutableBuffer.Companion.same( fun main() { - with(DoubleField.bufferAlgebra(5)) { + with(DoubleField.bufferAlgebra.withSize(5)) { println(number(2.0) + buffer(1, 2, 3, 4, 5)) } } diff --git a/gradle/wrapper/gradle-wrapper.properties b/gradle/wrapper/gradle-wrapper.properties index 05679dc3c..ffed3a254 100644 --- a/gradle/wrapper/gradle-wrapper.properties +++ b/gradle/wrapper/gradle-wrapper.properties @@ -1,5 +1,5 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists -distributionUrl=https\://services.gradle.org/distributions/gradle-7.1.1-bin.zip +distributionUrl=https\://services.gradle.org/distributions/gradle-7.2-bin.zip zipStoreBase=GRADLE_USER_HOME zipStorePath=wrapper/dists diff --git a/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/parser.kt b/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/parser.kt index ef6d51c7b..7f2780548 100644 --- a/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/parser.kt +++ b/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/parser.kt @@ -18,10 +18,10 @@ import com.github.h0tk3y.betterParse.parser.ParseResult import com.github.h0tk3y.betterParse.parser.Parser import space.kscience.kmath.expressions.MST import space.kscience.kmath.expressions.Symbol -import space.kscience.kmath.operations.FieldOperations -import space.kscience.kmath.operations.GroupOperations +import space.kscience.kmath.operations.FieldOps +import space.kscience.kmath.operations.GroupOps import space.kscience.kmath.operations.PowerOperations -import space.kscience.kmath.operations.RingOperations +import space.kscience.kmath.operations.RingOps /** * better-parse implementation of grammar defined in the ArithmeticsEvaluator.g4. @@ -60,7 +60,7 @@ public object ArithmeticsEvaluator : Grammar() { .or(binaryFunction) .or(unaryFunction) .or(singular) - .or(-minus and parser(ArithmeticsEvaluator::term) map { MST.Unary(GroupOperations.MINUS_OPERATION, it) }) + .or(-minus and parser(ArithmeticsEvaluator::term) map { MST.Unary(GroupOps.MINUS_OPERATION, it) }) .or(-lpar and parser(ArithmeticsEvaluator::subSumChain) and -rpar) private val powChain: Parser by leftAssociative(term = term, operator = pow) { a, _, b -> @@ -72,9 +72,9 @@ public object ArithmeticsEvaluator : Grammar() { operator = div or mul use TokenMatch::type ) { a, op, b -> if (op == div) - MST.Binary(FieldOperations.DIV_OPERATION, a, b) + MST.Binary(FieldOps.DIV_OPERATION, a, b) else - MST.Binary(RingOperations.TIMES_OPERATION, a, b) + MST.Binary(RingOps.TIMES_OPERATION, a, b) } private val subSumChain: Parser by leftAssociative( @@ -82,9 +82,9 @@ public object ArithmeticsEvaluator : Grammar() { operator = plus or minus use TokenMatch::type ) { a, op, b -> if (op == plus) - MST.Binary(GroupOperations.PLUS_OPERATION, a, b) + MST.Binary(GroupOps.PLUS_OPERATION, a, b) else - MST.Binary(GroupOperations.MINUS_OPERATION, a, b) + MST.Binary(GroupOps.MINUS_OPERATION, a, b) } override val rootParser: Parser by subSumChain diff --git a/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/features.kt b/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/features.kt index a7a28d87f..8b76b6f19 100644 --- a/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/features.kt +++ b/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/features.kt @@ -39,7 +39,7 @@ public val PrintNumeric: RenderFeature = RenderFeature { _, node -> @UnstableKMathAPI private fun printSignedNumberString(s: String): MathSyntax = if (s.startsWith('-')) UnaryMinusSyntax( - operation = GroupOperations.MINUS_OPERATION, + operation = GroupOps.MINUS_OPERATION, operand = OperandSyntax( operand = NumberSyntax(string = s.removePrefix("-")), parentheses = true, @@ -72,7 +72,7 @@ public class PrettyPrintFloats(public val types: Set>) : Rend val exponent = afterE.toDouble().toString().removeSuffix(".0") return MultiplicationSyntax( - operation = RingOperations.TIMES_OPERATION, + operation = RingOps.TIMES_OPERATION, left = OperandSyntax(operand = NumberSyntax(significand), parentheses = true), right = OperandSyntax( operand = SuperscriptSyntax( @@ -91,7 +91,7 @@ public class PrettyPrintFloats(public val types: Set>) : Rend if (toString.startsWith('-')) return UnaryMinusSyntax( - operation = GroupOperations.MINUS_OPERATION, + operation = GroupOps.MINUS_OPERATION, operand = OperandSyntax(operand = infty, parentheses = true), ) @@ -211,9 +211,9 @@ public class BinaryPlus(operations: Collection?) : Binary(operations) { public companion object { /** - * The default instance configured with [GroupOperations.PLUS_OPERATION]. + * The default instance configured with [GroupOps.PLUS_OPERATION]. */ - public val Default: BinaryPlus = BinaryPlus(setOf(GroupOperations.PLUS_OPERATION)) + public val Default: BinaryPlus = BinaryPlus(setOf(GroupOps.PLUS_OPERATION)) } } @@ -233,9 +233,9 @@ public class BinaryMinus(operations: Collection?) : Binary(operations) { public companion object { /** - * The default instance configured with [GroupOperations.MINUS_OPERATION]. + * The default instance configured with [GroupOps.MINUS_OPERATION]. */ - public val Default: BinaryMinus = BinaryMinus(setOf(GroupOperations.MINUS_OPERATION)) + public val Default: BinaryMinus = BinaryMinus(setOf(GroupOps.MINUS_OPERATION)) } } @@ -253,9 +253,9 @@ public class UnaryPlus(operations: Collection?) : Unary(operations) { public companion object { /** - * The default instance configured with [GroupOperations.PLUS_OPERATION]. + * The default instance configured with [GroupOps.PLUS_OPERATION]. */ - public val Default: UnaryPlus = UnaryPlus(setOf(GroupOperations.PLUS_OPERATION)) + public val Default: UnaryPlus = UnaryPlus(setOf(GroupOps.PLUS_OPERATION)) } } @@ -273,9 +273,9 @@ public class UnaryMinus(operations: Collection?) : Unary(operations) { public companion object { /** - * The default instance configured with [GroupOperations.MINUS_OPERATION]. + * The default instance configured with [GroupOps.MINUS_OPERATION]. */ - public val Default: UnaryMinus = UnaryMinus(setOf(GroupOperations.MINUS_OPERATION)) + public val Default: UnaryMinus = UnaryMinus(setOf(GroupOps.MINUS_OPERATION)) } } @@ -295,9 +295,9 @@ public class Fraction(operations: Collection?) : Binary(operations) { public companion object { /** - * The default instance configured with [FieldOperations.DIV_OPERATION]. + * The default instance configured with [FieldOps.DIV_OPERATION]. */ - public val Default: Fraction = Fraction(setOf(FieldOperations.DIV_OPERATION)) + public val Default: Fraction = Fraction(setOf(FieldOps.DIV_OPERATION)) } } @@ -422,9 +422,9 @@ public class Multiplication(operations: Collection?) : Binary(operations public companion object { /** - * The default instance configured with [RingOperations.TIMES_OPERATION]. + * The default instance configured with [RingOps.TIMES_OPERATION]. */ - public val Default: Multiplication = Multiplication(setOf(RingOperations.TIMES_OPERATION)) + public val Default: Multiplication = Multiplication(setOf(RingOps.TIMES_OPERATION)) } } diff --git a/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/phases.kt b/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/phases.kt index 3d05e03d6..ecea2d104 100644 --- a/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/phases.kt +++ b/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/phases.kt @@ -7,10 +7,10 @@ package space.kscience.kmath.ast.rendering import space.kscience.kmath.ast.rendering.FeaturedMathRendererWithPostProcess.PostProcessPhase import space.kscience.kmath.misc.UnstableKMathAPI -import space.kscience.kmath.operations.FieldOperations -import space.kscience.kmath.operations.GroupOperations +import space.kscience.kmath.operations.FieldOps +import space.kscience.kmath.operations.GroupOps import space.kscience.kmath.operations.PowerOperations -import space.kscience.kmath.operations.RingOperations +import space.kscience.kmath.operations.RingOps /** * Removes unnecessary times (×) symbols from [MultiplicationSyntax]. @@ -306,10 +306,10 @@ public class SimplifyParentheses(public val precedenceFunction: (MathSyntax) -> is BinarySyntax -> when (it.operation) { PowerOperations.POW_OPERATION -> 1 - RingOperations.TIMES_OPERATION -> 3 - FieldOperations.DIV_OPERATION -> 3 - GroupOperations.MINUS_OPERATION -> 4 - GroupOperations.PLUS_OPERATION -> 4 + RingOps.TIMES_OPERATION -> 3 + FieldOps.DIV_OPERATION -> 3 + GroupOps.MINUS_OPERATION -> 4 + GroupOps.PLUS_OPERATION -> 4 else -> 0 } diff --git a/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/rendering/TestLatex.kt b/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/rendering/TestLatex.kt index d8e432230..aba713c43 100644 --- a/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/rendering/TestLatex.kt +++ b/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/rendering/TestLatex.kt @@ -7,7 +7,7 @@ package space.kscience.kmath.ast.rendering import space.kscience.kmath.ast.rendering.TestUtils.testLatex import space.kscience.kmath.expressions.MST -import space.kscience.kmath.operations.GroupOperations +import space.kscience.kmath.operations.GroupOps import kotlin.test.Test internal class TestLatex { @@ -36,7 +36,7 @@ internal class TestLatex { fun unaryOperator() = testLatex("sin(1)", "\\operatorname{sin}\\,\\left(1\\right)") @Test - fun unaryPlus() = testLatex(MST.Unary(GroupOperations.PLUS_OPERATION, MST.Numeric(1)), "+1") + fun unaryPlus() = testLatex(MST.Unary(GroupOps.PLUS_OPERATION, MST.Numeric(1)), "+1") @Test fun unaryMinus() = testLatex("-x", "-x") diff --git a/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/rendering/TestMathML.kt b/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/rendering/TestMathML.kt index a7fcbc75b..658ecd47a 100644 --- a/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/rendering/TestMathML.kt +++ b/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/rendering/TestMathML.kt @@ -7,7 +7,7 @@ package space.kscience.kmath.ast.rendering import space.kscience.kmath.ast.rendering.TestUtils.testMathML import space.kscience.kmath.expressions.MST -import space.kscience.kmath.operations.GroupOperations +import space.kscience.kmath.operations.GroupOps import kotlin.test.Test internal class TestMathML { @@ -47,7 +47,7 @@ internal class TestMathML { @Test fun unaryPlus() = - testMathML(MST.Unary(GroupOperations.PLUS_OPERATION, MST.Numeric(1)), "+1") + testMathML(MST.Unary(GroupOps.PLUS_OPERATION, MST.Numeric(1)), "+1") @Test fun unaryMinus() = testMathML("-x", "-x") diff --git a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/wasm/internal/WasmBuilder.kt b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/wasm/internal/WasmBuilder.kt index 5b6cf65db..b04c4d48f 100644 --- a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/wasm/internal/WasmBuilder.kt +++ b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/wasm/internal/WasmBuilder.kt @@ -108,8 +108,8 @@ internal class DoubleWasmBuilder(target: MST) : WasmBuilder(f64, DoubleF override fun visitNumeric(mst: Numeric): ExpressionRef = ctx.f64.const(mst.value) override fun visitUnary(mst: Unary): ExpressionRef = when (mst.operation) { - GroupOperations.MINUS_OPERATION -> ctx.f64.neg(visit(mst.value)) - GroupOperations.PLUS_OPERATION -> visit(mst.value) + GroupOps.MINUS_OPERATION -> ctx.f64.neg(visit(mst.value)) + GroupOps.PLUS_OPERATION -> visit(mst.value) PowerOperations.SQRT_OPERATION -> ctx.f64.sqrt(visit(mst.value)) TrigonometricOperations.SIN_OPERATION -> ctx.call("sin", arrayOf(visit(mst.value)), f64) TrigonometricOperations.COS_OPERATION -> ctx.call("cos", arrayOf(visit(mst.value)), f64) @@ -129,10 +129,10 @@ internal class DoubleWasmBuilder(target: MST) : WasmBuilder(f64, DoubleF } override fun visitBinary(mst: Binary): ExpressionRef = when (mst.operation) { - GroupOperations.PLUS_OPERATION -> ctx.f64.add(visit(mst.left), visit(mst.right)) - GroupOperations.MINUS_OPERATION -> ctx.f64.sub(visit(mst.left), visit(mst.right)) - RingOperations.TIMES_OPERATION -> ctx.f64.mul(visit(mst.left), visit(mst.right)) - FieldOperations.DIV_OPERATION -> ctx.f64.div(visit(mst.left), visit(mst.right)) + GroupOps.PLUS_OPERATION -> ctx.f64.add(visit(mst.left), visit(mst.right)) + GroupOps.MINUS_OPERATION -> ctx.f64.sub(visit(mst.left), visit(mst.right)) + RingOps.TIMES_OPERATION -> ctx.f64.mul(visit(mst.left), visit(mst.right)) + FieldOps.DIV_OPERATION -> ctx.f64.div(visit(mst.left), visit(mst.right)) PowerOperations.POW_OPERATION -> ctx.call("pow", arrayOf(visit(mst.left), visit(mst.right)), f64) else -> super.visitBinary(mst) } @@ -142,15 +142,15 @@ internal class IntWasmBuilder(target: MST) : WasmBuilder(i32, IntRing, targ override fun visitNumeric(mst: Numeric): ExpressionRef = ctx.i32.const(mst.value) override fun visitUnary(mst: Unary): ExpressionRef = when (mst.operation) { - GroupOperations.MINUS_OPERATION -> ctx.i32.sub(ctx.i32.const(0), visit(mst.value)) - GroupOperations.PLUS_OPERATION -> visit(mst.value) + GroupOps.MINUS_OPERATION -> ctx.i32.sub(ctx.i32.const(0), visit(mst.value)) + GroupOps.PLUS_OPERATION -> visit(mst.value) else -> super.visitUnary(mst) } override fun visitBinary(mst: Binary): ExpressionRef = when (mst.operation) { - GroupOperations.PLUS_OPERATION -> ctx.i32.add(visit(mst.left), visit(mst.right)) - GroupOperations.MINUS_OPERATION -> ctx.i32.sub(visit(mst.left), visit(mst.right)) - RingOperations.TIMES_OPERATION -> ctx.i32.mul(visit(mst.left), visit(mst.right)) + GroupOps.PLUS_OPERATION -> ctx.i32.add(visit(mst.left), visit(mst.right)) + GroupOps.MINUS_OPERATION -> ctx.i32.sub(visit(mst.left), visit(mst.right)) + RingOps.TIMES_OPERATION -> ctx.i32.mul(visit(mst.left), visit(mst.right)) else -> super.visitBinary(mst) } } diff --git a/kmath-commons/src/main/kotlin/space/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt b/kmath-commons/src/main/kotlin/space/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt index bc0119ca2..d42e40d1e 100644 --- a/kmath-commons/src/main/kotlin/space/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt +++ b/kmath-commons/src/main/kotlin/space/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt @@ -9,7 +9,7 @@ import org.apache.commons.math3.analysis.differentiation.DerivativeStructure import space.kscience.kmath.expressions.* import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.operations.ExtendedField -import space.kscience.kmath.operations.NumbersAddOperations +import space.kscience.kmath.operations.NumbersAddOps /** * A field over commons-math [DerivativeStructure]. @@ -22,7 +22,7 @@ public class DerivativeStructureField( public val order: Int, bindings: Map, ) : ExtendedField, ExpressionAlgebra, - NumbersAddOperations { + NumbersAddOps { public val numberOfVariables: Int = bindings.size override val zero: DerivativeStructure by lazy { DerivativeStructure(numberOfVariables, order) } @@ -70,12 +70,12 @@ public class DerivativeStructureField( override fun DerivativeStructure.unaryMinus(): DerivativeStructure = negate() - override fun add(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.add(b) + override fun add(left: DerivativeStructure, right: DerivativeStructure): DerivativeStructure = left.add(right) override fun scale(a: DerivativeStructure, value: Double): DerivativeStructure = a.multiply(value) - override fun multiply(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.multiply(b) - override fun divide(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.divide(b) + override fun multiply(left: DerivativeStructure, right: DerivativeStructure): DerivativeStructure = left.multiply(right) + override fun divide(left: DerivativeStructure, right: DerivativeStructure): DerivativeStructure = left.divide(right) override fun sin(arg: DerivativeStructure): DerivativeStructure = arg.sin() override fun cos(arg: DerivativeStructure): DerivativeStructure = arg.cos() override fun tan(arg: DerivativeStructure): DerivativeStructure = arg.tan() diff --git a/kmath-complex/src/commonMain/kotlin/space/kscience/kmath/complex/Complex.kt b/kmath-complex/src/commonMain/kotlin/space/kscience/kmath/complex/Complex.kt index 7d948cb61..879cfe94e 100644 --- a/kmath-complex/src/commonMain/kotlin/space/kscience/kmath/complex/Complex.kt +++ b/kmath-complex/src/commonMain/kotlin/space/kscience/kmath/complex/Complex.kt @@ -52,7 +52,7 @@ private val PI_DIV_2 = Complex(PI / 2, 0) public object ComplexField : ExtendedField, Norm, - NumbersAddOperations, + NumbersAddOps, ScaleOperations { override val zero: Complex = 0.0.toComplex() @@ -77,33 +77,33 @@ public object ComplexField : override fun scale(a: Complex, value: Double): Complex = Complex(a.re * value, a.im * value) - override fun add(a: Complex, b: Complex): Complex = Complex(a.re + b.re, a.im + b.im) + override fun add(left: Complex, right: Complex): Complex = Complex(left.re + right.re, left.im + right.im) // override fun multiply(a: Complex, k: Number): Complex = Complex(a.re * k.toDouble(), a.im * k.toDouble()) - override fun multiply(a: Complex, b: Complex): Complex = - Complex(a.re * b.re - a.im * b.im, a.re * b.im + a.im * b.re) + override fun multiply(left: Complex, right: Complex): Complex = + Complex(left.re * right.re - left.im * right.im, left.re * right.im + left.im * right.re) - override fun divide(a: Complex, b: Complex): Complex = when { - abs(b.im) < abs(b.re) -> { - val wr = b.im / b.re - val wd = b.re + wr * b.im + override fun divide(left: Complex, right: Complex): Complex = when { + abs(right.im) < abs(right.re) -> { + val wr = right.im / right.re + val wd = right.re + wr * right.im if (wd.isNaN() || wd == 0.0) throw ArithmeticException("Division by zero or infinity") else - Complex((a.re + a.im * wr) / wd, (a.im - a.re * wr) / wd) + Complex((left.re + left.im * wr) / wd, (left.im - left.re * wr) / wd) } - b.im == 0.0 -> throw ArithmeticException("Division by zero") + right.im == 0.0 -> throw ArithmeticException("Division by zero") else -> { - val wr = b.re / b.im - val wd = b.im + wr * b.re + val wr = right.re / right.im + val wd = right.im + wr * right.re if (wd.isNaN() || wd == 0.0) throw ArithmeticException("Division by zero or infinity") else - Complex((a.re * wr + a.im) / wd, (a.im * wr - a.re) / wd) + Complex((left.re * wr + left.im) / wd, (left.im * wr - left.re) / wd) } } @@ -216,7 +216,6 @@ public data class Complex(val re: Double, val im: Double) { public val Complex.Companion.algebra: ComplexField get() = ComplexField - /** * Creates a complex number with real part equal to this real. * diff --git a/kmath-complex/src/commonMain/kotlin/space/kscience/kmath/complex/ComplexFieldND.kt b/kmath-complex/src/commonMain/kotlin/space/kscience/kmath/complex/ComplexFieldND.kt index 29e790d16..3951b5de0 100644 --- a/kmath-complex/src/commonMain/kotlin/space/kscience/kmath/complex/ComplexFieldND.kt +++ b/kmath-complex/src/commonMain/kotlin/space/kscience/kmath/complex/ComplexFieldND.kt @@ -6,13 +6,8 @@ package space.kscience.kmath.complex import space.kscience.kmath.misc.UnstableKMathAPI -import space.kscience.kmath.nd.BufferND -import space.kscience.kmath.nd.BufferedFieldND -import space.kscience.kmath.nd.StructureND -import space.kscience.kmath.operations.BufferField -import space.kscience.kmath.operations.ExtendedField -import space.kscience.kmath.operations.NumbersAddOperations -import space.kscience.kmath.operations.bufferAlgebra +import space.kscience.kmath.nd.* +import space.kscience.kmath.operations.* import space.kscience.kmath.structures.Buffer import kotlin.contracts.InvocationKind import kotlin.contracts.contract @@ -22,100 +17,61 @@ import kotlin.contracts.contract * An optimized nd-field for complex numbers */ @OptIn(UnstableKMathAPI::class) -public class ComplexFieldND( - shape: IntArray, -) : BufferedFieldND(shape, ComplexField, Buffer.Companion::complex), - NumbersAddOperations>, - ExtendedField> { +public sealed class ComplexFieldOpsND : BufferedFieldOpsND(ComplexField.bufferAlgebra), + ScaleOperations>, ExtendedFieldOps> { - override val zero: BufferND by lazy { produce { zero } } - override val one: BufferND by lazy { produce { one } } - - override fun number(value: Number): BufferND { - val d = value.toComplex() // minimize conversions - return produce { d } + override fun StructureND.toBufferND(): BufferND = when (this) { + is BufferND -> this + else -> { + val indexer = indexerBuilder(shape) + BufferND(indexer, Buffer.complex(indexer.linearSize) { offset -> get(indexer.index(offset)) }) + } } -// -// @Suppress("OVERRIDE_BY_INLINE") -// override inline fun map( -// arg: AbstractNDBuffer, -// transform: DoubleField.(Double) -> Double, -// ): RealNDElement { -// check(arg) -// val array = RealBuffer(arg.strides.linearSize) { offset -> DoubleField.transform(arg.buffer[offset]) } -// return BufferedNDFieldElement(this, array) -// } -// -// @Suppress("OVERRIDE_BY_INLINE") -// override inline fun produce(initializer: DoubleField.(IntArray) -> Double): RealNDElement { -// val array = RealBuffer(strides.linearSize) { offset -> elementContext.initializer(strides.index(offset)) } -// return BufferedNDFieldElement(this, array) -// } -// -// @Suppress("OVERRIDE_BY_INLINE") -// override inline fun mapIndexed( -// arg: AbstractNDBuffer, -// transform: DoubleField.(index: IntArray, Double) -> Double, -// ): RealNDElement { -// check(arg) -// return BufferedNDFieldElement( -// this, -// RealBuffer(arg.strides.linearSize) { offset -> -// elementContext.transform( -// arg.strides.index(offset), -// arg.buffer[offset] -// ) -// }) -// } -// -// @Suppress("OVERRIDE_BY_INLINE") -// override inline fun combine( -// a: AbstractNDBuffer, -// b: AbstractNDBuffer, -// transform: DoubleField.(Double, Double) -> Double, -// ): RealNDElement { -// check(a, b) -// val buffer = RealBuffer(strides.linearSize) { offset -> -// elementContext.transform(a.buffer[offset], b.buffer[offset]) -// } -// return BufferedNDFieldElement(this, buffer) -// } + //TODO do specialization - override fun power(arg: StructureND, pow: Number): BufferND = arg.map { power(it, pow) } + override fun scale(a: StructureND, value: Double): BufferND = + mapInline(a.toBufferND()) { it * value } - override fun exp(arg: StructureND): BufferND = arg.map { exp(it) } + override fun power(arg: StructureND, pow: Number): BufferND = + mapInline(arg.toBufferND()) { power(it, pow) } - override fun ln(arg: StructureND): BufferND = arg.map { ln(it) } + override fun exp(arg: StructureND): BufferND = mapInline(arg.toBufferND()) { exp(it) } + override fun ln(arg: StructureND): BufferND = mapInline(arg.toBufferND()) { ln(it) } - override fun sin(arg: StructureND): BufferND = arg.map { sin(it) } - override fun cos(arg: StructureND): BufferND = arg.map { cos(it) } - override fun tan(arg: StructureND): BufferND = arg.map { tan(it) } - override fun asin(arg: StructureND): BufferND = arg.map { asin(it) } - override fun acos(arg: StructureND): BufferND = arg.map { acos(it) } - override fun atan(arg: StructureND): BufferND = arg.map { atan(it) } + override fun sin(arg: StructureND): BufferND = mapInline(arg.toBufferND()) { sin(it) } + override fun cos(arg: StructureND): BufferND = mapInline(arg.toBufferND()) { cos(it) } + override fun tan(arg: StructureND): BufferND = mapInline(arg.toBufferND()) { tan(it) } + override fun asin(arg: StructureND): BufferND = mapInline(arg.toBufferND()) { asin(it) } + override fun acos(arg: StructureND): BufferND = mapInline(arg.toBufferND()) { acos(it) } + override fun atan(arg: StructureND): BufferND = mapInline(arg.toBufferND()) { atan(it) } - override fun sinh(arg: StructureND): BufferND = arg.map { sinh(it) } - override fun cosh(arg: StructureND): BufferND = arg.map { cosh(it) } - override fun tanh(arg: StructureND): BufferND = arg.map { tanh(it) } - override fun asinh(arg: StructureND): BufferND = arg.map { asinh(it) } - override fun acosh(arg: StructureND): BufferND = arg.map { acosh(it) } - override fun atanh(arg: StructureND): BufferND = arg.map { atanh(it) } -} + override fun sinh(arg: StructureND): BufferND = mapInline(arg.toBufferND()) { sinh(it) } + override fun cosh(arg: StructureND): BufferND = mapInline(arg.toBufferND()) { cosh(it) } + override fun tanh(arg: StructureND): BufferND = mapInline(arg.toBufferND()) { tanh(it) } + override fun asinh(arg: StructureND): BufferND = mapInline(arg.toBufferND()) { asinh(it) } + override fun acosh(arg: StructureND): BufferND = mapInline(arg.toBufferND()) { acosh(it) } + override fun atanh(arg: StructureND): BufferND = mapInline(arg.toBufferND()) { atanh(it) } - -/** - * Fast element production using function inlining - */ -public inline fun BufferedFieldND.produceInline(initializer: ComplexField.(Int) -> Complex): BufferND { - contract { callsInPlace(initializer, InvocationKind.EXACTLY_ONCE) } - val buffer = Buffer.complex(strides.linearSize) { offset -> ComplexField.initializer(offset) } - return BufferND(strides, buffer) + public companion object : ComplexFieldOpsND() } @UnstableKMathAPI -public fun ComplexField.bufferAlgebra(size: Int): BufferField = - bufferAlgebra(Buffer.Companion::complex, size) +public val ComplexField.bufferAlgebra: BufferFieldOps + get() = bufferAlgebra(Buffer.Companion::complex) + + +@OptIn(UnstableKMathAPI::class) +public class ComplexFieldND(override val shape: Shape) : + ComplexFieldOpsND(), FieldND, NumbersAddOps> { + + override fun number(value: Number): BufferND { + val d = value.toDouble() // minimize conversions + return produce(shape) { d.toComplex() } + } +} + +public val ComplexField.ndAlgebra: ComplexFieldOpsND get() = ComplexFieldOpsND public fun ComplexField.ndAlgebra(vararg shape: Int): ComplexFieldND = ComplexFieldND(shape) diff --git a/kmath-complex/src/commonMain/kotlin/space/kscience/kmath/complex/Quaternion.kt b/kmath-complex/src/commonMain/kotlin/space/kscience/kmath/complex/Quaternion.kt index e5d7ebd1e..9fdd60e1f 100644 --- a/kmath-complex/src/commonMain/kotlin/space/kscience/kmath/complex/Quaternion.kt +++ b/kmath-complex/src/commonMain/kotlin/space/kscience/kmath/complex/Quaternion.kt @@ -44,7 +44,7 @@ public val Quaternion.r: Double */ @OptIn(UnstableKMathAPI::class) public object QuaternionField : Field, Norm, PowerOperations, - ExponentialOperations, NumbersAddOperations, ScaleOperations { + ExponentialOperations, NumbersAddOps, ScaleOperations { override val zero: Quaternion = 0.toQuaternion() override val one: Quaternion = 1.toQuaternion() @@ -63,27 +63,27 @@ public object QuaternionField : Field, Norm, */ public val k: Quaternion = Quaternion(0, 0, 0, 1) - override fun add(a: Quaternion, b: Quaternion): Quaternion = - Quaternion(a.w + b.w, a.x + b.x, a.y + b.y, a.z + b.z) + override fun add(left: Quaternion, right: Quaternion): Quaternion = + Quaternion(left.w + right.w, left.x + right.x, left.y + right.y, left.z + right.z) override fun scale(a: Quaternion, value: Double): Quaternion = Quaternion(a.w * value, a.x * value, a.y * value, a.z * value) - override fun multiply(a: Quaternion, b: Quaternion): Quaternion = Quaternion( - a.w * b.w - a.x * b.x - a.y * b.y - a.z * b.z, - a.w * b.x + a.x * b.w + a.y * b.z - a.z * b.y, - a.w * b.y - a.x * b.z + a.y * b.w + a.z * b.x, - a.w * b.z + a.x * b.y - a.y * b.x + a.z * b.w, + override fun multiply(left: Quaternion, right: Quaternion): Quaternion = Quaternion( + left.w * right.w - left.x * right.x - left.y * right.y - left.z * right.z, + left.w * right.x + left.x * right.w + left.y * right.z - left.z * right.y, + left.w * right.y - left.x * right.z + left.y * right.w + left.z * right.x, + left.w * right.z + left.x * right.y - left.y * right.x + left.z * right.w, ) - override fun divide(a: Quaternion, b: Quaternion): Quaternion { - val s = b.w * b.w + b.x * b.x + b.y * b.y + b.z * b.z + override fun divide(left: Quaternion, right: Quaternion): Quaternion { + val s = right.w * right.w + right.x * right.x + right.y * right.y + right.z * right.z return Quaternion( - (b.w * a.w + b.x * a.x + b.y * a.y + b.z * a.z) / s, - (b.w * a.x - b.x * a.w - b.y * a.z + b.z * a.y) / s, - (b.w * a.y + b.x * a.z - b.y * a.w - b.z * a.x) / s, - (b.w * a.z - b.x * a.y + b.y * a.x - b.z * a.w) / s, + (right.w * left.w + right.x * left.x + right.y * left.y + right.z * left.z) / s, + (right.w * left.x - right.x * left.w - right.y * left.z + right.z * left.y) / s, + (right.w * left.y + right.x * left.z - right.y * left.w - right.z * left.x) / s, + (right.w * left.z - right.x * left.y + right.y * left.x - right.z * left.w) / s, ) } diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt index 36ccb96f7..661680565 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt @@ -52,13 +52,13 @@ public open class FunctionalExpressionGroup>( override val zero: Expression get() = const(algebra.zero) override fun Expression.unaryMinus(): Expression = - unaryOperation(GroupOperations.MINUS_OPERATION, this) + unaryOperation(GroupOps.MINUS_OPERATION, this) /** * Builds an Expression of addition of two another expressions. */ - override fun add(a: Expression, b: Expression): Expression = - binaryOperation(GroupOperations.PLUS_OPERATION, a, b) + override fun add(left: Expression, right: Expression): Expression = + binaryOperation(GroupOps.PLUS_OPERATION, left, right) // /** // * Builds an Expression of multiplication of expression by number. @@ -88,8 +88,8 @@ public open class FunctionalExpressionRing>( /** * Builds an Expression of multiplication of two expressions. */ - override fun multiply(a: Expression, b: Expression): Expression = - binaryOperationFunction(RingOperations.TIMES_OPERATION)(a, b) + override fun multiply(left: Expression, right: Expression): Expression = + binaryOperationFunction(RingOps.TIMES_OPERATION)(left, right) public operator fun Expression.times(arg: T): Expression = this * const(arg) public operator fun T.times(arg: Expression): Expression = arg * this @@ -107,8 +107,8 @@ public open class FunctionalExpressionField>( /** * Builds an Expression of division an expression by another one. */ - override fun divide(a: Expression, b: Expression): Expression = - binaryOperationFunction(FieldOperations.DIV_OPERATION)(a, b) + override fun divide(left: Expression, right: Expression): Expression = + binaryOperationFunction(FieldOps.DIV_OPERATION)(left, right) public operator fun Expression.div(arg: T): Expression = this / const(arg) public operator fun T.div(arg: Expression): Expression = arg / this diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/MstAlgebra.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/MstAlgebra.kt index bbc74005c..dd3c46207 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/MstAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/MstAlgebra.kt @@ -31,18 +31,18 @@ public object MstGroup : Group, NumericAlgebra, ScaleOperations { override fun number(value: Number): MST.Numeric = MstNumericAlgebra.number(value) override fun bindSymbolOrNull(value: String): Symbol = MstNumericAlgebra.bindSymbolOrNull(value) - override fun add(a: MST, b: MST): MST.Binary = binaryOperationFunction(GroupOperations.PLUS_OPERATION)(a, b) + override fun add(left: MST, right: MST): MST.Binary = binaryOperationFunction(GroupOps.PLUS_OPERATION)(left, right) override operator fun MST.unaryPlus(): MST.Unary = - unaryOperationFunction(GroupOperations.PLUS_OPERATION)(this) + unaryOperationFunction(GroupOps.PLUS_OPERATION)(this) override operator fun MST.unaryMinus(): MST.Unary = - unaryOperationFunction(GroupOperations.MINUS_OPERATION)(this) + unaryOperationFunction(GroupOps.MINUS_OPERATION)(this) - override operator fun MST.minus(b: MST): MST.Binary = - binaryOperationFunction(GroupOperations.MINUS_OPERATION)(this, b) + override operator fun MST.minus(other: MST): MST.Binary = + binaryOperationFunction(GroupOps.MINUS_OPERATION)(this, other) override fun scale(a: MST, value: Double): MST.Binary = - binaryOperationFunction(RingOperations.TIMES_OPERATION)(a, number(value)) + binaryOperationFunction(RingOps.TIMES_OPERATION)(a, number(value)) override fun binaryOperationFunction(operation: String): (left: MST, right: MST) -> MST.Binary = MstNumericAlgebra.binaryOperationFunction(operation) @@ -56,23 +56,23 @@ public object MstGroup : Group, NumericAlgebra, ScaleOperations { */ @Suppress("OVERRIDE_BY_INLINE") @OptIn(UnstableKMathAPI::class) -public object MstRing : Ring, NumbersAddOperations, ScaleOperations { +public object MstRing : Ring, NumbersAddOps, ScaleOperations { override inline val zero: MST.Numeric get() = MstGroup.zero override val one: MST.Numeric = number(1.0) override fun number(value: Number): MST.Numeric = MstGroup.number(value) override fun bindSymbolOrNull(value: String): Symbol = MstNumericAlgebra.bindSymbolOrNull(value) - override fun add(a: MST, b: MST): MST.Binary = MstGroup.add(a, b) + override fun add(left: MST, right: MST): MST.Binary = MstGroup.add(left, right) override fun scale(a: MST, value: Double): MST.Binary = - MstGroup.binaryOperationFunction(RingOperations.TIMES_OPERATION)(a, MstGroup.number(value)) + MstGroup.binaryOperationFunction(RingOps.TIMES_OPERATION)(a, MstGroup.number(value)) - override fun multiply(a: MST, b: MST): MST.Binary = - binaryOperationFunction(RingOperations.TIMES_OPERATION)(a, b) + override fun multiply(left: MST, right: MST): MST.Binary = + binaryOperationFunction(RingOps.TIMES_OPERATION)(left, right) override operator fun MST.unaryPlus(): MST.Unary = MstGroup { +this@unaryPlus } override operator fun MST.unaryMinus(): MST.Unary = MstGroup { -this@unaryMinus } - override operator fun MST.minus(b: MST): MST.Binary = MstGroup { this@minus - b } + override operator fun MST.minus(other: MST): MST.Binary = MstGroup { this@minus - other } override fun binaryOperationFunction(operation: String): (left: MST, right: MST) -> MST.Binary = MstGroup.binaryOperationFunction(operation) @@ -86,24 +86,24 @@ public object MstRing : Ring, NumbersAddOperations, ScaleOperations, NumbersAddOperations, ScaleOperations { +public object MstField : Field, NumbersAddOps, ScaleOperations { override inline val zero: MST.Numeric get() = MstRing.zero override inline val one: MST.Numeric get() = MstRing.one override fun bindSymbolOrNull(value: String): Symbol = MstNumericAlgebra.bindSymbolOrNull(value) override fun number(value: Number): MST.Numeric = MstRing.number(value) - override fun add(a: MST, b: MST): MST.Binary = MstRing.add(a, b) + override fun add(left: MST, right: MST): MST.Binary = MstRing.add(left, right) override fun scale(a: MST, value: Double): MST.Binary = - MstGroup.binaryOperationFunction(RingOperations.TIMES_OPERATION)(a, MstGroup.number(value)) + MstGroup.binaryOperationFunction(RingOps.TIMES_OPERATION)(a, MstGroup.number(value)) - override fun multiply(a: MST, b: MST): MST.Binary = MstRing.multiply(a, b) - override fun divide(a: MST, b: MST): MST.Binary = - binaryOperationFunction(FieldOperations.DIV_OPERATION)(a, b) + override fun multiply(left: MST, right: MST): MST.Binary = MstRing.multiply(left, right) + override fun divide(left: MST, right: MST): MST.Binary = + binaryOperationFunction(FieldOps.DIV_OPERATION)(left, right) override operator fun MST.unaryPlus(): MST.Unary = MstRing { +this@unaryPlus } override operator fun MST.unaryMinus(): MST.Unary = MstRing { -this@unaryMinus } - override operator fun MST.minus(b: MST): MST.Binary = MstRing { this@minus - b } + override operator fun MST.minus(other: MST): MST.Binary = MstRing { this@minus - other } override fun binaryOperationFunction(operation: String): (left: MST, right: MST) -> MST.Binary = MstRing.binaryOperationFunction(operation) @@ -134,17 +134,17 @@ public object MstExtendedField : ExtendedField, NumericAlgebra { override fun asinh(arg: MST): MST.Unary = unaryOperationFunction(ExponentialOperations.ASINH_OPERATION)(arg) override fun acosh(arg: MST): MST.Unary = unaryOperationFunction(ExponentialOperations.ACOSH_OPERATION)(arg) override fun atanh(arg: MST): MST.Unary = unaryOperationFunction(ExponentialOperations.ATANH_OPERATION)(arg) - override fun add(a: MST, b: MST): MST.Binary = MstField.add(a, b) + override fun add(left: MST, right: MST): MST.Binary = MstField.add(left, right) override fun sqrt(arg: MST): MST = unaryOperationFunction(PowerOperations.SQRT_OPERATION)(arg) override fun scale(a: MST, value: Double): MST = - binaryOperation(GroupOperations.PLUS_OPERATION, a, number(value)) + binaryOperation(GroupOps.PLUS_OPERATION, a, number(value)) - override fun multiply(a: MST, b: MST): MST.Binary = MstField.multiply(a, b) - override fun divide(a: MST, b: MST): MST.Binary = MstField.divide(a, b) + override fun multiply(left: MST, right: MST): MST.Binary = MstField.multiply(left, right) + override fun divide(left: MST, right: MST): MST.Binary = MstField.divide(left, right) override operator fun MST.unaryPlus(): MST.Unary = MstField { +this@unaryPlus } override operator fun MST.unaryMinus(): MST.Unary = MstField { -this@unaryMinus } - override operator fun MST.minus(b: MST): MST.Binary = MstField { this@minus - b } + override operator fun MST.minus(other: MST): MST.Binary = MstField { this@minus - other } override fun power(arg: MST, pow: Number): MST.Binary = binaryOperationFunction(PowerOperations.POW_OPERATION)(arg, number(pow)) diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/SimpleAutoDiff.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/SimpleAutoDiff.kt index d5b80da2c..704c4edd8 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/SimpleAutoDiff.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/SimpleAutoDiff.kt @@ -59,7 +59,7 @@ public fun DerivationResult.grad(vararg variables: Symbol): Point>( public val context: F, bindings: Map, -) : Field>, ExpressionAlgebra>, NumbersAddOperations> { +) : Field>, ExpressionAlgebra>, NumbersAddOps> { override val zero: AutoDiffValue get() = const(context.zero) override val one: AutoDiffValue get() = const(context.one) @@ -168,22 +168,22 @@ public open class SimpleAutoDiffField>( // Basic math (+, -, *, /) - override fun add(a: AutoDiffValue, b: AutoDiffValue): AutoDiffValue = - derive(const { a.value + b.value }) { z -> - a.d += z.d - b.d += z.d + override fun add(left: AutoDiffValue, right: AutoDiffValue): AutoDiffValue = + derive(const { left.value + right.value }) { z -> + left.d += z.d + right.d += z.d } - override fun multiply(a: AutoDiffValue, b: AutoDiffValue): AutoDiffValue = - derive(const { a.value * b.value }) { z -> - a.d += z.d * b.value - b.d += z.d * a.value + override fun multiply(left: AutoDiffValue, right: AutoDiffValue): AutoDiffValue = + derive(const { left.value * right.value }) { z -> + left.d += z.d * right.value + right.d += z.d * left.value } - override fun divide(a: AutoDiffValue, b: AutoDiffValue): AutoDiffValue = - derive(const { a.value / b.value }) { z -> - a.d += z.d / b.value - b.d -= z.d * a.value / (b.value * b.value) + override fun divide(left: AutoDiffValue, right: AutoDiffValue): AutoDiffValue = + derive(const { left.value / right.value }) { z -> + left.d += z.d / right.value + right.d -= z.d * left.value / (right.value * right.value) } override fun scale(a: AutoDiffValue, value: Double): AutoDiffValue = diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/BufferedLinearSpace.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/BufferedLinearSpace.kt index 3d562f26f..39dbe3a81 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/BufferedLinearSpace.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/BufferedLinearSpace.kt @@ -6,12 +6,10 @@ package space.kscience.kmath.linear import space.kscience.kmath.misc.PerformancePitfall -import space.kscience.kmath.nd.BufferedRingND +import space.kscience.kmath.nd.BufferedRingOpsND import space.kscience.kmath.nd.as2D import space.kscience.kmath.nd.asND -import space.kscience.kmath.nd.ndAlgebra -import space.kscience.kmath.operations.Ring -import space.kscience.kmath.operations.invoke +import space.kscience.kmath.operations.* import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.BufferFactory import space.kscience.kmath.structures.VirtualBuffer @@ -19,31 +17,28 @@ import space.kscience.kmath.structures.indices public class BufferedLinearSpace>( - override val elementAlgebra: A, - private val bufferFactory: BufferFactory, + private val bufferAlgebra: BufferAlgebra ) : LinearSpace { + override val elementAlgebra: A get() = bufferAlgebra.elementAlgebra - private fun ndRing( - rows: Int, - cols: Int, - ): BufferedRingND = elementAlgebra.ndAlgebra(bufferFactory, rows, cols) + private val ndAlgebra = BufferedRingOpsND(bufferAlgebra) override fun buildMatrix(rows: Int, columns: Int, initializer: A.(i: Int, j: Int) -> T): Matrix = - ndRing(rows, columns).produce { (i, j) -> elementAlgebra.initializer(i, j) }.as2D() + ndAlgebra.produce(intArrayOf(rows, columns)) { (i, j) -> elementAlgebra.initializer(i, j) }.as2D() override fun buildVector(size: Int, initializer: A.(Int) -> T): Point = - bufferFactory(size) { elementAlgebra.initializer(it) } + bufferAlgebra.buffer(size) { elementAlgebra.initializer(it) } - override fun Matrix.unaryMinus(): Matrix = ndRing(rowNum, colNum).run { + override fun Matrix.unaryMinus(): Matrix = ndAlgebra { asND().map { -it }.as2D() } - override fun Matrix.plus(other: Matrix): Matrix = ndRing(rowNum, colNum).run { + override fun Matrix.plus(other: Matrix): Matrix = ndAlgebra { require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::plus. Expected $shape but found ${other.shape}" } asND().plus(other.asND()).as2D() } - override fun Matrix.minus(other: Matrix): Matrix = ndRing(rowNum, colNum).run { + override fun Matrix.minus(other: Matrix): Matrix = ndAlgebra { require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::minus. Expected $shape but found ${other.shape}" } asND().minus(other.asND()).as2D() } @@ -88,11 +83,11 @@ public class BufferedLinearSpace>( } } - override fun Matrix.times(value: T): Matrix = ndRing(rowNum, colNum).run { + override fun Matrix.times(value: T): Matrix = ndAlgebra { asND().map { it * value }.as2D() } } public fun > A.linearSpace(bufferFactory: BufferFactory): BufferedLinearSpace = - BufferedLinearSpace(this, bufferFactory) + BufferedLinearSpace(BufferRingOps(this, bufferFactory)) diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/DoubleLinearSpace.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/DoubleLinearSpace.kt index c2f53939f..ec6040af0 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/DoubleLinearSpace.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/DoubleLinearSpace.kt @@ -6,11 +6,12 @@ package space.kscience.kmath.linear import space.kscience.kmath.misc.PerformancePitfall -import space.kscience.kmath.nd.DoubleFieldND +import space.kscience.kmath.nd.DoubleFieldOpsND import space.kscience.kmath.nd.as2D import space.kscience.kmath.nd.asND -import space.kscience.kmath.operations.DoubleBufferOperations +import space.kscience.kmath.operations.DoubleBufferOps import space.kscience.kmath.operations.DoubleField +import space.kscience.kmath.operations.invoke import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.DoubleBuffer @@ -18,30 +19,27 @@ public object DoubleLinearSpace : LinearSpace { override val elementAlgebra: DoubleField get() = DoubleField - private fun ndRing( - rows: Int, - cols: Int, - ): DoubleFieldND = DoubleFieldND(intArrayOf(rows, cols)) - override fun buildMatrix( rows: Int, columns: Int, initializer: DoubleField.(i: Int, j: Int) -> Double - ): Matrix = ndRing(rows, columns).produce { (i, j) -> DoubleField.initializer(i, j) }.as2D() + ): Matrix = DoubleFieldOpsND.produce(intArrayOf(rows, columns)) { (i, j) -> + DoubleField.initializer(i, j) + }.as2D() override fun buildVector(size: Int, initializer: DoubleField.(Int) -> Double): DoubleBuffer = DoubleBuffer(size) { DoubleField.initializer(it) } - override fun Matrix.unaryMinus(): Matrix = ndRing(rowNum, colNum).run { + override fun Matrix.unaryMinus(): Matrix = DoubleFieldOpsND { asND().map { -it }.as2D() } - override fun Matrix.plus(other: Matrix): Matrix = ndRing(rowNum, colNum).run { + override fun Matrix.plus(other: Matrix): Matrix = DoubleFieldOpsND { require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::plus. Expected $shape but found ${other.shape}" } asND().plus(other.asND()).as2D() } - override fun Matrix.minus(other: Matrix): Matrix = ndRing(rowNum, colNum).run { + override fun Matrix.minus(other: Matrix): Matrix = DoubleFieldOpsND { require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::minus. Expected $shape but found ${other.shape}" } asND().minus(other.asND()).as2D() } @@ -84,23 +82,23 @@ public object DoubleLinearSpace : LinearSpace { } - override fun Matrix.times(value: Double): Matrix = ndRing(rowNum, colNum).run { + override fun Matrix.times(value: Double): Matrix = DoubleFieldOpsND { asND().map { it * value }.as2D() } - public override fun Point.plus(other: Point): DoubleBuffer = DoubleBufferOperations.run { + public override fun Point.plus(other: Point): DoubleBuffer = DoubleBufferOps.run { this@plus + other } - public override fun Point.minus(other: Point): DoubleBuffer = DoubleBufferOperations.run { + public override fun Point.minus(other: Point): DoubleBuffer = DoubleBufferOps.run { this@minus - other } - public override fun Point.times(value: Double): DoubleBuffer = DoubleBufferOperations.run { + public override fun Point.times(value: Double): DoubleBuffer = DoubleBufferOps.run { scale(this@times, value) } - public operator fun Point.div(value: Double): DoubleBuffer = DoubleBufferOperations.run { + public operator fun Point.div(value: Double): DoubleBuffer = DoubleBufferOps.run { scale(this@div, 1.0 / value) } diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/LinearSpace.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/LinearSpace.kt index 1d8985b59..5349ad864 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/LinearSpace.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/LinearSpace.kt @@ -10,6 +10,7 @@ import space.kscience.kmath.nd.MutableStructure2D import space.kscience.kmath.nd.Structure2D import space.kscience.kmath.nd.StructureFeature import space.kscience.kmath.nd.as1D +import space.kscience.kmath.operations.BufferRingOps import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.Ring import space.kscience.kmath.operations.invoke @@ -188,7 +189,7 @@ public interface LinearSpace> { public fun > buffered( algebra: A, bufferFactory: BufferFactory = Buffer.Companion::boxing, - ): LinearSpace = BufferedLinearSpace(algebra, bufferFactory) + ): LinearSpace = BufferedLinearSpace(BufferRingOps(algebra, bufferFactory)) @Deprecated("use DoubleField.linearSpace") public val double: LinearSpace = buffered(DoubleField, ::DoubleBuffer) diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/AlgebraND.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/AlgebraND.kt index b925c2642..b4e8b7487 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/AlgebraND.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/AlgebraND.kt @@ -7,7 +7,6 @@ package space.kscience.kmath.nd import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.operations.* -import space.kscience.kmath.structures.* import kotlin.reflect.KClass /** @@ -19,6 +18,14 @@ import kotlin.reflect.KClass public class ShapeMismatchException(public val expected: IntArray, public val actual: IntArray) : RuntimeException("Shape ${actual.contentToString()} doesn't fit in expected shape ${expected.contentToString()}.") +public typealias Shape = IntArray + +public fun Shape(shapeFirst: Int, vararg shapeRest: Int): Shape = intArrayOf(shapeFirst, *shapeRest) + +public interface WithShape { + public val shape: Shape +} + /** * The base interface for all ND-algebra implementations. * @@ -26,20 +33,15 @@ public class ShapeMismatchException(public val expected: IntArray, public val ac * @param C the type of the element context. */ public interface AlgebraND> { - /** - * The shape of ND-structures this algebra operates on. - */ - public val shape: IntArray - /** * The algebra over elements of ND structure. */ - public val elementContext: C + public val elementAlgebra: C /** * Produces a new NDStructure using given initializer function. */ - public fun produce(initializer: C.(IntArray) -> T): StructureND + public fun produce(shape: Shape, initializer: C.(IntArray) -> T): StructureND /** * Maps elements from one structure to another one by applying [transform] to them. @@ -54,7 +56,7 @@ public interface AlgebraND> { /** * Combines two structures into one. */ - public fun combine(a: StructureND, b: StructureND, transform: C.(T, T) -> T): StructureND + public fun zip(left: StructureND, right: StructureND, transform: C.(T, T) -> T): StructureND /** * Element-wise invocation of function working on [T] on a [StructureND]. @@ -77,7 +79,6 @@ public interface AlgebraND> { public companion object } - /** * Get a feature of the structure in this scope. Structure features take precedence other context features. * @@ -89,46 +90,22 @@ public interface AlgebraND> { public inline fun AlgebraND.getFeature(structure: StructureND): F? = getFeature(structure, F::class) -/** - * Checks if given elements are consistent with this context. - * - * @param structures the structures to check. - * @return the array of valid structures. - */ -internal fun > AlgebraND.checkShape(vararg structures: StructureND): Array> = - structures - .map(StructureND::shape) - .singleOrNull { !shape.contentEquals(it) } - ?.let>> { throw ShapeMismatchException(shape, it) } - ?: structures - -/** - * Checks if given element is consistent with this context. - * - * @param element the structure to check. - * @return the valid structure. - */ -internal fun > AlgebraND.checkShape(element: StructureND): StructureND { - if (!element.shape.contentEquals(shape)) throw ShapeMismatchException(shape, element.shape) - return element -} - /** * Space of [StructureND]. * * @param T the type of the element contained in ND structure. - * @param S the type of group over structure elements. + * @param A the type of group over structure elements. */ -public interface GroupND> : Group>, AlgebraND { +public interface GroupOpsND> : GroupOps>, AlgebraND { /** * Element-wise addition. * - * @param a the augend. - * @param b the addend. + * @param left the augend. + * @param right the addend. * @return the sum. */ - override fun add(a: StructureND, b: StructureND): StructureND = - combine(a, b) { aValue, bValue -> add(aValue, bValue) } + override fun add(left: StructureND, right: StructureND): StructureND = + zip(left, right) { aValue, bValue -> add(aValue, bValue) } // TODO move to extensions after KEEP-176 @@ -157,7 +134,7 @@ public interface GroupND> : Group>, AlgebraND * @param arg the addend. * @return the sum. */ - public operator fun T.plus(arg: StructureND): StructureND = arg.map { value -> add(this@plus, value) } + public operator fun T.plus(arg: StructureND): StructureND = arg + this /** * Subtracts an ND structure from an element of it. @@ -171,22 +148,26 @@ public interface GroupND> : Group>, AlgebraND public companion object } +public interface GroupND> : Group>, GroupOpsND, WithShape { + override val zero: StructureND get() = produce(shape) { elementAlgebra.zero } +} + /** * Ring of [StructureND]. * * @param T the type of the element contained in ND structure. - * @param R the type of ring over structure elements. + * @param A the type of ring over structure elements. */ -public interface RingND> : Ring>, GroupND { +public interface RingOpsND> : RingOps>, GroupOpsND { /** * Element-wise multiplication. * - * @param a the multiplicand. - * @param b the multiplier. + * @param left the multiplicand. + * @param right the multiplier. * @return the product. */ - override fun multiply(a: StructureND, b: StructureND): StructureND = - combine(a, b) { aValue, bValue -> multiply(aValue, bValue) } + override fun multiply(left: StructureND, right: StructureND): StructureND = + zip(left, right) { aValue, bValue -> multiply(aValue, bValue) } //TODO move to extensions after KEEP-176 @@ -211,24 +192,32 @@ public interface RingND> : Ring>, GroupND> : Ring>, RingOpsND, GroupND, WithShape { + override val one: StructureND get() = produce(shape) { elementAlgebra.one } +} + + /** * Field of [StructureND]. * * @param T the type of the element contained in ND structure. - * @param F the type field over structure elements. + * @param A the type field over structure elements. */ -public interface FieldND> : Field>, RingND { +public interface FieldOpsND> : + FieldOps>, + RingOpsND, + ScaleOperations> { /** * Element-wise division. * - * @param a the dividend. - * @param b the divisor. + * @param left the dividend. + * @param right the divisor. * @return the quotient. */ - override fun divide(a: StructureND, b: StructureND): StructureND = - combine(a, b) { aValue, bValue -> divide(aValue, bValue) } + override fun divide(left: StructureND, right: StructureND): StructureND = + zip(left, right) { aValue, bValue -> divide(aValue, bValue) } - //TODO move to extensions after KEEP-176 + //TODO move to extensions after https://github.com/Kotlin/KEEP/blob/master/proposals/context-receivers.md /** * Divides an ND structure by an element of it. * @@ -247,42 +236,9 @@ public interface FieldND> : Field>, RingND): StructureND = arg.map { divide(it, this@div) } - /** - * Element-wise scaling. - * - * @param a the multiplicand. - * @param value the multiplier. - * @return the product. - */ override fun scale(a: StructureND, value: Double): StructureND = a.map { scale(it, value) } - -// @ThreadLocal -// public companion object { -// private val realNDFieldCache: MutableMap = hashMapOf() -// -// /** -// * Create a nd-field for [Double] values or pull it from cache if it was created previously. -// */ -// public fun real(vararg shape: Int): RealNDField = realNDFieldCache.getOrPut(shape) { RealNDField(shape) } -// -// /** -// * Create an ND field with boxing generic buffer. -// */ -// public fun > boxing( -// field: F, -// vararg shape: Int, -// bufferFactory: BufferFactory = Buffer.Companion::boxing, -// ): BufferedNDField = BufferedNDField(shape, field, bufferFactory) -// -// /** -// * Create a most suitable implementation for nd-field using reified class. -// */ -// @Suppress("UNCHECKED_CAST") -// public inline fun > auto(field: F, vararg shape: Int): NDField = -// when { -// T::class == Double::class -> real(*shape) as NDField -// T::class == Complex::class -> complex(*shape) as BufferedNDField -// else -> BoxingNDField(shape, field, Buffer.Companion::auto) -// } -// } } + +public interface FieldND> : Field>, FieldOpsND, RingND, WithShape { + override val one: StructureND get() = produce(shape) { elementAlgebra.one } +} \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/BufferAlgebraND.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/BufferAlgebraND.kt index ae72f3689..c94988eef 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/BufferAlgebraND.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/BufferAlgebraND.kt @@ -3,145 +3,177 @@ * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ +@file:OptIn(UnstableKMathAPI::class) + package space.kscience.kmath.nd import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.operations.* -import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.BufferFactory -import kotlin.contracts.InvocationKind -import kotlin.contracts.contract -import kotlin.jvm.JvmName public interface BufferAlgebraND> : AlgebraND { - public val strides: Strides - public val bufferFactory: BufferFactory + public val indexerBuilder: (IntArray) -> ShapeIndex + public val bufferAlgebra: BufferAlgebra + override val elementAlgebra: A get() = bufferAlgebra.elementAlgebra - override fun produce(initializer: A.(IntArray) -> T): BufferND = BufferND( - strides, - bufferFactory(strides.linearSize) { offset -> - elementContext.initializer(strides.index(offset)) + override fun produce(shape: Shape, initializer: A.(IntArray) -> T): BufferND { + val indexer = indexerBuilder(shape) + return BufferND( + indexer, + bufferAlgebra.buffer(indexer.linearSize) { offset -> + elementAlgebra.initializer(indexer.index(offset)) + } + ) + } + + public fun StructureND.toBufferND(): BufferND = when (this) { + is BufferND -> this + else -> { + val indexer = indexerBuilder(shape) + BufferND(indexer, bufferAlgebra.buffer(indexer.linearSize) { offset -> get(indexer.index(offset)) }) + } + } + + override fun StructureND.map(transform: A.(T) -> T): BufferND = mapInline(toBufferND(), transform) + + override fun StructureND.mapIndexed(transform: A.(index: IntArray, T) -> T): BufferND = + mapIndexedInline(toBufferND(), transform) + + override fun zip(left: StructureND, right: StructureND, transform: A.(T, T) -> T): BufferND = + zipInline(left.toBufferND(), right.toBufferND(), transform) + + public companion object { + public val defaultIndexerBuilder: (IntArray) -> ShapeIndex = DefaultStrides.Companion::invoke + } +} + +public inline fun > BufferAlgebraND.mapInline( + arg: BufferND, + crossinline transform: A.(T) -> T +): BufferND { + val indexes = arg.indexes + return BufferND(indexes, bufferAlgebra.mapInline(arg.buffer, transform)) +} + +internal inline fun > BufferAlgebraND.mapIndexedInline( + arg: BufferND, + crossinline transform: A.(index: IntArray, arg: T) -> T +): BufferND { + val indexes = arg.indexes + return BufferND( + indexes, + bufferAlgebra.mapIndexedInline(arg.buffer) { offset, value -> + transform(indexes.index(offset), value) } ) - - public val StructureND.buffer: Buffer - get() = when { - !shape.contentEquals(this@BufferAlgebraND.shape) -> throw ShapeMismatchException( - this@BufferAlgebraND.shape, - shape - ) - this is BufferND && this.strides == this@BufferAlgebraND.strides -> this.buffer - else -> bufferFactory(strides.linearSize) { offset -> get(strides.index(offset)) } - } - - override fun StructureND.map(transform: A.(T) -> T): BufferND { - val buffer = bufferFactory(strides.linearSize) { offset -> - elementContext.transform(buffer[offset]) - } - return BufferND(strides, buffer) - } - - override fun StructureND.mapIndexed(transform: A.(index: IntArray, T) -> T): BufferND { - val buffer = bufferFactory(strides.linearSize) { offset -> - elementContext.transform( - strides.index(offset), - buffer[offset] - ) - } - return BufferND(strides, buffer) - } - - override fun combine(a: StructureND, b: StructureND, transform: A.(T, T) -> T): BufferND { - val buffer = bufferFactory(strides.linearSize) { offset -> - elementContext.transform(a.buffer[offset], b.buffer[offset]) - } - return BufferND(strides, buffer) - } } -public open class BufferedGroupND>( - final override val shape: IntArray, - final override val elementContext: A, - final override val bufferFactory: BufferFactory, -) : GroupND, BufferAlgebraND { - override val strides: Strides = DefaultStrides(shape) - override val zero: BufferND by lazy { produce { zero } } - override fun StructureND.unaryMinus(): StructureND = produce { -get(it) } +internal inline fun > BufferAlgebraND.zipInline( + l: BufferND, + r: BufferND, + crossinline block: A.(l: T, r: T) -> T +): BufferND { + require(l.indexes == r.indexes) { "Zip requires the same shapes, but found ${l.shape} on the left and ${r.shape} on the right" } + val indexes = l.indexes + return BufferND(indexes, bufferAlgebra.zipInline(l.buffer, r.buffer, block)) } -public open class BufferedRingND>( - shape: IntArray, - elementContext: R, - bufferFactory: BufferFactory, -) : BufferedGroupND(shape, elementContext, bufferFactory), RingND { - override val one: BufferND by lazy { produce { one } } +public open class BufferedGroupNDOps>( + override val bufferAlgebra: BufferAlgebra, + override val indexerBuilder: (IntArray) -> ShapeIndex = BufferAlgebraND.defaultIndexerBuilder +) : GroupOpsND, BufferAlgebraND { + override fun StructureND.unaryMinus(): StructureND = map { -it } } -public open class BufferedFieldND>( - shape: IntArray, - elementContext: R, - bufferFactory: BufferFactory, -) : BufferedRingND(shape, elementContext, bufferFactory), FieldND { +public open class BufferedRingOpsND>( + bufferAlgebra: BufferAlgebra, + indexerBuilder: (IntArray) -> ShapeIndex = BufferAlgebraND.defaultIndexerBuilder +) : BufferedGroupNDOps(bufferAlgebra, indexerBuilder), RingOpsND + +public open class BufferedFieldOpsND>( + bufferAlgebra: BufferAlgebra, + indexerBuilder: (IntArray) -> ShapeIndex = BufferAlgebraND.defaultIndexerBuilder +) : BufferedRingOpsND(bufferAlgebra, indexerBuilder), FieldOpsND { + + public constructor( + elementAlgebra: A, + bufferFactory: BufferFactory, + indexerBuilder: (IntArray) -> ShapeIndex = BufferAlgebraND.defaultIndexerBuilder + ) : this(BufferFieldOps(elementAlgebra, bufferFactory), indexerBuilder) override fun scale(a: StructureND, value: Double): StructureND = a.map { it * value } } -// group factories -public fun > A.ndAlgebra( - bufferFactory: BufferFactory, - vararg shape: Int, -): BufferedGroupND = BufferedGroupND(shape, this, bufferFactory) +public val > BufferAlgebra.nd: BufferedGroupNDOps get() = BufferedGroupNDOps(this) +public val > BufferAlgebra.nd: BufferedRingOpsND get() = BufferedRingOpsND(this) +public val > BufferAlgebra.nd: BufferedFieldOpsND get() = BufferedFieldOpsND(this) -@JvmName("withNdGroup") -public inline fun , R> A.withNdAlgebra( - noinline bufferFactory: BufferFactory, - vararg shape: Int, - action: BufferedGroupND.() -> R, -): R { - contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) } - return ndAlgebra(bufferFactory, *shape).run(action) -} -//ring factories -public fun > A.ndAlgebra( - bufferFactory: BufferFactory, +public fun > BufferAlgebraND.produce( vararg shape: Int, -): BufferedRingND = BufferedRingND(shape, this, bufferFactory) + initializer: A.(IntArray) -> T +): BufferND = produce(shape, initializer) -@JvmName("withNdRing") -public inline fun , R> A.withNdAlgebra( - noinline bufferFactory: BufferFactory, - vararg shape: Int, - action: BufferedRingND.() -> R, -): R { - contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) } - return ndAlgebra(bufferFactory, *shape).run(action) -} +public fun , A> A.produce( + initializer: EA.(IntArray) -> T +): BufferND where A : BufferAlgebraND, A : WithShape = produce(shape, initializer) -//field factories -public fun > A.ndAlgebra( - bufferFactory: BufferFactory, - vararg shape: Int, -): BufferedFieldND = BufferedFieldND(shape, this, bufferFactory) +//// group factories +//public fun > A.ndAlgebra( +// bufferAlgebra: BufferAlgebra, +// vararg shape: Int, +//): BufferedGroupNDOps = BufferedGroupNDOps(bufferAlgebra) +// +//@JvmName("withNdGroup") +//public inline fun , R> A.withNdAlgebra( +// noinline bufferFactory: BufferFactory, +// vararg shape: Int, +// action: BufferedGroupNDOps.() -> R, +//): R { +// contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) } +// return ndAlgebra(bufferFactory, *shape).run(action) +//} -/** - * Create a [FieldND] for this [Field] inferring proper buffer factory from the type - */ -@UnstableKMathAPI -@Suppress("UNCHECKED_CAST") -public inline fun > A.autoNdAlgebra( - vararg shape: Int, -): FieldND = when (this) { - DoubleField -> DoubleFieldND(shape) as FieldND - else -> BufferedFieldND(shape, this, Buffer.Companion::auto) -} - -@JvmName("withNdField") -public inline fun , R> A.withNdAlgebra( - noinline bufferFactory: BufferFactory, - vararg shape: Int, - action: BufferedFieldND.() -> R, -): R { - contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) } - return ndAlgebra(bufferFactory, *shape).run(action) -} \ No newline at end of file +////ring factories +//public fun > A.ndAlgebra( +// bufferFactory: BufferFactory, +// vararg shape: Int, +//): BufferedRingNDOps = BufferedRingNDOps(shape, this, bufferFactory) +// +//@JvmName("withNdRing") +//public inline fun , R> A.withNdAlgebra( +// noinline bufferFactory: BufferFactory, +// vararg shape: Int, +// action: BufferedRingNDOps.() -> R, +//): R { +// contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) } +// return ndAlgebra(bufferFactory, *shape).run(action) +//} +// +////field factories +//public fun > A.ndAlgebra( +// bufferFactory: BufferFactory, +// vararg shape: Int, +//): BufferedFieldNDOps = BufferedFieldNDOps(shape, this, bufferFactory) +// +///** +// * Create a [FieldND] for this [Field] inferring proper buffer factory from the type +// */ +//@UnstableKMathAPI +//@Suppress("UNCHECKED_CAST") +//public inline fun > A.autoNdAlgebra( +// vararg shape: Int, +//): FieldND = when (this) { +// DoubleField -> DoubleFieldND(shape) as FieldND +// else -> BufferedFieldNDOps(shape, this, Buffer.Companion::auto) +//} +// +//@JvmName("withNdField") +//public inline fun , R> A.withNdAlgebra( +// noinline bufferFactory: BufferFactory, +// vararg shape: Int, +// action: BufferedFieldNDOps.() -> R, +//): R { +// contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) } +// return ndAlgebra(bufferFactory, *shape).run(action) +//} \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/BufferND.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/BufferND.kt index 694e0ceae..c17632101 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/BufferND.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/BufferND.kt @@ -15,26 +15,20 @@ import space.kscience.kmath.structures.MutableBufferFactory * Represents [StructureND] over [Buffer]. * * @param T the type of items. - * @param strides The strides to access elements of [Buffer] by linear indices. + * @param indexes The strides to access elements of [Buffer] by linear indices. * @param buffer The underlying buffer. */ public open class BufferND( - public val strides: Strides, - public val buffer: Buffer, + public val indexes: ShapeIndex, + public open val buffer: Buffer, ) : StructureND { - init { - if (strides.linearSize != buffer.size) { - error("Expected buffer side of ${strides.linearSize}, but found ${buffer.size}") - } - } + override operator fun get(index: IntArray): T = buffer[indexes.offset(index)] - override operator fun get(index: IntArray): T = buffer[strides.offset(index)] - - override val shape: IntArray get() = strides.shape + override val shape: IntArray get() = indexes.shape @PerformancePitfall - override fun elements(): Sequence> = strides.indices().map { + override fun elements(): Sequence> = indexes.indices().map { it to this[it] } @@ -49,7 +43,7 @@ public inline fun StructureND.mapToBuffer( crossinline transform: (T) -> R, ): BufferND { return if (this is BufferND) - BufferND(this.strides, factory.invoke(strides.linearSize) { transform(buffer[it]) }) + BufferND(this.indexes, factory.invoke(indexes.linearSize) { transform(buffer[it]) }) else { val strides = DefaultStrides(shape) BufferND(strides, factory.invoke(strides.linearSize) { transform(get(strides.index(it))) }) @@ -61,14 +55,14 @@ public inline fun StructureND.mapToBuffer( * * @param T the type of items. * @param strides The strides to access elements of [MutableBuffer] by linear indices. - * @param mutableBuffer The underlying buffer. + * @param buffer The underlying buffer. */ public class MutableBufferND( - strides: Strides, - public val mutableBuffer: MutableBuffer, -) : MutableStructureND, BufferND(strides, mutableBuffer) { + strides: ShapeIndex, + override val buffer: MutableBuffer, +) : MutableStructureND, BufferND(strides, buffer) { override fun set(index: IntArray, value: T) { - mutableBuffer[strides.offset(index)] = value + buffer[indexes.offset(index)] = value } } @@ -80,7 +74,7 @@ public inline fun MutableStructureND.mapToMutableBuffer( crossinline transform: (T) -> R, ): MutableBufferND { return if (this is MutableBufferND) - MutableBufferND(this.strides, factory.invoke(strides.linearSize) { transform(mutableBuffer[it]) }) + MutableBufferND(this.indexes, factory.invoke(indexes.linearSize) { transform(buffer[it]) }) else { val strides = DefaultStrides(shape) MutableBufferND(strides, factory.invoke(strides.linearSize) { transform(get(strides.index(it))) }) diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/DoubleFieldND.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/DoubleFieldND.kt index 0c7d4d5d1..1502a6fd0 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/DoubleFieldND.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/DoubleFieldND.kt @@ -6,108 +6,158 @@ package space.kscience.kmath.nd import space.kscience.kmath.misc.UnstableKMathAPI -import space.kscience.kmath.operations.DoubleField -import space.kscience.kmath.operations.ExtendedField -import space.kscience.kmath.operations.NumbersAddOperations -import space.kscience.kmath.operations.ScaleOperations +import space.kscience.kmath.operations.* import space.kscience.kmath.structures.DoubleBuffer import kotlin.contracts.InvocationKind import kotlin.contracts.contract +import kotlin.math.pow + +public class DoubleBufferND( + indexes: ShapeIndex, + override val buffer: DoubleBuffer, +) : BufferND(indexes, buffer) + + +public sealed class DoubleFieldOpsND : BufferedFieldOpsND(DoubleField.bufferAlgebra), + ScaleOperations>, ExtendedFieldOps> { + + override fun StructureND.toBufferND(): DoubleBufferND = when (this) { + is DoubleBufferND -> this + else -> { + val indexer = indexerBuilder(shape) + DoubleBufferND(indexer, DoubleBuffer(indexer.linearSize) { offset -> get(indexer.index(offset)) }) + } + } + + private inline fun mapInline( + arg: DoubleBufferND, + transform: (Double) -> Double + ): DoubleBufferND { + val indexes = arg.indexes + val array = arg.buffer.array + return DoubleBufferND(indexes, DoubleBuffer(indexes.linearSize) { transform(array[it]) }) + } + + private inline fun zipInline( + l: DoubleBufferND, + r: DoubleBufferND, + block: (l: Double, r: Double) -> Double + ): DoubleBufferND { + require(l.indexes == r.indexes) { "Zip requires the same shapes, but found ${l.shape} on the left and ${r.shape} on the right" } + val indexes = l.indexes + val lArray = l.buffer.array + val rArray = r.buffer.array + return DoubleBufferND(indexes, DoubleBuffer(indexes.linearSize) { block(lArray[it], rArray[it]) }) + } + + override fun StructureND.map(transform: DoubleField.(Double) -> Double): BufferND = + mapInline(toBufferND()) { DoubleField.transform(it) } + + + override fun zip( + left: StructureND, + right: StructureND, + transform: DoubleField.(Double, Double) -> Double + ): BufferND = zipInline(left.toBufferND(), right.toBufferND()) { l, r -> DoubleField.transform(l, r) } + + override fun produce(shape: Shape, initializer: DoubleField.(IntArray) -> Double): DoubleBufferND { + val indexer = indexerBuilder(shape) + return DoubleBufferND( + indexer, + DoubleBuffer(indexer.linearSize) { offset -> + elementAlgebra.initializer(indexer.index(offset)) + } + ) + } + + override fun add(left: StructureND, right: StructureND): DoubleBufferND = + zipInline(left.toBufferND(), right.toBufferND()) { l, r -> l + r } + + override fun multiply(left: StructureND, right: StructureND): DoubleBufferND = + zipInline(left.toBufferND(), right.toBufferND()) { l, r -> l * r } + + override fun StructureND.unaryMinus(): DoubleBufferND = mapInline(toBufferND()) { -it } + + override fun StructureND.div(other: StructureND): DoubleBufferND = + zipInline(toBufferND(), other.toBufferND()) { l, r -> l / r } + + override fun StructureND.plus(arg: Double): DoubleBufferND = mapInline(toBufferND()) { it + arg } + + override fun StructureND.minus(arg: Double): StructureND = mapInline(toBufferND()) { it - arg } + + override fun Double.plus(arg: StructureND): StructureND = arg + this + + override fun Double.minus(arg: StructureND): StructureND = mapInline(arg.toBufferND()) { this - it } + + override fun scale(a: StructureND, value: Double): DoubleBufferND = + mapInline(a.toBufferND()) { it * value } + + override fun power(arg: StructureND, pow: Number): DoubleBufferND = + mapInline(arg.toBufferND()) { it.pow(pow.toDouble()) } + + override fun exp(arg: StructureND): DoubleBufferND = + mapInline(arg.toBufferND()) { kotlin.math.exp(it) } + + override fun ln(arg: StructureND): DoubleBufferND = + mapInline(arg.toBufferND()) { kotlin.math.ln(it) } + + override fun sin(arg: StructureND): DoubleBufferND = + mapInline(arg.toBufferND()) { kotlin.math.sin(it) } + + override fun cos(arg: StructureND): DoubleBufferND = + mapInline(arg.toBufferND()) { kotlin.math.cos(it) } + + override fun tan(arg: StructureND): DoubleBufferND = + mapInline(arg.toBufferND()) { kotlin.math.tan(it) } + + override fun asin(arg: StructureND): DoubleBufferND = + mapInline(arg.toBufferND()) { kotlin.math.asin(it) } + + override fun acos(arg: StructureND): DoubleBufferND = + mapInline(arg.toBufferND()) { kotlin.math.acos(it) } + + override fun atan(arg: StructureND): DoubleBufferND = + mapInline(arg.toBufferND()) { kotlin.math.atan(it) } + + override fun sinh(arg: StructureND): DoubleBufferND = + mapInline(arg.toBufferND()) { kotlin.math.sinh(it) } + + override fun cosh(arg: StructureND): DoubleBufferND = + mapInline(arg.toBufferND()) { kotlin.math.cosh(it) } + + override fun tanh(arg: StructureND): DoubleBufferND = + mapInline(arg.toBufferND()) { kotlin.math.tanh(it) } + + override fun asinh(arg: StructureND): DoubleBufferND = + mapInline(arg.toBufferND()) { kotlin.math.asinh(it) } + + override fun acosh(arg: StructureND): DoubleBufferND = + mapInline(arg.toBufferND()) { kotlin.math.acosh(it) } + + override fun atanh(arg: StructureND): DoubleBufferND = + mapInline(arg.toBufferND()) { kotlin.math.atanh(it) } + + public companion object : DoubleFieldOpsND() +} @OptIn(UnstableKMathAPI::class) -public class DoubleFieldND( - shape: IntArray, -) : BufferedFieldND(shape, DoubleField, ::DoubleBuffer), - NumbersAddOperations>, - ScaleOperations>, - ExtendedField> { +public class DoubleFieldND(override val shape: Shape) : + DoubleFieldOpsND(), FieldND, NumbersAddOps> { - override val zero: BufferND by lazy { produce { zero } } - override val one: BufferND by lazy { produce { one } } - - override fun number(value: Number): BufferND { + override fun number(value: Number): DoubleBufferND { val d = value.toDouble() // minimize conversions - return produce { d } + return produce(shape) { d } } - - override val StructureND.buffer: DoubleBuffer - get() = when { - !shape.contentEquals(this@DoubleFieldND.shape) -> throw ShapeMismatchException( - this@DoubleFieldND.shape, - shape - ) - this is BufferND && this.strides == this@DoubleFieldND.strides -> this.buffer as DoubleBuffer - else -> DoubleBuffer(strides.linearSize) { offset -> get(strides.index(offset)) } - } - - @Suppress("OVERRIDE_BY_INLINE") - override inline fun StructureND.map( - transform: DoubleField.(Double) -> Double, - ): BufferND { - val buffer = DoubleBuffer(strides.linearSize) { offset -> DoubleField.transform(buffer.array[offset]) } - return BufferND(strides, buffer) - } - - @Suppress("OVERRIDE_BY_INLINE") - override inline fun produce(initializer: DoubleField.(IntArray) -> Double): BufferND { - val array = DoubleArray(strides.linearSize) { offset -> - val index = strides.index(offset) - DoubleField.initializer(index) - } - return BufferND(strides, DoubleBuffer(array)) - } - - @Suppress("OVERRIDE_BY_INLINE") - override inline fun StructureND.mapIndexed( - transform: DoubleField.(index: IntArray, Double) -> Double, - ): BufferND = BufferND( - strides, - buffer = DoubleBuffer(strides.linearSize) { offset -> - DoubleField.transform( - strides.index(offset), - buffer.array[offset] - ) - }) - - @Suppress("OVERRIDE_BY_INLINE") - override inline fun combine( - a: StructureND, - b: StructureND, - transform: DoubleField.(Double, Double) -> Double, - ): BufferND { - val buffer = DoubleBuffer(strides.linearSize) { offset -> - DoubleField.transform(a.buffer.array[offset], b.buffer.array[offset]) - } - return BufferND(strides, buffer) - } - - override fun scale(a: StructureND, value: Double): StructureND = a.map { it * value } - - override fun power(arg: StructureND, pow: Number): BufferND = arg.map { power(it, pow) } - - override fun exp(arg: StructureND): BufferND = arg.map { exp(it) } - override fun ln(arg: StructureND): BufferND = arg.map { ln(it) } - - override fun sin(arg: StructureND): BufferND = arg.map { sin(it) } - override fun cos(arg: StructureND): BufferND = arg.map { cos(it) } - override fun tan(arg: StructureND): BufferND = arg.map { tan(it) } - override fun asin(arg: StructureND): BufferND = arg.map { asin(it) } - override fun acos(arg: StructureND): BufferND = arg.map { acos(it) } - override fun atan(arg: StructureND): BufferND = arg.map { atan(it) } - - override fun sinh(arg: StructureND): BufferND = arg.map { sinh(it) } - override fun cosh(arg: StructureND): BufferND = arg.map { cosh(it) } - override fun tanh(arg: StructureND): BufferND = arg.map { tanh(it) } - override fun asinh(arg: StructureND): BufferND = arg.map { asinh(it) } - override fun acosh(arg: StructureND): BufferND = arg.map { acosh(it) } - override fun atanh(arg: StructureND): BufferND = arg.map { atanh(it) } } +public val DoubleField.ndAlgebra: DoubleFieldOpsND get() = DoubleFieldOpsND + public fun DoubleField.ndAlgebra(vararg shape: Int): DoubleFieldND = DoubleFieldND(shape) /** * Produce a context for n-dimensional operations inside this real field */ +@UnstableKMathAPI public inline fun DoubleField.withNdAlgebra(vararg shape: Int, action: DoubleFieldND.() -> R): R { contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) } return DoubleFieldND(shape).run(action) diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/ShapeIndex.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/ShapeIndex.kt new file mode 100644 index 000000000..bdbae70c2 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/ShapeIndex.kt @@ -0,0 +1,120 @@ +/* + * Copyright 2018-2021 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. + */ + +package space.kscience.kmath.nd + +import kotlin.native.concurrent.ThreadLocal + +/** + * A converter from linear index to multivariate index + */ +public interface ShapeIndex{ + public val shape: Shape + + /** + * Get linear index from multidimensional index + */ + public fun offset(index: IntArray): Int + + /** + * Get multidimensional from linear + */ + public fun index(offset: Int): IntArray + + /** + * The size of linear buffer to accommodate all elements of ND-structure corresponding to strides + */ + public val linearSize: Int + + // TODO introduce a fast way to calculate index of the next element? + + /** + * Iterate over ND indices in a natural order + */ + public fun indices(): Sequence + + override fun equals(other: Any?): Boolean + override fun hashCode(): Int +} + +/** + * Linear transformation of indexes + */ +public abstract class Strides: ShapeIndex { + /** + * Array strides + */ + public abstract val strides: IntArray + + public override fun offset(index: IntArray): Int = index.mapIndexed { i, value -> + if (value < 0 || value >= shape[i]) throw IndexOutOfBoundsException("Index $value out of shape bounds: (0,${this.shape[i]})") + value * strides[i] + }.sum() + + // TODO introduce a fast way to calculate index of the next element? + + /** + * Iterate over ND indices in a natural order + */ + public override fun indices(): Sequence = (0 until linearSize).asSequence().map(::index) +} + +/** + * Simple implementation of [Strides]. + */ +public class DefaultStrides private constructor(override val shape: IntArray) : Strides() { + override val linearSize: Int get() = strides[shape.size] + + /** + * Strides for memory access + */ + override val strides: IntArray by lazy { + sequence { + var current = 1 + yield(1) + + shape.forEach { + current *= it + yield(current) + } + }.toList().toIntArray() + } + + override fun index(offset: Int): IntArray { + val res = IntArray(shape.size) + var current = offset + var strideIndex = strides.size - 2 + + while (strideIndex >= 0) { + res[strideIndex] = (current / strides[strideIndex]) + current %= strides[strideIndex] + strideIndex-- + } + + return res + } + + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (other !is DefaultStrides) return false + if (!shape.contentEquals(other.shape)) return false + return true + } + + override fun hashCode(): Int = shape.contentHashCode() + + @ThreadLocal + public companion object { + //private val defaultStridesCache = HashMap() + + /** + * Cached builder for default strides + */ + public operator fun invoke(shape: IntArray): Strides = DefaultStrides(shape) + //defaultStridesCache.getOrPut(shape) { DefaultStrides(shape) } + + //TODO fix cache + } +} \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/ShortRingND.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/ShortRingND.kt index b56bef230..65c1f71b4 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/ShortRingND.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/ShortRingND.kt @@ -6,34 +6,27 @@ package space.kscience.kmath.nd import space.kscience.kmath.misc.UnstableKMathAPI -import space.kscience.kmath.operations.NumbersAddOperations +import space.kscience.kmath.operations.NumbersAddOps import space.kscience.kmath.operations.ShortRing -import space.kscience.kmath.structures.Buffer -import space.kscience.kmath.structures.ShortBuffer +import space.kscience.kmath.operations.bufferAlgebra import kotlin.contracts.InvocationKind import kotlin.contracts.contract +public sealed class ShortRingOpsND : BufferedRingOpsND(ShortRing.bufferAlgebra) { + public companion object : ShortRingOpsND() +} + @OptIn(UnstableKMathAPI::class) public class ShortRingND( - shape: IntArray, -) : BufferedRingND(shape, ShortRing, Buffer.Companion::auto), - NumbersAddOperations> { - - override val zero: BufferND by lazy { produce { zero } } - override val one: BufferND by lazy { produce { one } } + override val shape: Shape +) : ShortRingOpsND(), RingND, NumbersAddOps> { override fun number(value: Number): BufferND { val d = value.toShort() // minimize conversions - return produce { d } + return produce(shape) { d } } } -/** - * Fast element production using function inlining. - */ -public inline fun BufferedRingND.produceInline(crossinline initializer: ShortRing.(Int) -> Short): BufferND = - BufferND(strides, ShortBuffer(ShortArray(strides.linearSize) { offset -> ShortRing.initializer(offset) })) - public inline fun ShortRing.withNdAlgebra(vararg shape: Int, action: ShortRingND.() -> R): R { contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) } return ShortRingND(shape).run(action) diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/StructureND.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/StructureND.kt index 6123336ba..611d2724f 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/StructureND.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/StructureND.kt @@ -15,7 +15,6 @@ import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.BufferFactory import kotlin.jvm.JvmName import kotlin.math.abs -import kotlin.native.concurrent.ThreadLocal import kotlin.reflect.KClass public interface StructureFeature : Feature @@ -72,7 +71,7 @@ public interface StructureND : Featured { if (st1 === st2) return true // fast comparison of buffers if possible - if (st1 is BufferND && st2 is BufferND && st1.strides == st2.strides) + if (st1 is BufferND && st2 is BufferND && st1.indexes == st2.indexes) return Buffer.contentEquals(st1.buffer, st2.buffer) //element by element comparison if it could not be avoided @@ -88,7 +87,7 @@ public interface StructureND : Featured { if (st1 === st2) return true // fast comparison of buffers if possible - if (st1 is BufferND && st2 is BufferND && st1.strides == st2.strides) + if (st1 is BufferND && st2 is BufferND && st1.indexes == st2.indexes) return Buffer.contentEquals(st1.buffer, st2.buffer) //element by element comparison if it could not be avoided @@ -187,11 +186,11 @@ public fun > LinearSpace>.contentEquals( * Indicates whether some [StructureND] is equal to another one with [absoluteTolerance]. */ @PerformancePitfall -public fun > GroupND>.contentEquals( +public fun > GroupOpsND>.contentEquals( st1: StructureND, st2: StructureND, absoluteTolerance: T, -): Boolean = st1.elements().all { (index, value) -> elementContext { (value - st2[index]) } < absoluteTolerance } +): Boolean = st1.elements().all { (index, value) -> elementAlgebra { (value - st2[index]) } < absoluteTolerance } /** * Indicates whether some [StructureND] is equal to another one with [absoluteTolerance]. @@ -231,107 +230,10 @@ public interface MutableStructureND : StructureND { * Transform a structure element-by element in place. */ @OptIn(PerformancePitfall::class) -public inline fun MutableStructureND.mapInPlace(action: (IntArray, T) -> T): Unit = +public inline fun MutableStructureND.mapInPlace(action: (index: IntArray, t: T) -> T): Unit = elements().forEach { (index, oldValue) -> this[index] = action(index, oldValue) } -/** - * A way to convert ND indices to linear one and back. - */ -public interface Strides { - /** - * Shape of NDStructure - */ - public val shape: IntArray - - /** - * Array strides - */ - public val strides: IntArray - - /** - * Get linear index from multidimensional index - */ - public fun offset(index: IntArray): Int = index.mapIndexed { i, value -> - if (value < 0 || value >= shape[i]) throw IndexOutOfBoundsException("Index $value out of shape bounds: (0,${this.shape[i]})") - value * strides[i] - }.sum() - - /** - * Get multidimensional from linear - */ - public fun index(offset: Int): IntArray - - /** - * The size of linear buffer to accommodate all elements of ND-structure corresponding to strides - */ - public val linearSize: Int - - // TODO introduce a fast way to calculate index of the next element? - - /** - * Iterate over ND indices in a natural order - */ - public fun indices(): Sequence = (0 until linearSize).asSequence().map(::index) -} - -/** - * Simple implementation of [Strides]. - */ -public class DefaultStrides private constructor(override val shape: IntArray) : Strides { - override val linearSize: Int - get() = strides[shape.size] - - /** - * Strides for memory access - */ - override val strides: IntArray by lazy { - sequence { - var current = 1 - yield(1) - - shape.forEach { - current *= it - yield(current) - } - }.toList().toIntArray() - } - - override fun index(offset: Int): IntArray { - val res = IntArray(shape.size) - var current = offset - var strideIndex = strides.size - 2 - - while (strideIndex >= 0) { - res[strideIndex] = (current / strides[strideIndex]) - current %= strides[strideIndex] - strideIndex-- - } - - return res - } - - override fun equals(other: Any?): Boolean { - if (this === other) return true - if (other !is DefaultStrides) return false - if (!shape.contentEquals(other.shape)) return false - return true - } - - override fun hashCode(): Int = shape.contentHashCode() - - @ThreadLocal - public companion object { - private val defaultStridesCache = HashMap() - - /** - * Cached builder for default strides - */ - public operator fun invoke(shape: IntArray): Strides = - defaultStridesCache.getOrPut(shape) { DefaultStrides(shape) } - } -} - -public inline fun StructureND.combine( +public inline fun StructureND.zip( struct: StructureND, crossinline block: (T, T) -> T, ): StructureND { diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/algebraNDExtentions.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/algebraNDExtentions.kt new file mode 100644 index 000000000..7bc18a4dd --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/algebraNDExtentions.kt @@ -0,0 +1,34 @@ +/* + * Copyright 2018-2021 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. + */ + +package space.kscience.kmath.nd + +import space.kscience.kmath.operations.Algebra +import space.kscience.kmath.operations.Group +import space.kscience.kmath.operations.Ring +import kotlin.jvm.JvmName + + +public fun > AlgebraND.produce( + shapeFirst: Int, + vararg shapeRest: Int, + initializer: A.(IntArray) -> T +): StructureND = produce(Shape(shapeFirst, *shapeRest), initializer) + +public fun > AlgebraND.zero(shape: Shape): StructureND = produce(shape) { zero } + +@JvmName("zeroVarArg") +public fun > AlgebraND.zero( + shapeFirst: Int, + vararg shapeRest: Int, +): StructureND = produce(shapeFirst, *shapeRest) { zero } + +public fun > AlgebraND.one(shape: Shape): StructureND = produce(shape) { one } + +@JvmName("oneVarArg") +public fun > AlgebraND.one( + shapeFirst: Int, + vararg shapeRest: Int, +): StructureND = produce(shapeFirst, *shapeRest) { one } \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/Algebra.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/Algebra.kt index daff58d9a..d0b0c0b73 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/Algebra.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/Algebra.kt @@ -117,15 +117,15 @@ public inline operator fun , R> A.invoke(block: A.() -> R): R = r * * @param T the type of element of this semispace. */ -public interface GroupOperations : Algebra { +public interface GroupOps : Algebra { /** * Addition of two elements. * - * @param a the augend. - * @param b the addend. + * @param left the augend. + * @param right the addend. * @return the sum. */ - public fun add(a: T, b: T): T + public fun add(left: T, right: T): T // Operations to be performed in this context. Could be moved to extensions in case of KEEP-176. @@ -149,20 +149,20 @@ public interface GroupOperations : Algebra { * Addition of two elements. * * @receiver the augend. - * @param b the addend. + * @param other the addend. * @return the sum. */ - public operator fun T.plus(b: T): T = add(this, b) + public operator fun T.plus(other: T): T = add(this, other) /** * Subtraction of two elements. * * @receiver the minuend. - * @param b the subtrahend. + * @param other the subtrahend. * @return the difference. */ - public operator fun T.minus(b: T): T = add(this, -b) - + public operator fun T.minus(other: T): T = add(this, -other) + // Dynamic dispatch of operations override fun unaryOperationFunction(operation: String): (arg: T) -> T = when (operation) { PLUS_OPERATION -> { arg -> +arg } MINUS_OPERATION -> { arg -> -arg } @@ -193,7 +193,7 @@ public interface GroupOperations : Algebra { * * @param T the type of element of this semispace. */ -public interface Group : GroupOperations { +public interface Group : GroupOps { /** * The neutral element of addition. */ @@ -206,22 +206,22 @@ public interface Group : GroupOperations { * * @param T the type of element of this semiring. */ -public interface RingOperations : GroupOperations { +public interface RingOps : GroupOps { /** * Multiplies two elements. * - * @param a the multiplier. - * @param b the multiplicand. + * @param left the multiplier. + * @param right the multiplicand. */ - public fun multiply(a: T, b: T): T + public fun multiply(left: T, right: T): T /** * Multiplies this element by scalar. * * @receiver the multiplier. - * @param b the multiplicand. + * @param other the multiplicand. */ - public operator fun T.times(b: T): T = multiply(this, b) + public operator fun T.times(other: T): T = multiply(this, other) override fun binaryOperationFunction(operation: String): (left: T, right: T) -> T = when (operation) { TIMES_OPERATION -> ::multiply @@ -242,7 +242,7 @@ public interface RingOperations : GroupOperations { * * @param T the type of element of this ring. */ -public interface Ring : Group, RingOperations { +public interface Ring : Group, RingOps { /** * The neutral element of multiplication */ @@ -256,24 +256,24 @@ public interface Ring : Group, RingOperations { * * @param T the type of element of this semifield. */ -public interface FieldOperations : RingOperations { +public interface FieldOps : RingOps { /** * Division of two elements. * - * @param a the dividend. - * @param b the divisor. + * @param left the dividend. + * @param right the divisor. * @return the quotient. */ - public fun divide(a: T, b: T): T + public fun divide(left: T, right: T): T /** * Division of two elements. * * @receiver the dividend. - * @param b the divisor. + * @param other the divisor. * @return the quotient. */ - public operator fun T.div(b: T): T = divide(this, b) + public operator fun T.div(other: T): T = divide(this, other) override fun binaryOperationFunction(operation: String): (left: T, right: T) -> T = when (operation) { DIV_OPERATION -> ::divide @@ -295,6 +295,6 @@ public interface FieldOperations : RingOperations { * * @param T the type of element of this field. */ -public interface Field : Ring, FieldOperations, ScaleOperations, NumericAlgebra { +public interface Field : Ring, FieldOps, ScaleOperations, NumericAlgebra { override fun number(value: Number): T = scale(one, value.toDouble()) } diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/BigInt.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/BigInt.kt index 82754e43d..5a713049e 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/BigInt.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/BigInt.kt @@ -6,7 +6,7 @@ package space.kscience.kmath.operations import space.kscience.kmath.misc.UnstableKMathAPI -import space.kscience.kmath.nd.BufferedRingND +import space.kscience.kmath.nd.BufferedRingOpsND import space.kscience.kmath.operations.BigInt.Companion.BASE import space.kscience.kmath.operations.BigInt.Companion.BASE_SIZE import space.kscience.kmath.structures.Buffer @@ -26,7 +26,7 @@ private typealias TBase = ULong * @author Peter Klimai */ @OptIn(UnstableKMathAPI::class) -public object BigIntField : Field, NumbersAddOperations, ScaleOperations { +public object BigIntField : Field, NumbersAddOps, ScaleOperations { override val zero: BigInt = BigInt.ZERO override val one: BigInt = BigInt.ONE @@ -34,10 +34,10 @@ public object BigIntField : Field, NumbersAddOperations, ScaleOp @Suppress("EXTENSION_SHADOWED_BY_MEMBER") override fun BigInt.unaryMinus(): BigInt = -this - override fun add(a: BigInt, b: BigInt): BigInt = a.plus(b) + override fun add(left: BigInt, right: BigInt): BigInt = left.plus(right) override fun scale(a: BigInt, value: Double): BigInt = a.times(number(value)) - override fun multiply(a: BigInt, b: BigInt): BigInt = a.times(b) - override fun divide(a: BigInt, b: BigInt): BigInt = a.div(b) + override fun multiply(left: BigInt, right: BigInt): BigInt = left.times(right) + override fun divide(left: BigInt, right: BigInt): BigInt = left.div(right) public operator fun String.unaryPlus(): BigInt = this.parseBigInteger() ?: error("Can't parse $this as big integer") public operator fun String.unaryMinus(): BigInt = @@ -542,5 +542,5 @@ public inline fun MutableBuffer.Companion.bigInt(size: Int, initializer: (Int) - public inline fun BigInt.mutableBuffer(size: Int, initializer: (Int) -> BigInt): Buffer = Buffer.boxing(size, initializer) -public fun BigIntField.nd(vararg shape: Int): BufferedRingND = - BufferedRingND(shape, BigIntField, BigInt::buffer) +public val BigIntField.nd: BufferedRingOpsND + get() = BufferedRingOpsND(BufferRingOps(BigIntField, BigInt::buffer)) diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/BufferAlgebra.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/BufferAlgebra.kt index e82b62c1b..bc05f3904 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/BufferAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/BufferAlgebra.kt @@ -5,32 +5,34 @@ package space.kscience.kmath.operations -import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.BufferFactory import space.kscience.kmath.structures.DoubleBuffer +import space.kscience.kmath.structures.ShortBuffer + +public interface WithSize { + public val size: Int +} /** * An algebra over [Buffer] */ -@UnstableKMathAPI -public interface BufferAlgebra> : Algebra> { - public val bufferFactory: BufferFactory +public interface BufferAlgebra> : Algebra> { public val elementAlgebra: A - public val size: Int + public val bufferFactory: BufferFactory - public fun buffer(vararg elements: T): Buffer { + public fun buffer(size: Int, vararg elements: T): Buffer { require(elements.size == size) { "Expected $size elements but found ${elements.size}" } return bufferFactory(size) { elements[it] } } //TODO move to multi-receiver inline extension - public fun Buffer.map(block: (T) -> T): Buffer = bufferFactory(size) { block(get(it)) } + public fun Buffer.map(block: A.(T) -> T): Buffer = mapInline(this, block) - public fun Buffer.zip(other: Buffer, block: (left: T, right: T) -> T): Buffer { - require(size == other.size) { "Incompatible buffer sizes. left: $size, right: ${other.size}" } - return bufferFactory(size) { block(this[it], other[it]) } - } + public fun Buffer.mapIndexed(block: A.(index: Int, arg: T) -> T): Buffer = mapIndexedInline(this, block) + + public fun Buffer.zip(other: Buffer, block: A.(left: T, right: T) -> T): Buffer = + zipInline(this, other, block) override fun unaryOperationFunction(operation: String): (arg: Buffer) -> Buffer { val operationFunction = elementAlgebra.unaryOperationFunction(operation) @@ -45,112 +47,149 @@ public interface BufferAlgebra> : Algebra> { } } -@UnstableKMathAPI -public fun BufferField.buffer(initializer: (Int) -> T): Buffer { +/** + * Inline map + */ +public inline fun > BufferAlgebra.mapInline( + buffer: Buffer, + crossinline block: A.(T) -> T +): Buffer = bufferFactory(buffer.size) { elementAlgebra.block(buffer[it]) } + +/** + * Inline map + */ +public inline fun > BufferAlgebra.mapIndexedInline( + buffer: Buffer, + crossinline block: A.(index: Int, arg: T) -> T +): Buffer = bufferFactory(buffer.size) { elementAlgebra.block(it, buffer[it]) } + +/** + * Inline zip + */ +public inline fun > BufferAlgebra.zipInline( + l: Buffer, + r: Buffer, + crossinline block: A.(l: T, r: T) -> T +): Buffer { + require(l.size == r.size) { "Incompatible buffer sizes. left: ${l.size}, right: ${r.size}" } + return bufferFactory(l.size) { elementAlgebra.block(l[it], r[it]) } +} + +public fun BufferAlgebra.buffer(size: Int, initializer: (Int) -> T): Buffer { + return bufferFactory(size, initializer) +} + +public fun A.buffer(initializer: (Int) -> T): Buffer where A : BufferAlgebra, A : WithSize { return bufferFactory(size, initializer) } -@UnstableKMathAPI public fun > BufferAlgebra.sin(arg: Buffer): Buffer = - arg.map(elementAlgebra::sin) + mapInline(arg) { sin(it) } -@UnstableKMathAPI public fun > BufferAlgebra.cos(arg: Buffer): Buffer = - arg.map(elementAlgebra::cos) + mapInline(arg) { cos(it) } -@UnstableKMathAPI public fun > BufferAlgebra.tan(arg: Buffer): Buffer = - arg.map(elementAlgebra::tan) + mapInline(arg) { tan(it) } -@UnstableKMathAPI public fun > BufferAlgebra.asin(arg: Buffer): Buffer = - arg.map(elementAlgebra::asin) + mapInline(arg) { asin(it) } -@UnstableKMathAPI public fun > BufferAlgebra.acos(arg: Buffer): Buffer = - arg.map(elementAlgebra::acos) + mapInline(arg) { acos(it) } -@UnstableKMathAPI public fun > BufferAlgebra.atan(arg: Buffer): Buffer = - arg.map(elementAlgebra::atan) + mapInline(arg) { atan(it) } -@UnstableKMathAPI public fun > BufferAlgebra.exp(arg: Buffer): Buffer = - arg.map(elementAlgebra::exp) + mapInline(arg) { exp(it) } -@UnstableKMathAPI public fun > BufferAlgebra.ln(arg: Buffer): Buffer = - arg.map(elementAlgebra::ln) + mapInline(arg) { ln(it) } -@UnstableKMathAPI public fun > BufferAlgebra.sinh(arg: Buffer): Buffer = - arg.map(elementAlgebra::sinh) + mapInline(arg) { sinh(it) } -@UnstableKMathAPI public fun > BufferAlgebra.cosh(arg: Buffer): Buffer = - arg.map(elementAlgebra::cosh) + mapInline(arg) { cosh(it) } -@UnstableKMathAPI public fun > BufferAlgebra.tanh(arg: Buffer): Buffer = - arg.map(elementAlgebra::tanh) + mapInline(arg) { tanh(it) } -@UnstableKMathAPI public fun > BufferAlgebra.asinh(arg: Buffer): Buffer = - arg.map(elementAlgebra::asinh) + mapInline(arg) { asinh(it) } -@UnstableKMathAPI public fun > BufferAlgebra.acosh(arg: Buffer): Buffer = - arg.map(elementAlgebra::acosh) + mapInline(arg) { acosh(it) } -@UnstableKMathAPI public fun > BufferAlgebra.atanh(arg: Buffer): Buffer = - arg.map(elementAlgebra::atanh) + mapInline(arg) { atanh(it) } -@UnstableKMathAPI public fun > BufferAlgebra.pow(arg: Buffer, pow: Number): Buffer = - with(elementAlgebra) { arg.map { power(it, pow) } } + mapInline(arg) { power(it, pow) } -@UnstableKMathAPI -public class BufferField>( - override val bufferFactory: BufferFactory, +public open class BufferRingOps>( override val elementAlgebra: A, + override val bufferFactory: BufferFactory, +) : BufferAlgebra, RingOps>{ + + override fun add(left: Buffer, right: Buffer): Buffer = zipInline(left, right) { l, r -> l + r } + override fun multiply(left: Buffer, right: Buffer): Buffer = zipInline(left, right) { l, r -> l * r } + override fun Buffer.unaryMinus(): Buffer = map { -it } + + override fun unaryOperationFunction(operation: String): (arg: Buffer) -> Buffer = + super.unaryOperationFunction(operation) + + override fun binaryOperationFunction(operation: String): (left: Buffer, right: Buffer) -> Buffer = + super.binaryOperationFunction(operation) +} + +public val ShortRing.bufferAlgebra: BufferRingOps + get() = BufferRingOps(ShortRing, ::ShortBuffer) + +public open class BufferFieldOps>( + elementAlgebra: A, + bufferFactory: BufferFactory, +) : BufferRingOps(elementAlgebra, bufferFactory), BufferAlgebra, FieldOps>, ScaleOperations> { + + override fun add(left: Buffer, right: Buffer): Buffer = zipInline(left, right) { l, r -> l + r } + override fun multiply(left: Buffer, right: Buffer): Buffer = zipInline(left, right) { l, r -> l * r } + override fun divide(left: Buffer, right: Buffer): Buffer = zipInline(left, right) { l, r -> l / r } + + override fun scale(a: Buffer, value: Double): Buffer = a.map { scale(it, value) } + override fun Buffer.unaryMinus(): Buffer = map { -it } + + override fun binaryOperationFunction(operation: String): (left: Buffer, right: Buffer) -> Buffer = + super.binaryOperationFunction(operation) +} + +public class BufferField>( + elementAlgebra: A, + bufferFactory: BufferFactory, override val size: Int -) : BufferAlgebra, Field> { +) : BufferFieldOps(elementAlgebra, bufferFactory), Field>, WithSize { override val zero: Buffer = bufferFactory(size) { elementAlgebra.zero } override val one: Buffer = bufferFactory(size) { elementAlgebra.one } - - - override fun add(a: Buffer, b: Buffer): Buffer = a.zip(b, elementAlgebra::add) - override fun multiply(a: Buffer, b: Buffer): Buffer = a.zip(b, elementAlgebra::multiply) - override fun divide(a: Buffer, b: Buffer): Buffer = a.zip(b, elementAlgebra::divide) - - override fun scale(a: Buffer, value: Double): Buffer = with(elementAlgebra) { a.map { scale(it, value) } } - override fun Buffer.unaryMinus(): Buffer = with(elementAlgebra) { map { -it } } - - override fun unaryOperationFunction(operation: String): (arg: Buffer) -> Buffer { - return super.unaryOperationFunction(operation) - } - - override fun binaryOperationFunction(operation: String): (left: Buffer, right: Buffer) -> Buffer { - return super.binaryOperationFunction(operation) - } } +/** + * Generate full buffer field from given buffer operations + */ +public fun > BufferFieldOps.withSize(size: Int): BufferField = + BufferField(elementAlgebra, bufferFactory, size) + //Double buffer specialization -@UnstableKMathAPI public fun BufferField.buffer(vararg elements: Number): Buffer { require(elements.size == size) { "Expected $size elements but found ${elements.size}" } return bufferFactory(size) { elements[it].toDouble() } } -@UnstableKMathAPI -public fun > A.bufferAlgebra(bufferFactory: BufferFactory, size: Int): BufferField = - BufferField(bufferFactory, this, size) +public fun > A.bufferAlgebra(bufferFactory: BufferFactory): BufferFieldOps = + BufferFieldOps(this, bufferFactory) -@UnstableKMathAPI -public fun DoubleField.bufferAlgebra(size: Int): BufferField = - BufferField(::DoubleBuffer, DoubleField, size) +public val DoubleField.bufferAlgebra: BufferFieldOps + get() = BufferFieldOps(DoubleField, ::DoubleBuffer) diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/DoubleBufferField.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/DoubleBufferField.kt index acc2c2dc0..060ea5a7e 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/DoubleBufferField.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/DoubleBufferField.kt @@ -13,21 +13,21 @@ import space.kscience.kmath.structures.DoubleBuffer * * @property size the size of buffers to operate on. */ -public class DoubleBufferField(public val size: Int) : ExtendedField>, DoubleBufferOperations() { +public class DoubleBufferField(public val size: Int) : ExtendedField>, DoubleBufferOps() { override val zero: Buffer by lazy { DoubleBuffer(size) { 0.0 } } override val one: Buffer by lazy { DoubleBuffer(size) { 1.0 } } - override fun sinh(arg: Buffer): DoubleBuffer = super.sinh(arg) + override fun sinh(arg: Buffer): DoubleBuffer = super.sinh(arg) - override fun cosh(arg: Buffer): DoubleBuffer = super.cosh(arg) + override fun cosh(arg: Buffer): DoubleBuffer = super.cosh(arg) - override fun tanh(arg: Buffer): DoubleBuffer = super.tanh(arg) + override fun tanh(arg: Buffer): DoubleBuffer = super.tanh(arg) - override fun asinh(arg: Buffer): DoubleBuffer = super.asinh(arg) + override fun asinh(arg: Buffer): DoubleBuffer = super.asinh(arg) - override fun acosh(arg: Buffer): DoubleBuffer = super.acosh(arg) + override fun acosh(arg: Buffer): DoubleBuffer = super.acosh(arg) - override fun atanh(arg: Buffer): DoubleBuffer= super.atanh(arg) + override fun atanh(arg: Buffer): DoubleBuffer= super.atanh(arg) // override fun number(value: Number): Buffer = DoubleBuffer(size) { value.toDouble() } // diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/DoubleBufferOperations.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/DoubleBufferOps.kt similarity index 73% rename from kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/DoubleBufferOperations.kt rename to kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/DoubleBufferOps.kt index 50b821962..29b25aae8 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/DoubleBufferOperations.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/DoubleBufferOps.kt @@ -12,39 +12,40 @@ import space.kscience.kmath.structures.DoubleBuffer import kotlin.math.* /** - * [ExtendedFieldOperations] over [DoubleBuffer]. + * [ExtendedFieldOps] over [DoubleBuffer]. */ -public abstract class DoubleBufferOperations : ExtendedFieldOperations>, Norm, Double> { +public abstract class DoubleBufferOps : ExtendedFieldOps>, Norm, Double> { + override fun Buffer.unaryMinus(): DoubleBuffer = if (this is DoubleBuffer) { DoubleBuffer(size) { -array[it] } } else { DoubleBuffer(size) { -get(it) } } - override fun add(a: Buffer, b: Buffer): DoubleBuffer { - require(b.size == a.size) { - "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " + override fun add(left: Buffer, right: Buffer): DoubleBuffer { + require(right.size == left.size) { + "The size of the first buffer ${left.size} should be the same as for second one: ${right.size} " } - return if (a is DoubleBuffer && b is DoubleBuffer) { - val aArray = a.array - val bArray = b.array - DoubleBuffer(DoubleArray(a.size) { aArray[it] + bArray[it] }) - } else DoubleBuffer(DoubleArray(a.size) { a[it] + b[it] }) + return if (left is DoubleBuffer && right is DoubleBuffer) { + val aArray = left.array + val bArray = right.array + DoubleBuffer(DoubleArray(left.size) { aArray[it] + bArray[it] }) + } else DoubleBuffer(DoubleArray(left.size) { left[it] + right[it] }) } - override fun Buffer.plus(b: Buffer): DoubleBuffer = add(this, b) + override fun Buffer.plus(other: Buffer): DoubleBuffer = add(this, other) - override fun Buffer.minus(b: Buffer): DoubleBuffer { - require(b.size == this.size) { - "The size of the first buffer ${this.size} should be the same as for second one: ${b.size} " + override fun Buffer.minus(other: Buffer): DoubleBuffer { + require(other.size == this.size) { + "The size of the first buffer ${this.size} should be the same as for second one: ${other.size} " } - return if (this is DoubleBuffer && b is DoubleBuffer) { + return if (this is DoubleBuffer && other is DoubleBuffer) { val aArray = this.array - val bArray = b.array + val bArray = other.array DoubleBuffer(DoubleArray(this.size) { aArray[it] - bArray[it] }) - } else DoubleBuffer(DoubleArray(this.size) { this[it] - b[it] }) + } else DoubleBuffer(DoubleArray(this.size) { this[it] - other[it] }) } // @@ -66,29 +67,29 @@ public abstract class DoubleBufferOperations : ExtendedFieldOperations, b: Buffer): DoubleBuffer { - require(b.size == a.size) { - "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " + override fun multiply(left: Buffer, right: Buffer): DoubleBuffer { + require(right.size == left.size) { + "The size of the first buffer ${left.size} should be the same as for second one: ${right.size} " } - return if (a is DoubleBuffer && b is DoubleBuffer) { - val aArray = a.array - val bArray = b.array - DoubleBuffer(DoubleArray(a.size) { aArray[it] * bArray[it] }) + return if (left is DoubleBuffer && right is DoubleBuffer) { + val aArray = left.array + val bArray = right.array + DoubleBuffer(DoubleArray(left.size) { aArray[it] * bArray[it] }) } else - DoubleBuffer(DoubleArray(a.size) { a[it] * b[it] }) + DoubleBuffer(DoubleArray(left.size) { left[it] * right[it] }) } - override fun divide(a: Buffer, b: Buffer): DoubleBuffer { - require(b.size == a.size) { - "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " + override fun divide(left: Buffer, right: Buffer): DoubleBuffer { + require(right.size == left.size) { + "The size of the first buffer ${left.size} should be the same as for second one: ${right.size} " } - return if (a is DoubleBuffer && b is DoubleBuffer) { - val aArray = a.array - val bArray = b.array - DoubleBuffer(DoubleArray(a.size) { aArray[it] / bArray[it] }) - } else DoubleBuffer(DoubleArray(a.size) { a[it] / b[it] }) + return if (left is DoubleBuffer && right is DoubleBuffer) { + val aArray = left.array + val bArray = right.array + DoubleBuffer(DoubleArray(left.size) { aArray[it] / bArray[it] }) + } else DoubleBuffer(DoubleArray(left.size) { left[it] / right[it] }) } override fun sin(arg: Buffer): DoubleBuffer = if (arg is DoubleBuffer) { @@ -185,7 +186,7 @@ public abstract class DoubleBufferOperations : ExtendedFieldOperations, Double> { diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/NumericAlgebra.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/NumericAlgebra.kt index a3d8f5ffe..5f6848211 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/NumericAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/NumericAlgebra.kt @@ -150,7 +150,7 @@ public interface ScaleOperations : Algebra { * TODO to be removed and replaced by extensions after multiple receivers are there */ @UnstableKMathAPI -public interface NumbersAddOperations : Ring, NumericAlgebra { +public interface NumbersAddOps : RingOps, NumericAlgebra { /** * Addition of element and scalar. * diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/numbers.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/numbers.kt index 4c0010bf9..1168dc6ba 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/numbers.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/numbers.kt @@ -10,8 +10,8 @@ import kotlin.math.pow as kpow /** * Advanced Number-like semifield that implements basic operations. */ -public interface ExtendedFieldOperations : - FieldOperations, +public interface ExtendedFieldOps : + FieldOps, TrigonometricOperations, PowerOperations, ExponentialOperations, @@ -35,14 +35,14 @@ public interface ExtendedFieldOperations : ExponentialOperations.ACOSH_OPERATION -> ::acosh ExponentialOperations.ASINH_OPERATION -> ::asinh ExponentialOperations.ATANH_OPERATION -> ::atanh - else -> super.unaryOperationFunction(operation) + else -> super.unaryOperationFunction(operation) } } /** * Advanced Number-like field that implements basic operations. */ -public interface ExtendedField : ExtendedFieldOperations, Field, NumericAlgebra{ +public interface ExtendedField : ExtendedFieldOps, Field, NumericAlgebra{ override fun sinh(arg: T): T = (exp(arg) - exp(-arg)) / 2.0 override fun cosh(arg: T): T = (exp(arg) + exp(-arg)) / 2.0 override fun tanh(arg: T): T = (exp(arg) - exp(-arg)) / (exp(-arg) + exp(arg)) @@ -73,10 +73,10 @@ public object DoubleField : ExtendedField, Norm, ScaleOp else -> super.binaryOperationFunction(operation) } - override inline fun add(a: Double, b: Double): Double = a + b + override inline fun add(left: Double, right: Double): Double = left + right - override inline fun multiply(a: Double, b: Double): Double = a * b - override inline fun divide(a: Double, b: Double): Double = a / b + override inline fun multiply(left: Double, right: Double): Double = left * right + override inline fun divide(left: Double, right: Double): Double = left / right override inline fun scale(a: Double, value: Double): Double = a * value @@ -102,10 +102,10 @@ public object DoubleField : ExtendedField, Norm, ScaleOp override inline fun norm(arg: Double): Double = abs(arg) override inline fun Double.unaryMinus(): Double = -this - override inline fun Double.plus(b: Double): Double = this + b - override inline fun Double.minus(b: Double): Double = this - b - override inline fun Double.times(b: Double): Double = this * b - override inline fun Double.div(b: Double): Double = this / b + override inline fun Double.plus(other: Double): Double = this + other + override inline fun Double.minus(other: Double): Double = this - other + override inline fun Double.times(other: Double): Double = this * other + override inline fun Double.div(other: Double): Double = this / other } public val Double.Companion.algebra: DoubleField get() = DoubleField @@ -126,12 +126,12 @@ public object FloatField : ExtendedField, Norm { else -> super.binaryOperationFunction(operation) } - override inline fun add(a: Float, b: Float): Float = a + b + override inline fun add(left: Float, right: Float): Float = left + right override fun scale(a: Float, value: Double): Float = a * value.toFloat() - override inline fun multiply(a: Float, b: Float): Float = a * b + override inline fun multiply(left: Float, right: Float): Float = left * right - override inline fun divide(a: Float, b: Float): Float = a / b + override inline fun divide(left: Float, right: Float): Float = left / right override inline fun sin(arg: Float): Float = kotlin.math.sin(arg) override inline fun cos(arg: Float): Float = kotlin.math.cos(arg) @@ -155,10 +155,10 @@ public object FloatField : ExtendedField, Norm { override inline fun norm(arg: Float): Float = abs(arg) override inline fun Float.unaryMinus(): Float = -this - override inline fun Float.plus(b: Float): Float = this + b - override inline fun Float.minus(b: Float): Float = this - b - override inline fun Float.times(b: Float): Float = this * b - override inline fun Float.div(b: Float): Float = this / b + override inline fun Float.plus(other: Float): Float = this + other + override inline fun Float.minus(other: Float): Float = this - other + override inline fun Float.times(other: Float): Float = this * other + override inline fun Float.div(other: Float): Float = this / other } public val Float.Companion.algebra: FloatField get() = FloatField @@ -175,14 +175,14 @@ public object IntRing : Ring, Norm, NumericAlgebra { get() = 1 override fun number(value: Number): Int = value.toInt() - override inline fun add(a: Int, b: Int): Int = a + b - override inline fun multiply(a: Int, b: Int): Int = a * b + override inline fun add(left: Int, right: Int): Int = left + right + override inline fun multiply(left: Int, right: Int): Int = left * right override inline fun norm(arg: Int): Int = abs(arg) override inline fun Int.unaryMinus(): Int = -this - override inline fun Int.plus(b: Int): Int = this + b - override inline fun Int.minus(b: Int): Int = this - b - override inline fun Int.times(b: Int): Int = this * b + override inline fun Int.plus(other: Int): Int = this + other + override inline fun Int.minus(other: Int): Int = this - other + override inline fun Int.times(other: Int): Int = this * other } public val Int.Companion.algebra: IntRing get() = IntRing @@ -199,14 +199,14 @@ public object ShortRing : Ring, Norm, NumericAlgebra get() = 1 override fun number(value: Number): Short = value.toShort() - override inline fun add(a: Short, b: Short): Short = (a + b).toShort() - override inline fun multiply(a: Short, b: Short): Short = (a * b).toShort() + override inline fun add(left: Short, right: Short): Short = (left + right).toShort() + override inline fun multiply(left: Short, right: Short): Short = (left * right).toShort() override fun norm(arg: Short): Short = if (arg > 0) arg else (-arg).toShort() override inline fun Short.unaryMinus(): Short = (-this).toShort() - override inline fun Short.plus(b: Short): Short = (this + b).toShort() - override inline fun Short.minus(b: Short): Short = (this - b).toShort() - override inline fun Short.times(b: Short): Short = (this * b).toShort() + override inline fun Short.plus(other: Short): Short = (this + other).toShort() + override inline fun Short.minus(other: Short): Short = (this - other).toShort() + override inline fun Short.times(other: Short): Short = (this * other).toShort() } public val Short.Companion.algebra: ShortRing get() = ShortRing @@ -223,14 +223,14 @@ public object ByteRing : Ring, Norm, NumericAlgebra { get() = 1 override fun number(value: Number): Byte = value.toByte() - override inline fun add(a: Byte, b: Byte): Byte = (a + b).toByte() - override inline fun multiply(a: Byte, b: Byte): Byte = (a * b).toByte() + override inline fun add(left: Byte, right: Byte): Byte = (left + right).toByte() + override inline fun multiply(left: Byte, right: Byte): Byte = (left * right).toByte() override fun norm(arg: Byte): Byte = if (arg > 0) arg else (-arg).toByte() override inline fun Byte.unaryMinus(): Byte = (-this).toByte() - override inline fun Byte.plus(b: Byte): Byte = (this + b).toByte() - override inline fun Byte.minus(b: Byte): Byte = (this - b).toByte() - override inline fun Byte.times(b: Byte): Byte = (this * b).toByte() + override inline fun Byte.plus(other: Byte): Byte = (this + other).toByte() + override inline fun Byte.minus(other: Byte): Byte = (this - other).toByte() + override inline fun Byte.times(other: Byte): Byte = (this * other).toByte() } public val Byte.Companion.algebra: ByteRing get() = ByteRing @@ -247,14 +247,14 @@ public object LongRing : Ring, Norm, NumericAlgebra { get() = 1L override fun number(value: Number): Long = value.toLong() - override inline fun add(a: Long, b: Long): Long = a + b - override inline fun multiply(a: Long, b: Long): Long = a * b + override inline fun add(left: Long, right: Long): Long = left + right + override inline fun multiply(left: Long, right: Long): Long = left * right override fun norm(arg: Long): Long = abs(arg) override inline fun Long.unaryMinus(): Long = (-this) - override inline fun Long.plus(b: Long): Long = (this + b) - override inline fun Long.minus(b: Long): Long = (this - b) - override inline fun Long.times(b: Long): Long = (this * b) + override inline fun Long.plus(other: Long): Long = (this + other) + override inline fun Long.minus(other: Long): Long = (this - other) + override inline fun Long.times(other: Long): Long = (this * other) } public val Long.Companion.algebra: LongRing get() = LongRing diff --git a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/structures/NDFieldTest.kt b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/structures/NDFieldTest.kt index 05a67ab09..2009eb64f 100644 --- a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/structures/NDFieldTest.kt +++ b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/structures/NDFieldTest.kt @@ -7,6 +7,7 @@ package space.kscience.kmath.structures import space.kscience.kmath.nd.get import space.kscience.kmath.nd.ndAlgebra +import space.kscience.kmath.nd.produce import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.invoke import space.kscience.kmath.testutils.FieldVerifier @@ -21,7 +22,7 @@ internal class NDFieldTest { @Test fun testStrides() { - val ndArray = DoubleField.ndAlgebra(10, 10).produce { (it[0] + it[1]).toDouble() } + val ndArray = DoubleField.ndAlgebra.produce(10, 10) { (it[0] + it[1]).toDouble() } assertEquals(ndArray[5, 5], 10.0) } } diff --git a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/structures/NumberNDFieldTest.kt b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/structures/NumberNDFieldTest.kt index e58976f4a..907301a53 100644 --- a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/structures/NumberNDFieldTest.kt +++ b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/structures/NumberNDFieldTest.kt @@ -7,10 +7,7 @@ package space.kscience.kmath.structures import space.kscience.kmath.linear.linearSpace import space.kscience.kmath.misc.PerformancePitfall -import space.kscience.kmath.nd.StructureND -import space.kscience.kmath.nd.combine -import space.kscience.kmath.nd.get -import space.kscience.kmath.nd.ndAlgebra +import space.kscience.kmath.nd.* import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.Norm import space.kscience.kmath.operations.algebra @@ -22,9 +19,9 @@ import kotlin.test.assertEquals @Suppress("UNUSED_VARIABLE") class NumberNDFieldTest { - val algebra = DoubleField.ndAlgebra(3, 3) - val array1 = algebra.produce { (i, j) -> (i + j).toDouble() } - val array2 = algebra.produce { (i, j) -> (i - j).toDouble() } + val algebra = DoubleField.ndAlgebra + val array1 = algebra.produce(3, 3) { (i, j) -> (i + j).toDouble() } + val array2 = algebra.produce(3, 3) { (i, j) -> (i - j).toDouble() } @Test fun testSum() { @@ -77,7 +74,7 @@ class NumberNDFieldTest { @Test fun combineTest() { - val division = array1.combine(array2, Double::div) + val division = array1.zip(array2, Double::div) } object L2Norm : Norm, Double> { diff --git a/kmath-core/src/jvmMain/kotlin/space/kscience/kmath/operations/BigNumbers.kt b/kmath-core/src/jvmMain/kotlin/space/kscience/kmath/operations/BigNumbers.kt index 69dd858c4..3a9c242fc 100644 --- a/kmath-core/src/jvmMain/kotlin/space/kscience/kmath/operations/BigNumbers.kt +++ b/kmath-core/src/jvmMain/kotlin/space/kscience/kmath/operations/BigNumbers.kt @@ -18,9 +18,9 @@ public object JBigIntegerField : Ring, NumericAlgebra { override val one: BigInteger get() = BigInteger.ONE override fun number(value: Number): BigInteger = BigInteger.valueOf(value.toLong()) - override fun add(a: BigInteger, b: BigInteger): BigInteger = a.add(b) - override operator fun BigInteger.minus(b: BigInteger): BigInteger = subtract(b) - override fun multiply(a: BigInteger, b: BigInteger): BigInteger = a.multiply(b) + override fun add(left: BigInteger, right: BigInteger): BigInteger = left.add(right) + override operator fun BigInteger.minus(other: BigInteger): BigInteger = subtract(other) + override fun multiply(left: BigInteger, right: BigInteger): BigInteger = left.multiply(right) override operator fun BigInteger.unaryMinus(): BigInteger = negate() } @@ -39,15 +39,15 @@ public abstract class JBigDecimalFieldBase internal constructor( override val one: BigDecimal get() = BigDecimal.ONE - override fun add(a: BigDecimal, b: BigDecimal): BigDecimal = a.add(b) - override operator fun BigDecimal.minus(b: BigDecimal): BigDecimal = subtract(b) + override fun add(left: BigDecimal, right: BigDecimal): BigDecimal = left.add(right) + override operator fun BigDecimal.minus(other: BigDecimal): BigDecimal = subtract(other) override fun number(value: Number): BigDecimal = BigDecimal.valueOf(value.toDouble()) override fun scale(a: BigDecimal, value: Double): BigDecimal = a.multiply(value.toBigDecimal(mathContext), mathContext) - override fun multiply(a: BigDecimal, b: BigDecimal): BigDecimal = a.multiply(b, mathContext) - override fun divide(a: BigDecimal, b: BigDecimal): BigDecimal = a.divide(b, mathContext) + override fun multiply(left: BigDecimal, right: BigDecimal): BigDecimal = left.multiply(right, mathContext) + override fun divide(left: BigDecimal, right: BigDecimal): BigDecimal = left.divide(right, mathContext) override fun power(arg: BigDecimal, pow: Number): BigDecimal = arg.pow(pow.toInt(), mathContext) override fun sqrt(arg: BigDecimal): BigDecimal = arg.sqrt(mathContext) override operator fun BigDecimal.unaryMinus(): BigDecimal = negate(mathContext) diff --git a/kmath-coroutines/src/commonMain/kotlin/space/kscience/kmath/chains/flowExtra.kt b/kmath-coroutines/src/commonMain/kotlin/space/kscience/kmath/chains/flowExtra.kt index 1821daac9..1620f029c 100644 --- a/kmath-coroutines/src/commonMain/kotlin/space/kscience/kmath/chains/flowExtra.kt +++ b/kmath-coroutines/src/commonMain/kotlin/space/kscience/kmath/chains/flowExtra.kt @@ -10,12 +10,12 @@ import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.map import kotlinx.coroutines.flow.runningReduce import kotlinx.coroutines.flow.scan -import space.kscience.kmath.operations.GroupOperations +import space.kscience.kmath.operations.GroupOps import space.kscience.kmath.operations.Ring import space.kscience.kmath.operations.ScaleOperations import space.kscience.kmath.operations.invoke -public fun Flow.cumulativeSum(group: GroupOperations): Flow = +public fun Flow.cumulativeSum(group: GroupOps): Flow = group { runningReduce { sum, element -> sum + element } } @ExperimentalCoroutinesApi diff --git a/kmath-for-real/src/commonMain/kotlin/space/kscience/kmath/real/realND.kt b/kmath-for-real/src/commonMain/kotlin/space/kscience/kmath/real/realND.kt index b2c3209c2..0edd51be2 100644 --- a/kmath-for-real/src/commonMain/kotlin/space/kscience/kmath/real/realND.kt +++ b/kmath-for-real/src/commonMain/kotlin/space/kscience/kmath/real/realND.kt @@ -13,8 +13,8 @@ import space.kscience.kmath.structures.DoubleBuffer * Map one [BufferND] using function without indices. */ public inline fun BufferND.mapInline(crossinline transform: DoubleField.(Double) -> Double): BufferND { - val array = DoubleArray(strides.linearSize) { offset -> DoubleField.transform(buffer[offset]) } - return BufferND(strides, DoubleBuffer(array)) + val array = DoubleArray(indexes.linearSize) { offset -> DoubleField.transform(buffer[offset]) } + return BufferND(indexes, DoubleBuffer(array)) } /** diff --git a/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/functions/Polynomial.kt b/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/functions/Polynomial.kt index 54b285a70..e862c0b9d 100644 --- a/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/functions/Polynomial.kt +++ b/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/functions/Polynomial.kt @@ -104,12 +104,12 @@ public class PolynomialSpace( Polynomial(coefficients.map { -it }) } - override fun add(a: Polynomial, b: Polynomial): Polynomial { - val dim = max(a.coefficients.size, b.coefficients.size) + override fun add(left: Polynomial, right: Polynomial): Polynomial { + val dim = max(left.coefficients.size, right.coefficients.size) return ring { Polynomial(List(dim) { index -> - a.coefficients.getOrElse(index) { zero } + b.coefficients.getOrElse(index) { zero } + left.coefficients.getOrElse(index) { zero } + right.coefficients.getOrElse(index) { zero } }) } } diff --git a/kmath-geometry/src/commonMain/kotlin/space/kscience/kmath/geometry/Euclidean2DSpace.kt b/kmath-geometry/src/commonMain/kotlin/space/kscience/kmath/geometry/Euclidean2DSpace.kt index e8b1ce95b..5e3cbff83 100644 --- a/kmath-geometry/src/commonMain/kotlin/space/kscience/kmath/geometry/Euclidean2DSpace.kt +++ b/kmath-geometry/src/commonMain/kotlin/space/kscience/kmath/geometry/Euclidean2DSpace.kt @@ -47,7 +47,7 @@ public object Euclidean2DSpace : GeometrySpace, ScaleOperations, ScaleOperations when (i) { 0 -> Double.NEGATIVE_INFINITY - strides.shape[axis] - 1 -> upper[axis] + shape[axis] - 1 -> upper[axis] else -> lower[axis] + (i.toDouble()) * binSize[axis] } }.asBuffer() @@ -60,7 +59,7 @@ public class DoubleHistogramSpace( val upperBoundary = index.mapIndexed { axis, i -> when (i) { 0 -> lower[axis] - strides.shape[axis] - 1 -> Double.POSITIVE_INFINITY + shape[axis] - 1 -> Double.POSITIVE_INFINITY else -> lower[axis] + (i.toDouble() + 1) * binSize[axis] } }.asBuffer() @@ -75,7 +74,7 @@ public class DoubleHistogramSpace( } override fun produce(builder: HistogramBuilder.() -> Unit): IndexedHistogram { - val ndCounter = StructureND.auto(strides) { Counter.real() } + val ndCounter = StructureND.auto(shape) { Counter.real() } val hBuilder = HistogramBuilder { point, value -> val index = getIndex(point) ndCounter[index].add(value.toDouble()) diff --git a/kmath-histograms/src/commonMain/kotlin/space/kscience/kmath/histogram/IndexedHistogramSpace.kt b/kmath-histograms/src/commonMain/kotlin/space/kscience/kmath/histogram/IndexedHistogramSpace.kt index 44f3072d2..a495577c3 100644 --- a/kmath-histograms/src/commonMain/kotlin/space/kscience/kmath/histogram/IndexedHistogramSpace.kt +++ b/kmath-histograms/src/commonMain/kotlin/space/kscience/kmath/histogram/IndexedHistogramSpace.kt @@ -8,8 +8,9 @@ package space.kscience.kmath.histogram import space.kscience.kmath.domains.Domain import space.kscience.kmath.linear.Point import space.kscience.kmath.misc.UnstableKMathAPI +import space.kscience.kmath.nd.DefaultStrides import space.kscience.kmath.nd.FieldND -import space.kscience.kmath.nd.Strides +import space.kscience.kmath.nd.Shape import space.kscience.kmath.nd.StructureND import space.kscience.kmath.operations.Group import space.kscience.kmath.operations.ScaleOperations @@ -34,10 +35,10 @@ public class IndexedHistogram, V : Any>( return context.produceBin(index, values[index]) } - override val dimension: Int get() = context.strides.shape.size + override val dimension: Int get() = context.shape.size override val bins: Iterable> - get() = context.strides.indices().map { + get() = DefaultStrides(context.shape).indices().map { context.produceBin(it, values[it]) }.asIterable() @@ -49,7 +50,7 @@ public class IndexedHistogram, V : Any>( public interface IndexedHistogramSpace, V : Any> : Group>, ScaleOperations> { //public val valueSpace: Space - public val strides: Strides + public val shape: Shape public val histogramValueSpace: FieldND //= NDAlgebra.space(valueSpace, Buffer.Companion::boxing, *shape), /** @@ -66,10 +67,10 @@ public interface IndexedHistogramSpace, V : Any> public fun produce(builder: HistogramBuilder.() -> Unit): IndexedHistogram - override fun add(a: IndexedHistogram, b: IndexedHistogram): IndexedHistogram { - require(a.context == this) { "Can't operate on a histogram produced by external space" } - require(b.context == this) { "Can't operate on a histogram produced by external space" } - return IndexedHistogram(this, histogramValueSpace { a.values + b.values }) + override fun add(left: IndexedHistogram, right: IndexedHistogram): IndexedHistogram { + require(left.context == this) { "Can't operate on a histogram produced by external space" } + require(right.context == this) { "Can't operate on a histogram produced by external space" } + return IndexedHistogram(this, histogramValueSpace { left.values + right.values }) } override fun scale(a: IndexedHistogram, value: Double): IndexedHistogram { diff --git a/kmath-histograms/src/commonTest/kotlin/space/kscience/kmath/histogram/MultivariateHistogramTest.kt b/kmath-histograms/src/commonTest/kotlin/space/kscience/kmath/histogram/MultivariateHistogramTest.kt index c797eb65a..23dd076e1 100644 --- a/kmath-histograms/src/commonTest/kotlin/space/kscience/kmath/histogram/MultivariateHistogramTest.kt +++ b/kmath-histograms/src/commonTest/kotlin/space/kscience/kmath/histogram/MultivariateHistogramTest.kt @@ -5,6 +5,7 @@ package space.kscience.kmath.histogram +import space.kscience.kmath.nd.DefaultStrides import space.kscience.kmath.operations.invoke import space.kscience.kmath.real.DoubleVector import kotlin.random.Random @@ -69,7 +70,7 @@ internal class MultivariateHistogramTest { } val res = histogram1 - histogram2 assertTrue { - strides.indices().all { index -> + DefaultStrides(shape).indices().all { index -> res.values[index] <= histogram1.values[index] } } diff --git a/kmath-histograms/src/jvmMain/kotlin/space/kscience/kmath/histogram/TreeHistogramSpace.kt b/kmath-histograms/src/jvmMain/kotlin/space/kscience/kmath/histogram/TreeHistogramSpace.kt index 96f945f6a..cc54d7e1a 100644 --- a/kmath-histograms/src/jvmMain/kotlin/space/kscience/kmath/histogram/TreeHistogramSpace.kt +++ b/kmath-histograms/src/jvmMain/kotlin/space/kscience/kmath/histogram/TreeHistogramSpace.kt @@ -88,20 +88,20 @@ public class TreeHistogramSpace( TreeHistogramBuilder(binFactory).apply(block).build() override fun add( - a: UnivariateHistogram, - b: UnivariateHistogram, + left: UnivariateHistogram, + right: UnivariateHistogram, ): UnivariateHistogram { // require(a.context == this) { "Histogram $a does not belong to this context" } // require(b.context == this) { "Histogram $b does not belong to this context" } val bins = TreeMap().apply { - (a.bins.map { it.domain } union b.bins.map { it.domain }).forEach { def -> + (left.bins.map { it.domain } union right.bins.map { it.domain }).forEach { def -> put( def.center, UnivariateBin( def, - value = (a[def.center]?.value ?: 0.0) + (b[def.center]?.value ?: 0.0), - standardDeviation = (a[def.center]?.standardDeviation - ?: 0.0) + (b[def.center]?.standardDeviation ?: 0.0) + value = (left[def.center]?.value ?: 0.0) + (right[def.center]?.value ?: 0.0), + standardDeviation = (left[def.center]?.standardDeviation + ?: 0.0) + (right[def.center]?.standardDeviation ?: 0.0) ) ) } diff --git a/kmath-jafama/src/main/kotlin/space/kscience/kmath/jafama/KMathJafama.kt b/kmath-jafama/src/main/kotlin/space/kscience/kmath/jafama/KMathJafama.kt index 1a6e3325b..645a14e30 100644 --- a/kmath-jafama/src/main/kotlin/space/kscience/kmath/jafama/KMathJafama.kt +++ b/kmath-jafama/src/main/kotlin/space/kscience/kmath/jafama/KMathJafama.kt @@ -28,10 +28,10 @@ public object JafamaDoubleField : ExtendedField, Norm, S else -> super.binaryOperationFunction(operation) } - override inline fun add(a: Double, b: Double): Double = a + b + override inline fun add(left: Double, right: Double): Double = left + right - override inline fun multiply(a: Double, b: Double): Double = a * b - override inline fun divide(a: Double, b: Double): Double = a / b + override inline fun multiply(left: Double, right: Double): Double = left * right + override inline fun divide(left: Double, right: Double): Double = left / right override inline fun scale(a: Double, value: Double): Double = a * value @@ -57,10 +57,10 @@ public object JafamaDoubleField : ExtendedField, Norm, S override inline fun norm(arg: Double): Double = FastMath.abs(arg) override inline fun Double.unaryMinus(): Double = -this - override inline fun Double.plus(b: Double): Double = this + b - override inline fun Double.minus(b: Double): Double = this - b - override inline fun Double.times(b: Double): Double = this * b - override inline fun Double.div(b: Double): Double = this / b + override inline fun Double.plus(other: Double): Double = this + other + override inline fun Double.minus(other: Double): Double = this - other + override inline fun Double.times(other: Double): Double = this * other + override inline fun Double.div(other: Double): Double = this / other } /** @@ -79,10 +79,10 @@ public object StrictJafamaDoubleField : ExtendedField, Norm super.binaryOperationFunction(operation) } - override inline fun add(a: Double, b: Double): Double = a + b + override inline fun add(left: Double, right: Double): Double = left + right - override inline fun multiply(a: Double, b: Double): Double = a * b - override inline fun divide(a: Double, b: Double): Double = a / b + override inline fun multiply(left: Double, right: Double): Double = left * right + override inline fun divide(left: Double, right: Double): Double = left / right override inline fun scale(a: Double, value: Double): Double = a * value @@ -108,8 +108,8 @@ public object StrictJafamaDoubleField : ExtendedField, Norm> MST.toSFun(): SFun = when (this) { is Symbol -> toSVar() is MST.Unary -> when (operation) { - GroupOperations.PLUS_OPERATION -> +value.toSFun() - GroupOperations.MINUS_OPERATION -> -value.toSFun() + GroupOps.PLUS_OPERATION -> +value.toSFun() + GroupOps.MINUS_OPERATION -> -value.toSFun() TrigonometricOperations.SIN_OPERATION -> sin(value.toSFun()) TrigonometricOperations.COS_OPERATION -> cos(value.toSFun()) TrigonometricOperations.TAN_OPERATION -> tan(value.toSFun()) @@ -124,10 +124,10 @@ public fun > MST.toSFun(): SFun = when (this) { } is MST.Binary -> when (operation) { - GroupOperations.PLUS_OPERATION -> left.toSFun() + right.toSFun() - GroupOperations.MINUS_OPERATION -> left.toSFun() - right.toSFun() - RingOperations.TIMES_OPERATION -> left.toSFun() * right.toSFun() - FieldOperations.DIV_OPERATION -> left.toSFun() / right.toSFun() + GroupOps.PLUS_OPERATION -> left.toSFun() + right.toSFun() + GroupOps.MINUS_OPERATION -> left.toSFun() - right.toSFun() + RingOps.TIMES_OPERATION -> left.toSFun() * right.toSFun() + FieldOps.DIV_OPERATION -> left.toSFun() / right.toSFun() PowerOperations.POW_OPERATION -> left.toSFun() pow (right as MST.Numeric).toSConst() else -> error("Binary operation $operation not defined in $this") } diff --git a/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jArrayAlgebra.kt b/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jArrayAlgebra.kt index b604bf5f2..792890a2d 100644 --- a/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jArrayAlgebra.kt +++ b/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jArrayAlgebra.kt @@ -15,13 +15,6 @@ import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.nd.* import space.kscience.kmath.operations.* -internal fun AlgebraND<*, *>.checkShape(array: INDArray): INDArray { - val arrayShape = array.shape().toIntArray() - if (!shape.contentEquals(arrayShape)) throw ShapeMismatchException(shape, arrayShape) - return array -} - - /** * Represents [AlgebraND] over [Nd4jArrayAlgebra]. * @@ -39,33 +32,34 @@ public sealed interface Nd4jArrayAlgebra> : AlgebraND.ndArray: INDArray - override fun produce(initializer: C.(IntArray) -> T): Nd4jArrayStructure { + override fun produce(shape: Shape, initializer: C.(IntArray) -> T): Nd4jArrayStructure { val struct = Nd4j.create(*shape)!!.wrap() - struct.indicesIterator().forEach { struct[it] = elementContext.initializer(it) } + struct.indicesIterator().forEach { struct[it] = elementAlgebra.initializer(it) } return struct } override fun StructureND.map(transform: C.(T) -> T): Nd4jArrayStructure { val newStruct = ndArray.dup().wrap() - newStruct.elements().forEach { (idx, value) -> newStruct[idx] = elementContext.transform(value) } + newStruct.elements().forEach { (idx, value) -> newStruct[idx] = elementAlgebra.transform(value) } return newStruct } override fun StructureND.mapIndexed( transform: C.(index: IntArray, T) -> T, ): Nd4jArrayStructure { - val new = Nd4j.create(*this@Nd4jArrayAlgebra.shape).wrap() - new.indicesIterator().forEach { idx -> new[idx] = elementContext.transform(idx, this[idx]) } + val new = Nd4j.create(*shape).wrap() + new.indicesIterator().forEach { idx -> new[idx] = elementAlgebra.transform(idx, this[idx]) } return new } - override fun combine( - a: StructureND, - b: StructureND, + override fun zip( + left: StructureND, + right: StructureND, transform: C.(T, T) -> T, ): Nd4jArrayStructure { - val new = Nd4j.create(*shape).wrap() - new.indicesIterator().forEach { idx -> new[idx] = elementContext.transform(a[idx], b[idx]) } + require(left.shape.contentEquals(right.shape)) { "Can't zip tow structures of shape ${left.shape} and ${right.shape}" } + val new = Nd4j.create(*left.shape).wrap() + new.indicesIterator().forEach { idx -> new[idx] = elementAlgebra.transform(left[idx], right[idx]) } return new } } @@ -76,16 +70,13 @@ public sealed interface Nd4jArrayAlgebra> : AlgebraND> : GroupND, Nd4jArrayAlgebra { +public sealed interface Nd4jArrayGroupOps> : GroupOpsND, Nd4jArrayAlgebra { - override val zero: Nd4jArrayStructure - get() = Nd4j.zeros(*shape).wrap() + override fun add(left: StructureND, right: StructureND): Nd4jArrayStructure = + left.ndArray.add(right.ndArray).wrap() - override fun add(a: StructureND, b: StructureND): Nd4jArrayStructure = - a.ndArray.add(b.ndArray).wrap() - - override operator fun StructureND.minus(b: StructureND): Nd4jArrayStructure = - ndArray.sub(b.ndArray).wrap() + override operator fun StructureND.minus(other: StructureND): Nd4jArrayStructure = + ndArray.sub(other.ndArray).wrap() override operator fun StructureND.unaryMinus(): Nd4jArrayStructure = ndArray.neg().wrap() @@ -101,13 +92,10 @@ public sealed interface Nd4jArrayGroup> : GroupND, Nd4j * @param R the type of ring of structure elements. */ @OptIn(UnstableKMathAPI::class) -public sealed interface Nd4jArrayRing> : RingND, Nd4jArrayGroup { +public sealed interface Nd4jArrayRingOps> : RingOpsND, Nd4jArrayGroupOps { - override val one: Nd4jArrayStructure - get() = Nd4j.ones(*shape).wrap() - - override fun multiply(a: StructureND, b: StructureND): Nd4jArrayStructure = - a.ndArray.mul(b.ndArray).wrap() + override fun multiply(left: StructureND, right: StructureND): Nd4jArrayStructure = + left.ndArray.mul(right.ndArray).wrap() // // override operator fun Nd4jArrayStructure.minus(b: Number): Nd4jArrayStructure { // check(this) @@ -125,21 +113,12 @@ public sealed interface Nd4jArrayRing> : RingND, Nd4jAr // } public companion object { - private val intNd4jArrayRingCache: ThreadLocal> = - ThreadLocal.withInitial(::HashMap) - - /** - * Creates an [RingND] for [Int] values or pull it from cache if it was created previously. - */ - public fun int(vararg shape: Int): Nd4jArrayRing = - intNd4jArrayRingCache.get().getOrPut(shape) { IntNd4jArrayRing(shape) } - /** * Creates a most suitable implementation of [RingND] using reified class. */ @Suppress("UNCHECKED_CAST") - public inline fun auto(vararg shape: Int): Nd4jArrayRing> = when { - T::class == Int::class -> int(*shape) as Nd4jArrayRing> + public inline fun auto(vararg shape: Int): Nd4jArrayRingOps> = when { + T::class == Int::class -> IntRing.nd4j as Nd4jArrayRingOps> else -> throw UnsupportedOperationException("This factory method only supports Long type.") } } @@ -151,38 +130,21 @@ public sealed interface Nd4jArrayRing> : RingND, Nd4jAr * @param T the type of the element contained in ND structure. * @param F the type field of structure elements. */ -public sealed interface Nd4jArrayField> : FieldND, Nd4jArrayRing { - override fun divide(a: StructureND, b: StructureND): Nd4jArrayStructure = - a.ndArray.div(b.ndArray).wrap() +public sealed interface Nd4jArrayField> : FieldOpsND, Nd4jArrayRingOps { + + override fun divide(left: StructureND, right: StructureND): Nd4jArrayStructure = + left.ndArray.div(right.ndArray).wrap() public operator fun Number.div(b: StructureND): Nd4jArrayStructure = b.ndArray.rdiv(this).wrap() public companion object { - private val floatNd4jArrayFieldCache: ThreadLocal> = - ThreadLocal.withInitial(::HashMap) - - private val doubleNd4JArrayFieldCache: ThreadLocal> = - ThreadLocal.withInitial(::HashMap) - - /** - * Creates an [FieldND] for [Float] values or pull it from cache if it was created previously. - */ - public fun float(vararg shape: Int): Nd4jArrayRing = - floatNd4jArrayFieldCache.get().getOrPut(shape) { FloatNd4jArrayField(shape) } - - /** - * Creates an [FieldND] for [Double] values or pull it from cache if it was created previously. - */ - public fun real(vararg shape: Int): Nd4jArrayRing = - doubleNd4JArrayFieldCache.get().getOrPut(shape) { DoubleNd4jArrayField(shape) } - /** * Creates a most suitable implementation of [FieldND] using reified class. */ @Suppress("UNCHECKED_CAST") public inline fun auto(vararg shape: Int): Nd4jArrayField> = when { - T::class == Float::class -> float(*shape) as Nd4jArrayField> - T::class == Double::class -> real(*shape) as Nd4jArrayField> + T::class == Float::class -> FloatField.nd4j as Nd4jArrayField> + T::class == Double::class -> DoubleField.nd4j as Nd4jArrayField> else -> throw UnsupportedOperationException("This factory method only supports Float and Double types.") } } @@ -191,8 +153,9 @@ public sealed interface Nd4jArrayField> : FieldND, Nd4 /** * Represents intersection of [ExtendedField] and [Field] over [Nd4jArrayStructure]. */ -public sealed interface Nd4jArrayExtendedField> : ExtendedField>, - Nd4jArrayField { +public sealed interface Nd4jArrayExtendedFieldOps> : + ExtendedFieldOps>, Nd4jArrayField { + override fun sin(arg: StructureND): StructureND = Transforms.sin(arg.ndArray).wrap() override fun cos(arg: StructureND): StructureND = Transforms.cos(arg.ndArray).wrap() override fun asin(arg: StructureND): StructureND = Transforms.asin(arg.ndArray).wrap() @@ -221,63 +184,59 @@ public sealed interface Nd4jArrayExtendedField> : Ex /** * Represents [FieldND] over [Nd4jArrayDoubleStructure]. */ -public class DoubleNd4jArrayField(override val shape: IntArray) : Nd4jArrayExtendedField { - override val elementContext: DoubleField get() = DoubleField +public open class DoubleNd4jArrayFieldOps : Nd4jArrayExtendedFieldOps { + override val elementAlgebra: DoubleField get() = DoubleField - override fun INDArray.wrap(): Nd4jArrayStructure = checkShape(this).asDoubleStructure() + override fun INDArray.wrap(): Nd4jArrayStructure = asDoubleStructure() @OptIn(PerformancePitfall::class) override val StructureND.ndArray: INDArray get() = when (this) { - is Nd4jArrayStructure -> checkShape(ndArray) + is Nd4jArrayStructure -> ndArray else -> Nd4j.zeros(*shape).also { elements().forEach { (idx, value) -> it.putScalar(idx, value) } } } - override fun scale(a: StructureND, value: Double): Nd4jArrayStructure { - return a.ndArray.mul(value).wrap() - } + override fun scale(a: StructureND, value: Double): Nd4jArrayStructure = a.ndArray.mul(value).wrap() - override operator fun StructureND.div(arg: Double): Nd4jArrayStructure { - return ndArray.div(arg).wrap() - } + override operator fun StructureND.div(arg: Double): Nd4jArrayStructure = ndArray.div(arg).wrap() - override operator fun StructureND.plus(arg: Double): Nd4jArrayStructure { - return ndArray.add(arg).wrap() - } + override operator fun StructureND.plus(arg: Double): Nd4jArrayStructure = ndArray.add(arg).wrap() - override operator fun StructureND.minus(arg: Double): Nd4jArrayStructure { - return ndArray.sub(arg).wrap() - } + override operator fun StructureND.minus(arg: Double): Nd4jArrayStructure = ndArray.sub(arg).wrap() - override operator fun StructureND.times(arg: Double): Nd4jArrayStructure { - return ndArray.mul(arg).wrap() - } + override operator fun StructureND.times(arg: Double): Nd4jArrayStructure = ndArray.mul(arg).wrap() - override operator fun Double.div(arg: StructureND): Nd4jArrayStructure { - return arg.ndArray.rdiv(this).wrap() - } + override operator fun Double.div(arg: StructureND): Nd4jArrayStructure = + arg.ndArray.rdiv(this).wrap() - override operator fun Double.minus(arg: StructureND): Nd4jArrayStructure { - return arg.ndArray.rsub(this).wrap() - } + override operator fun Double.minus(arg: StructureND): Nd4jArrayStructure = + arg.ndArray.rsub(this).wrap() + + public companion object : DoubleNd4jArrayFieldOps() } -public fun DoubleField.nd4j(vararg shape: Int): DoubleNd4jArrayField = DoubleNd4jArrayField(intArrayOf(*shape)) +public val DoubleField.nd4j: DoubleNd4jArrayFieldOps get() = DoubleNd4jArrayFieldOps + +public class DoubleNd4jArrayField(override val shape: Shape) : DoubleNd4jArrayFieldOps(), FieldND + +public fun DoubleField.nd4j(shapeFirst: Int, vararg shapeRest: Int): DoubleNd4jArrayField = + DoubleNd4jArrayField(intArrayOf(shapeFirst, * shapeRest)) + /** * Represents [FieldND] over [Nd4jArrayStructure] of [Float]. */ -public class FloatNd4jArrayField(override val shape: IntArray) : Nd4jArrayExtendedField { - override val elementContext: FloatField get() = FloatField +public open class FloatNd4jArrayFieldOps : Nd4jArrayExtendedFieldOps { + override val elementAlgebra: FloatField get() = FloatField - override fun INDArray.wrap(): Nd4jArrayStructure = checkShape(this).asFloatStructure() + override fun INDArray.wrap(): Nd4jArrayStructure = asFloatStructure() @OptIn(PerformancePitfall::class) override val StructureND.ndArray: INDArray get() = when (this) { - is Nd4jArrayStructure -> checkShape(ndArray) + is Nd4jArrayStructure -> ndArray else -> Nd4j.zeros(*shape).also { elements().forEach { (idx, value) -> it.putScalar(idx, value) } } @@ -303,21 +262,29 @@ public class FloatNd4jArrayField(override val shape: IntArray) : Nd4jArrayExtend override operator fun Float.minus(arg: StructureND): Nd4jArrayStructure = arg.ndArray.rsub(this).wrap() + + public companion object : FloatNd4jArrayFieldOps() } +public class FloatNd4jArrayField(override val shape: Shape) : FloatNd4jArrayFieldOps(), RingND + +public val FloatField.nd4j: FloatNd4jArrayFieldOps get() = FloatNd4jArrayFieldOps + +public fun FloatField.nd4j(shapeFirst: Int, vararg shapeRest: Int): FloatNd4jArrayField = + FloatNd4jArrayField(intArrayOf(shapeFirst, * shapeRest)) + /** * Represents [RingND] over [Nd4jArrayIntStructure]. */ -public class IntNd4jArrayRing(override val shape: IntArray) : Nd4jArrayRing { - override val elementContext: IntRing - get() = IntRing +public open class IntNd4jArrayRingOps : Nd4jArrayRingOps { + override val elementAlgebra: IntRing get() = IntRing - override fun INDArray.wrap(): Nd4jArrayStructure = checkShape(this).asIntStructure() + override fun INDArray.wrap(): Nd4jArrayStructure = asIntStructure() @OptIn(PerformancePitfall::class) override val StructureND.ndArray: INDArray get() = when (this) { - is Nd4jArrayStructure -> checkShape(ndArray) + is Nd4jArrayStructure -> ndArray else -> Nd4j.zeros(*shape).also { elements().forEach { (idx, value) -> it.putScalar(idx, value) } } @@ -334,4 +301,13 @@ public class IntNd4jArrayRing(override val shape: IntArray) : Nd4jArrayRing): Nd4jArrayStructure = arg.ndArray.rsub(this).wrap() + + public companion object : IntNd4jArrayRingOps() } + +public val IntRing.nd4j: IntNd4jArrayRingOps get() = IntNd4jArrayRingOps + +public class IntNd4jArrayRing(override val shape: Shape) : IntNd4jArrayRingOps(), RingND + +public fun IntRing.nd4j(shapeFirst: Int, vararg shapeRest: Int): IntNd4jArrayRing = + IntNd4jArrayRing(intArrayOf(shapeFirst, * shapeRest)) \ No newline at end of file diff --git a/kmath-nd4j/src/test/kotlin/space/kscience/kmath/nd4j/Nd4jArrayAlgebraTest.kt b/kmath-nd4j/src/test/kotlin/space/kscience/kmath/nd4j/Nd4jArrayAlgebraTest.kt index a03a7269e..465937fa9 100644 --- a/kmath-nd4j/src/test/kotlin/space/kscience/kmath/nd4j/Nd4jArrayAlgebraTest.kt +++ b/kmath-nd4j/src/test/kotlin/space/kscience/kmath/nd4j/Nd4jArrayAlgebraTest.kt @@ -8,6 +8,10 @@ package space.kscience.kmath.nd4j import org.nd4j.linalg.factory.Nd4j import space.kscience.kmath.misc.PerformancePitfall import space.kscience.kmath.nd.StructureND +import space.kscience.kmath.nd.one +import space.kscience.kmath.nd.produce +import space.kscience.kmath.operations.DoubleField +import space.kscience.kmath.operations.IntRing import space.kscience.kmath.operations.invoke import kotlin.math.PI import kotlin.test.Test @@ -19,7 +23,7 @@ import kotlin.test.fail internal class Nd4jArrayAlgebraTest { @Test fun testProduce() { - val res = with(DoubleNd4jArrayField(intArrayOf(2, 2))) { produce { it.sum().toDouble() } } + val res = DoubleField.nd4j.produce(2, 2) { it.sum().toDouble() } val expected = (Nd4j.create(2, 2) ?: fail()).asDoubleStructure() expected[intArrayOf(0, 0)] = 0.0 expected[intArrayOf(0, 1)] = 1.0 @@ -30,7 +34,9 @@ internal class Nd4jArrayAlgebraTest { @Test fun testMap() { - val res = with(IntNd4jArrayRing(intArrayOf(2, 2))) { one.map { it + it * 2 } } + val res = IntRing.nd4j { + one(2, 2).map { it + it * 2 } + } val expected = (Nd4j.create(2, 2) ?: fail()).asIntStructure() expected[intArrayOf(0, 0)] = 3 expected[intArrayOf(0, 1)] = 3 @@ -41,7 +47,7 @@ internal class Nd4jArrayAlgebraTest { @Test fun testAdd() { - val res = with(IntNd4jArrayRing(intArrayOf(2, 2))) { one + 25 } + val res = IntRing.nd4j { one(2, 2) + 25 } val expected = (Nd4j.create(2, 2) ?: fail()).asIntStructure() expected[intArrayOf(0, 0)] = 26 expected[intArrayOf(0, 1)] = 26 @@ -51,10 +57,10 @@ internal class Nd4jArrayAlgebraTest { } @Test - fun testSin() = DoubleNd4jArrayField(intArrayOf(2, 2)).invoke { - val initial = produce { (i, j) -> if (i == j) PI / 2 else 0.0 } + fun testSin() = DoubleField.nd4j{ + val initial = produce(2, 2) { (i, j) -> if (i == j) PI / 2 else 0.0 } val transformed = sin(initial) - val expected = produce { (i, j) -> if (i == j) 1.0 else 0.0 } + val expected = produce(2, 2) { (i, j) -> if (i == j) 1.0 else 0.0 } println(transformed) assertTrue { StructureND.contentEquals(transformed, expected) } diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/SamplerAlgebra.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/SamplerAlgebra.kt index c1bbace86..e0be72d4b 100644 --- a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/SamplerAlgebra.kt +++ b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/SamplerAlgebra.kt @@ -41,8 +41,8 @@ public class SamplerSpace(public val algebra: S) : Group = ConstantSampler(algebra.zero) - override fun add(a: Sampler, b: Sampler): Sampler = BasicSampler { generator -> - a.sample(generator).zip(b.sample(generator)) { aValue, bValue -> algebra { aValue + bValue } } + override fun add(left: Sampler, right: Sampler): Sampler = BasicSampler { generator -> + left.sample(generator).zip(right.sample(generator)) { aValue, bValue -> algebra { aValue + bValue } } } override fun scale(a: Sampler, value: Double): Sampler = BasicSampler { generator -> diff --git a/kmath-symja/src/main/kotlin/space/kscience/kmath/symja/adapters.kt b/kmath-symja/src/main/kotlin/space/kscience/kmath/symja/adapters.kt index 8def1ae83..a7ca298a9 100644 --- a/kmath-symja/src/main/kotlin/space/kscience/kmath/symja/adapters.kt +++ b/kmath-symja/src/main/kotlin/space/kscience/kmath/symja/adapters.kt @@ -64,8 +64,8 @@ public fun MST.toIExpr(): IExpr = when (this) { } is MST.Unary -> when (operation) { - GroupOperations.PLUS_OPERATION -> value.toIExpr() - GroupOperations.MINUS_OPERATION -> F.Negate(value.toIExpr()) + GroupOps.PLUS_OPERATION -> value.toIExpr() + GroupOps.MINUS_OPERATION -> F.Negate(value.toIExpr()) TrigonometricOperations.SIN_OPERATION -> F.Sin(value.toIExpr()) TrigonometricOperations.COS_OPERATION -> F.Cos(value.toIExpr()) TrigonometricOperations.TAN_OPERATION -> F.Tan(value.toIExpr()) @@ -85,10 +85,10 @@ public fun MST.toIExpr(): IExpr = when (this) { } is MST.Binary -> when (operation) { - GroupOperations.PLUS_OPERATION -> left.toIExpr() + right.toIExpr() - GroupOperations.MINUS_OPERATION -> left.toIExpr() - right.toIExpr() - RingOperations.TIMES_OPERATION -> left.toIExpr() * right.toIExpr() - FieldOperations.DIV_OPERATION -> F.Divide(left.toIExpr(), right.toIExpr()) + GroupOps.PLUS_OPERATION -> left.toIExpr() + right.toIExpr() + GroupOps.MINUS_OPERATION -> left.toIExpr() - right.toIExpr() + RingOps.TIMES_OPERATION -> left.toIExpr() * right.toIExpr() + FieldOps.DIV_OPERATION -> F.Divide(left.toIExpr(), right.toIExpr()) PowerOperations.POW_OPERATION -> F.Power(left.toIExpr(), F.symjify((right as MST.Numeric).value)) else -> error("Binary operation $operation not defined in $this") } diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt index 6076748d9..810ebe777 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt @@ -5,7 +5,7 @@ package space.kscience.kmath.tensors.api -import space.kscience.kmath.operations.Algebra +import space.kscience.kmath.operations.RingOps /** * Algebra over a ring on [Tensor]. @@ -13,7 +13,7 @@ import space.kscience.kmath.operations.Algebra * * @param T the type of items in the tensors. */ -public interface TensorAlgebra : Algebra> { +public interface TensorAlgebra : RingOps> { /** * Returns a single tensor value of unit dimension if tensor shape equals to [1]. * @@ -53,7 +53,7 @@ public interface TensorAlgebra : Algebra> { * @param other tensor to be added. * @return the sum of this tensor and [other]. */ - public operator fun Tensor.plus(other: Tensor): Tensor + override fun Tensor.plus(other: Tensor): Tensor /** * Adds the scalar [value] to each element of this tensor. @@ -93,7 +93,7 @@ public interface TensorAlgebra : Algebra> { * @param other tensor to be subtracted. * @return the difference between this tensor and [other]. */ - public operator fun Tensor.minus(other: Tensor): Tensor + override fun Tensor.minus(other: Tensor): Tensor /** * Subtracts the scalar [value] from each element of this tensor. @@ -134,7 +134,7 @@ public interface TensorAlgebra : Algebra> { * @param other tensor to be multiplied. * @return the product of this tensor and [other]. */ - public operator fun Tensor.times(other: Tensor): Tensor + override fun Tensor.times(other: Tensor): Tensor /** * Multiplies the scalar [value] by each element of this tensor. @@ -155,7 +155,7 @@ public interface TensorAlgebra : Algebra> { * * @return tensor negation of the original tensor. */ - public operator fun Tensor.unaryMinus(): Tensor + override fun Tensor.unaryMinus(): Tensor /** * Returns the tensor at index i @@ -323,4 +323,8 @@ public interface TensorAlgebra : Algebra> { * @return the index of maximum value of each row of the input tensor in the given dimension [dim]. */ public fun Tensor.argMax(dim: Int, keepDim: Boolean): Tensor + + override fun add(left: Tensor, right: Tensor): Tensor = left + right + + override fun multiply(left: Tensor, right: Tensor): Tensor = left * right } diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt index 8e39e6cdd..594070cd2 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt @@ -373,8 +373,12 @@ public open class DoubleTensorAlgebra : return resTensor } - override fun diagonalEmbedding(diagonalEntries: Tensor, offset: Int, dim1: Int, dim2: Int): - DoubleTensor { + override fun diagonalEmbedding( + diagonalEntries: Tensor, + offset: Int, + dim1: Int, + dim2: Int + ): DoubleTensor { val n = diagonalEntries.shape.size val d1 = minusIndexFrom(n + 1, dim1) val d2 = minusIndexFrom(n + 1, dim2) diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/internal/TensorLinearStructure.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/internal/TensorLinearStructure.kt index 817ed60d8..57668722a 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/internal/TensorLinearStructure.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/internal/TensorLinearStructure.kt @@ -44,7 +44,7 @@ internal fun indexFromOffset(offset: Int, strides: IntArray, nDim: Int): IntArra * * @param shape the shape of the tensor. */ -internal class TensorLinearStructure(override val shape: IntArray) : Strides { +internal class TensorLinearStructure(override val shape: IntArray) : Strides() { override val strides: IntArray get() = stridesFromShape(shape) @@ -54,4 +54,18 @@ internal class TensorLinearStructure(override val shape: IntArray) : Strides { override val linearSize: Int get() = shape.reduce(Int::times) + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (other == null || this::class != other::class) return false + + other as TensorLinearStructure + + if (!shape.contentEquals(other.shape)) return false + + return true + } + + override fun hashCode(): Int { + return shape.contentHashCode() + } } diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/internal/tensorCastsUtils.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/internal/tensorCastsUtils.kt index 29aa02931..1f5778d5e 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/internal/tensorCastsUtils.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/internal/tensorCastsUtils.kt @@ -26,8 +26,11 @@ internal fun Tensor.copyToBufferedTensor(): BufferedTensor = internal fun Tensor.toBufferedTensor(): BufferedTensor = when (this) { is BufferedTensor -> this - is MutableBufferND -> if (this.strides.strides contentEquals TensorLinearStructure(this.shape).strides) - BufferedTensor(this.shape, this.mutableBuffer, 0) else this.copyToBufferedTensor() + is MutableBufferND -> if (this.indexes == TensorLinearStructure(this.shape)) { + BufferedTensor(this.shape, this.buffer, 0) + } else { + this.copyToBufferedTensor() + } else -> this.copyToBufferedTensor() } diff --git a/kmath-viktor/src/main/kotlin/space/kscience/kmath/viktor/ViktorFieldOpsND.kt b/kmath-viktor/src/main/kotlin/space/kscience/kmath/viktor/ViktorFieldOpsND.kt new file mode 100644 index 000000000..c72553a64 --- /dev/null +++ b/kmath-viktor/src/main/kotlin/space/kscience/kmath/viktor/ViktorFieldOpsND.kt @@ -0,0 +1,124 @@ +/* + * Copyright 2018-2021 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. + */ + +package space.kscience.kmath.viktor + +import org.jetbrains.bio.viktor.F64Array +import space.kscience.kmath.misc.UnstableKMathAPI +import space.kscience.kmath.nd.* +import space.kscience.kmath.operations.DoubleField +import space.kscience.kmath.operations.ExtendedFieldOps +import space.kscience.kmath.operations.NumbersAddOps + +@OptIn(UnstableKMathAPI::class) +@Suppress("OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") +public open class ViktorFieldOpsND : + FieldOpsND, + ExtendedFieldOps> { + + public val StructureND.f64Buffer: F64Array + get() = when (this) { + is ViktorStructureND -> this.f64Buffer + else -> produce(shape) { this@f64Buffer[it] }.f64Buffer + } + + override val elementAlgebra: DoubleField get() = DoubleField + + override fun produce(shape: IntArray, initializer: DoubleField.(IntArray) -> Double): ViktorStructureND = + F64Array(*shape).apply { + DefaultStrides(shape).indices().forEach { index -> + set(value = DoubleField.initializer(index), indices = index) + } + }.asStructure() + + override fun StructureND.unaryMinus(): StructureND = -1 * this + + override fun StructureND.map(transform: DoubleField.(Double) -> Double): ViktorStructureND = + F64Array(*shape).apply { + DefaultStrides(shape).indices().forEach { index -> + set(value = DoubleField.transform(this@map[index]), indices = index) + } + }.asStructure() + + override fun StructureND.mapIndexed( + transform: DoubleField.(index: IntArray, Double) -> Double, + ): ViktorStructureND = F64Array(*shape).apply { + DefaultStrides(shape).indices().forEach { index -> + set(value = DoubleField.transform(index, this@mapIndexed[index]), indices = index) + } + }.asStructure() + + override fun zip( + left: StructureND, + right: StructureND, + transform: DoubleField.(Double, Double) -> Double, + ): ViktorStructureND { + require(left.shape.contentEquals(right.shape)) + return F64Array(*left.shape).apply { + DefaultStrides(left.shape).indices().forEach { index -> + set(value = DoubleField.transform(left[index], right[index]), indices = index) + } + }.asStructure() + } + + override fun add(left: StructureND, right: StructureND): ViktorStructureND = + (left.f64Buffer + right.f64Buffer).asStructure() + + override fun scale(a: StructureND, value: Double): ViktorStructureND = + (a.f64Buffer * value).asStructure() + + override fun StructureND.plus(other: StructureND): ViktorStructureND = + (f64Buffer + other.f64Buffer).asStructure() + + override fun StructureND.minus(other: StructureND): ViktorStructureND = + (f64Buffer - other.f64Buffer).asStructure() + + override fun StructureND.times(k: Number): ViktorStructureND = + (f64Buffer * k.toDouble()).asStructure() + + override fun StructureND.plus(arg: Double): ViktorStructureND = + (f64Buffer.plus(arg)).asStructure() + + override fun sin(arg: StructureND): ViktorStructureND = arg.map { sin(it) } + override fun cos(arg: StructureND): ViktorStructureND = arg.map { cos(it) } + override fun tan(arg: StructureND): ViktorStructureND = arg.map { tan(it) } + override fun asin(arg: StructureND): ViktorStructureND = arg.map { asin(it) } + override fun acos(arg: StructureND): ViktorStructureND = arg.map { acos(it) } + override fun atan(arg: StructureND): ViktorStructureND = arg.map { atan(it) } + + override fun power(arg: StructureND, pow: Number): ViktorStructureND = arg.map { it.pow(pow) } + + override fun exp(arg: StructureND): ViktorStructureND = arg.f64Buffer.exp().asStructure() + + override fun ln(arg: StructureND): ViktorStructureND = arg.f64Buffer.log().asStructure() + + override fun sinh(arg: StructureND): ViktorStructureND = arg.map { sinh(it) } + + override fun cosh(arg: StructureND): ViktorStructureND = arg.map { cosh(it) } + + override fun asinh(arg: StructureND): ViktorStructureND = arg.map { asinh(it) } + + override fun acosh(arg: StructureND): ViktorStructureND = arg.map { acosh(it) } + + override fun atanh(arg: StructureND): ViktorStructureND = arg.map { atanh(it) } + + public companion object : ViktorFieldOpsND() +} + +public val DoubleField.viktorAlgebra: ViktorFieldOpsND get() = ViktorFieldOpsND + +public open class ViktorFieldND( + override val shape: Shape +) : ViktorFieldOpsND(), FieldND, NumbersAddOps> { + override val zero: ViktorStructureND by lazy { F64Array.full(init = 0.0, shape = shape).asStructure() } + override val one: ViktorStructureND by lazy { F64Array.full(init = 1.0, shape = shape).asStructure() } + + override fun number(value: Number): ViktorStructureND = + F64Array.full(init = value.toDouble(), shape = shape).asStructure() +} + +public fun DoubleField.viktorAlgebra(vararg shape: Int): ViktorFieldND = ViktorFieldND(shape) + +public fun ViktorFieldND(vararg shape: Int): ViktorFieldND = ViktorFieldND(shape) \ No newline at end of file diff --git a/kmath-viktor/src/main/kotlin/space/kscience/kmath/viktor/ViktorStructureND.kt b/kmath-viktor/src/main/kotlin/space/kscience/kmath/viktor/ViktorStructureND.kt index 682123ddd..0d29983f9 100644 --- a/kmath-viktor/src/main/kotlin/space/kscience/kmath/viktor/ViktorStructureND.kt +++ b/kmath-viktor/src/main/kotlin/space/kscience/kmath/viktor/ViktorStructureND.kt @@ -7,12 +7,8 @@ package space.kscience.kmath.viktor import org.jetbrains.bio.viktor.F64Array import space.kscience.kmath.misc.PerformancePitfall -import space.kscience.kmath.misc.UnstableKMathAPI -import space.kscience.kmath.nd.* -import space.kscience.kmath.operations.DoubleField -import space.kscience.kmath.operations.ExtendedField -import space.kscience.kmath.operations.NumbersAddOperations -import space.kscience.kmath.operations.ScaleOperations +import space.kscience.kmath.nd.DefaultStrides +import space.kscience.kmath.nd.MutableStructureND @Suppress("OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") public class ViktorStructureND(public val f64Buffer: F64Array) : MutableStructureND { @@ -31,96 +27,4 @@ public class ViktorStructureND(public val f64Buffer: F64Array) : MutableStructur public fun F64Array.asStructure(): ViktorStructureND = ViktorStructureND(this) -@OptIn(UnstableKMathAPI::class) -@Suppress("OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") -public class ViktorFieldND(override val shape: IntArray) : FieldND, - NumbersAddOperations>, ExtendedField>, - ScaleOperations> { - public val StructureND.f64Buffer: F64Array - get() = when { - !shape.contentEquals(this@ViktorFieldND.shape) -> throw ShapeMismatchException( - this@ViktorFieldND.shape, - shape - ) - this is ViktorStructureND && this.f64Buffer.shape.contentEquals(this@ViktorFieldND.shape) -> this.f64Buffer - else -> produce { this@f64Buffer[it] }.f64Buffer - } - - override val zero: ViktorStructureND by lazy { F64Array.full(init = 0.0, shape = shape).asStructure() } - override val one: ViktorStructureND by lazy { F64Array.full(init = 1.0, shape = shape).asStructure() } - - private val strides: Strides = DefaultStrides(shape) - - override val elementContext: DoubleField get() = DoubleField - - override fun produce(initializer: DoubleField.(IntArray) -> Double): ViktorStructureND = - F64Array(*shape).apply { - this@ViktorFieldND.strides.indices().forEach { index -> - set(value = DoubleField.initializer(index), indices = index) - } - }.asStructure() - - override fun StructureND.unaryMinus(): StructureND = -1 * this - - override fun StructureND.map(transform: DoubleField.(Double) -> Double): ViktorStructureND = - F64Array(*this@ViktorFieldND.shape).apply { - this@ViktorFieldND.strides.indices().forEach { index -> - set(value = DoubleField.transform(this@map[index]), indices = index) - } - }.asStructure() - - override fun StructureND.mapIndexed( - transform: DoubleField.(index: IntArray, Double) -> Double, - ): ViktorStructureND = F64Array(*this@ViktorFieldND.shape).apply { - this@ViktorFieldND.strides.indices().forEach { index -> - set(value = DoubleField.transform(index, this@mapIndexed[index]), indices = index) - } - }.asStructure() - - override fun combine( - a: StructureND, - b: StructureND, - transform: DoubleField.(Double, Double) -> Double, - ): ViktorStructureND = F64Array(*shape).apply { - this@ViktorFieldND.strides.indices().forEach { index -> - set(value = DoubleField.transform(a[index], b[index]), indices = index) - } - }.asStructure() - - override fun add(a: StructureND, b: StructureND): ViktorStructureND = - (a.f64Buffer + b.f64Buffer).asStructure() - - override fun scale(a: StructureND, value: Double): ViktorStructureND = - (a.f64Buffer * value).asStructure() - - override inline fun StructureND.plus(b: StructureND): ViktorStructureND = - (f64Buffer + b.f64Buffer).asStructure() - - override inline fun StructureND.minus(b: StructureND): ViktorStructureND = - (f64Buffer - b.f64Buffer).asStructure() - - override inline fun StructureND.times(k: Number): ViktorStructureND = - (f64Buffer * k.toDouble()).asStructure() - - override inline fun StructureND.plus(arg: Double): ViktorStructureND = - (f64Buffer.plus(arg)).asStructure() - - override fun number(value: Number): ViktorStructureND = - F64Array.full(init = value.toDouble(), shape = shape).asStructure() - - override fun sin(arg: StructureND): ViktorStructureND = arg.map { sin(it) } - override fun cos(arg: StructureND): ViktorStructureND = arg.map { cos(it) } - override fun tan(arg: StructureND): ViktorStructureND = arg.map { tan(it) } - override fun asin(arg: StructureND): ViktorStructureND = arg.map { asin(it) } - override fun acos(arg: StructureND): ViktorStructureND = arg.map { acos(it) } - override fun atan(arg: StructureND): ViktorStructureND = arg.map { atan(it) } - - override fun power(arg: StructureND, pow: Number): ViktorStructureND = arg.map { it.pow(pow) } - - override fun exp(arg: StructureND): ViktorStructureND = arg.f64Buffer.exp().asStructure() - - override fun ln(arg: StructureND): ViktorStructureND = arg.f64Buffer.log().asStructure() -} - -public fun ViktorNDField(vararg shape: Int): ViktorFieldND = ViktorFieldND(shape) diff --git a/settings.gradle.kts b/settings.gradle.kts index 528adb336..dc70cbb9e 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -1,16 +1,18 @@ pluginManagement { repositories { - mavenLocal() maven("https://repo.kotlin.link") mavenCentral() gradlePluginPortal() } - val kotlinVersion = "1.6.0-M1" + val kotlinVersion = "1.6.0-RC" + val toolsVersion = "0.10.5" plugins { id("org.jetbrains.kotlinx.benchmark") version "0.3.1" - id("ru.mipt.npm.gradle.project") version "0.10.5" + id("ru.mipt.npm.gradle.project") version toolsVersion + id("ru.mipt.npm.gradle.jvm") version toolsVersion + id("ru.mipt.npm.gradle.mpp") version toolsVersion kotlin("multiplatform") version kotlinVersion kotlin("plugin.allopen") version kotlinVersion }