Shapeless ND and Buffer algebras

This commit is contained in:
Alexander Nozik 2021-10-17 11:12:35 +03:00
parent 8d2770c275
commit d0354da80a
56 changed files with 893 additions and 906 deletions

View File

@ -42,6 +42,9 @@
- Use `Symbol` factory function instead of `StringSymbol` - Use `Symbol` factory function instead of `StringSymbol`
- New discoverability pattern: `<Type>.algebra.<nd/etc>` - New discoverability pattern: `<Type>.algebra.<nd/etc>`
- Adjusted commons-math API for linear solvers to match conventions. - Adjusted commons-math API for linear solvers to match conventions.
- Buffer algebra does not require size anymore
- Operations -> Ops
- Default Buffer and ND algebras are now Ops and lack neutral elements (0, 1) as well as algebra-level shapes.
### Deprecated ### Deprecated
- Specialized `DoubleBufferAlgebra` - Specialized `DoubleBufferAlgebra`

View File

@ -9,9 +9,10 @@ import kotlinx.benchmark.Benchmark
import kotlinx.benchmark.Blackhole import kotlinx.benchmark.Blackhole
import kotlinx.benchmark.Scope import kotlinx.benchmark.Scope
import kotlinx.benchmark.State import kotlinx.benchmark.State
import space.kscience.kmath.nd.BufferedFieldOpsND
import space.kscience.kmath.nd.StructureND import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.nd.autoNdAlgebra
import space.kscience.kmath.nd.ndAlgebra import space.kscience.kmath.nd.ndAlgebra
import space.kscience.kmath.nd.one
import space.kscience.kmath.nd4j.nd4j import space.kscience.kmath.nd4j.nd4j
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.Buffer
@ -23,21 +24,21 @@ import space.kscience.kmath.tensors.core.tensorAlgebra
internal class NDFieldBenchmark { internal class NDFieldBenchmark {
@Benchmark @Benchmark
fun autoFieldAdd(blackhole: Blackhole) = with(autoField) { fun autoFieldAdd(blackhole: Blackhole) = with(autoField) {
var res: StructureND<Double> = one var res: StructureND<Double> = one(shape)
repeat(n) { res += one } repeat(n) { res += 1.0 }
blackhole.consume(res) blackhole.consume(res)
} }
@Benchmark @Benchmark
fun specializedFieldAdd(blackhole: Blackhole) = with(specializedField) { fun specializedFieldAdd(blackhole: Blackhole) = with(specializedField) {
var res: StructureND<Double> = one var res: StructureND<Double> = one(shape)
repeat(n) { res += 1.0 } repeat(n) { res += 1.0 }
blackhole.consume(res) blackhole.consume(res)
} }
@Benchmark @Benchmark
fun boxingFieldAdd(blackhole: Blackhole) = with(genericField) { fun boxingFieldAdd(blackhole: Blackhole) = with(genericField) {
var res: StructureND<Double> = one var res: StructureND<Double> = one(shape)
repeat(n) { res += 1.0 } repeat(n) { res += 1.0 }
blackhole.consume(res) blackhole.consume(res)
} }
@ -56,19 +57,20 @@ internal class NDFieldBenchmark {
blackhole.consume(res) blackhole.consume(res)
} }
// @Benchmark @Benchmark
// fun nd4jAdd(blackhole: Blackhole) = with(nd4jField) { fun nd4jAdd(blackhole: Blackhole) = with(nd4jField) {
// var res: StructureND<Double> = one var res: StructureND<Double> = one(dim, dim)
// repeat(n) { res += 1.0 } repeat(n) { res += 1.0 }
// blackhole.consume(res) blackhole.consume(res)
// } }
private companion object { private companion object {
private const val dim = 1000 private const val dim = 1000
private const val n = 100 private const val n = 100
private val autoField = DoubleField.autoNdAlgebra(dim, dim) private val shape = intArrayOf(dim, dim)
private val specializedField = DoubleField.ndAlgebra(dim, dim) private val autoField = BufferedFieldOpsND(DoubleField, Buffer.Companion::auto)
private val genericField = DoubleField.ndAlgebra(Buffer.Companion::boxing, dim, dim) private val specializedField = DoubleField.ndAlgebra
private val nd4jField = DoubleField.nd4j(dim, dim) private val genericField = BufferedFieldOpsND(DoubleField, Buffer.Companion::boxing)
private val nd4jField = DoubleField.nd4j
} }
} }

View File

@ -10,18 +10,17 @@ import kotlinx.benchmark.Blackhole
import kotlinx.benchmark.Scope import kotlinx.benchmark.Scope
import kotlinx.benchmark.State import kotlinx.benchmark.State
import org.jetbrains.bio.viktor.F64Array import org.jetbrains.bio.viktor.F64Array
import space.kscience.kmath.nd.StructureND import space.kscience.kmath.nd.*
import space.kscience.kmath.nd.autoNdAlgebra
import space.kscience.kmath.nd.ndAlgebra
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.viktor.ViktorNDField import space.kscience.kmath.structures.Buffer
import space.kscience.kmath.viktor.ViktorFieldND
@State(Scope.Benchmark) @State(Scope.Benchmark)
internal class ViktorBenchmark { internal class ViktorBenchmark {
@Benchmark @Benchmark
fun automaticFieldAddition(blackhole: Blackhole) { fun automaticFieldAddition(blackhole: Blackhole) {
with(autoField) { with(autoField) {
var res: StructureND<Double> = one var res: StructureND<Double> = one(shape)
repeat(n) { res += 1.0 } repeat(n) { res += 1.0 }
blackhole.consume(res) blackhole.consume(res)
} }
@ -30,7 +29,7 @@ internal class ViktorBenchmark {
@Benchmark @Benchmark
fun realFieldAddition(blackhole: Blackhole) { fun realFieldAddition(blackhole: Blackhole) {
with(realField) { with(realField) {
var res: StructureND<Double> = one var res: StructureND<Double> = one(shape)
repeat(n) { res += 1.0 } repeat(n) { res += 1.0 }
blackhole.consume(res) blackhole.consume(res)
} }
@ -39,7 +38,7 @@ internal class ViktorBenchmark {
@Benchmark @Benchmark
fun viktorFieldAddition(blackhole: Blackhole) { fun viktorFieldAddition(blackhole: Blackhole) {
with(viktorField) { with(viktorField) {
var res = one var res = one(shape)
repeat(n) { res += 1.0 } repeat(n) { res += 1.0 }
blackhole.consume(res) blackhole.consume(res)
} }
@ -56,10 +55,11 @@ internal class ViktorBenchmark {
private companion object { private companion object {
private const val dim = 1000 private const val dim = 1000
private const val n = 100 private const val n = 100
private val shape = Shape(dim, dim)
// automatically build context most suited for given type. // automatically build context most suited for given type.
private val autoField = DoubleField.autoNdAlgebra(dim, dim) private val autoField = BufferedFieldOpsND(DoubleField, Buffer.Companion::auto)
private val realField = DoubleField.ndAlgebra(dim, dim) private val realField = DoubleField.ndAlgebra
private val viktorField = ViktorNDField(dim, dim) private val viktorField = ViktorFieldND(dim, dim)
} }
} }

View File

@ -10,18 +10,21 @@ import kotlinx.benchmark.Blackhole
import kotlinx.benchmark.Scope import kotlinx.benchmark.Scope
import kotlinx.benchmark.State import kotlinx.benchmark.State
import org.jetbrains.bio.viktor.F64Array import org.jetbrains.bio.viktor.F64Array
import space.kscience.kmath.nd.autoNdAlgebra import space.kscience.kmath.nd.BufferedFieldOpsND
import space.kscience.kmath.nd.Shape
import space.kscience.kmath.nd.ndAlgebra import space.kscience.kmath.nd.ndAlgebra
import space.kscience.kmath.nd.one
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.structures.Buffer
import space.kscience.kmath.viktor.ViktorFieldND import space.kscience.kmath.viktor.ViktorFieldND
@State(Scope.Benchmark) @State(Scope.Benchmark)
internal class ViktorLogBenchmark { internal class ViktorLogBenchmark {
@Benchmark @Benchmark
fun realFieldLog(blackhole: Blackhole) { fun realFieldLog(blackhole: Blackhole) {
with(realNdField) { with(realField) {
val fortyTwo = produce { 42.0 } val fortyTwo = produce(shape) { 42.0 }
var res = one var res = one(shape)
repeat(n) { res = ln(fortyTwo) } repeat(n) { res = ln(fortyTwo) }
blackhole.consume(res) blackhole.consume(res)
} }
@ -30,7 +33,7 @@ internal class ViktorLogBenchmark {
@Benchmark @Benchmark
fun viktorFieldLog(blackhole: Blackhole) { fun viktorFieldLog(blackhole: Blackhole) {
with(viktorField) { with(viktorField) {
val fortyTwo = produce { 42.0 } val fortyTwo = produce(shape) { 42.0 }
var res = one var res = one
repeat(n) { res = ln(fortyTwo) } repeat(n) { res = ln(fortyTwo) }
blackhole.consume(res) blackhole.consume(res)
@ -48,10 +51,11 @@ internal class ViktorLogBenchmark {
private companion object { private companion object {
private const val dim = 1000 private const val dim = 1000
private const val n = 100 private const val n = 100
private val shape = Shape(dim, dim)
// automatically build context most suited for given type. // automatically build context most suited for given type.
private val autoField = DoubleField.autoNdAlgebra(dim, dim) private val autoField = BufferedFieldOpsND(DoubleField, Buffer.Companion::auto)
private val realNdField = DoubleField.ndAlgebra(dim, dim) private val realField = DoubleField.ndAlgebra
private val viktorField = ViktorFieldND(intArrayOf(dim, dim)) private val viktorField = ViktorFieldND(dim, dim)
} }
} }

View File

@ -11,27 +11,27 @@ import space.kscience.kmath.complex.bufferAlgebra
import space.kscience.kmath.complex.ndAlgebra import space.kscience.kmath.complex.ndAlgebra
import space.kscience.kmath.nd.BufferND import space.kscience.kmath.nd.BufferND
import space.kscience.kmath.nd.StructureND import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.nd.produce
fun main() = Complex.algebra { fun main() = Complex.algebra {
val complex = 2 + 2 * i val complex = 2 + 2 * i
println(complex * 8 - 5 * i) println(complex * 8 - 5 * i)
//flat buffer //flat buffer
val buffer = bufferAlgebra(8).run { val buffer = with(bufferAlgebra){
buffer { Complex(it, -it) }.map { Complex(it.im, it.re) } buffer(8) { Complex(it, -it) }.map { Complex(it.im, it.re) }
} }
println(buffer) println(buffer)
// 2d element // 2d element
val element: BufferND<Complex> = ndAlgebra(2, 2).produce { (i, j) -> val element: BufferND<Complex> = ndAlgebra.produce(2, 2) { (i, j) ->
Complex(i - j, i + j) Complex(i - j, i + j)
} }
println(element) println(element)
// 1d element operation // 1d element operation
val result: StructureND<Complex> = ndAlgebra(8).run { val result: StructureND<Complex> = ndAlgebra{
val a = produce { (it) -> i * it - it.toDouble() } val a = produce(8) { (it) -> i * it - it.toDouble() }
val b = 3 val b = 3
val c = Complex(1.0, 1.0) val c = Complex(1.0, 1.0)

View File

@ -14,7 +14,7 @@ import space.kscience.kmath.nd.ndAlgebra
import space.kscience.kmath.nd4j.Nd4jArrayField import space.kscience.kmath.nd4j.Nd4jArrayField
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.invoke import space.kscience.kmath.operations.invoke
import space.kscience.kmath.viktor.ViktorNDField import space.kscience.kmath.viktor.ViktorFieldND
import kotlin.contracts.InvocationKind import kotlin.contracts.InvocationKind
import kotlin.contracts.contract import kotlin.contracts.contract
import kotlin.system.measureTimeMillis import kotlin.system.measureTimeMillis
@ -41,7 +41,7 @@ fun main() {
// Nd4j specialized field. // Nd4j specialized field.
val nd4jField = Nd4jArrayField.real(dim, dim) val nd4jField = Nd4jArrayField.real(dim, dim)
//viktor field //viktor field
val viktorField = ViktorNDField(dim, dim) val viktorField = ViktorFieldND(dim, dim)
//parallel processing based on Java Streams //parallel processing based on Java Streams
val parallelField = DoubleField.ndStreaming(dim, dim) val parallelField = DoubleField.ndStreaming(dim, dim)

View File

@ -8,7 +8,7 @@ package space.kscience.kmath.structures
import space.kscience.kmath.nd.* import space.kscience.kmath.nd.*
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.ExtendedField import space.kscience.kmath.operations.ExtendedField
import space.kscience.kmath.operations.NumbersAddOperations import space.kscience.kmath.operations.NumbersAddOps
import java.util.* import java.util.*
import java.util.stream.IntStream import java.util.stream.IntStream
@ -17,11 +17,11 @@ import java.util.stream.IntStream
* execution. * execution.
*/ */
class StreamDoubleFieldND(override val shape: IntArray) : FieldND<Double, DoubleField>, class StreamDoubleFieldND(override val shape: IntArray) : FieldND<Double, DoubleField>,
NumbersAddOperations<StructureND<Double>>, NumbersAddOps<StructureND<Double>>,
ExtendedField<StructureND<Double>> { ExtendedField<StructureND<Double>> {
private val strides = DefaultStrides(shape) private val strides = DefaultStrides(shape)
override val elementContext: DoubleField get() = DoubleField override val elementAlgebra: DoubleField get() = DoubleField
override val zero: BufferND<Double> by lazy { produce { zero } } override val zero: BufferND<Double> by lazy { produce { zero } }
override val one: BufferND<Double> by lazy { produce { one } } override val one: BufferND<Double> by lazy { produce { one } }
@ -36,7 +36,7 @@ class StreamDoubleFieldND(override val shape: IntArray) : FieldND<Double, Double
this@StreamDoubleFieldND.shape, this@StreamDoubleFieldND.shape,
shape shape
) )
this is BufferND && this.strides == this@StreamDoubleFieldND.strides -> this.buffer as DoubleBuffer this is BufferND && this.indexes == this@StreamDoubleFieldND.strides -> this.buffer as DoubleBuffer
else -> DoubleBuffer(strides.linearSize) { offset -> get(strides.index(offset)) } else -> DoubleBuffer(strides.linearSize) { offset -> get(strides.index(offset)) }
} }
@ -69,7 +69,7 @@ class StreamDoubleFieldND(override val shape: IntArray) : FieldND<Double, Double
return BufferND(strides, array.asBuffer()) return BufferND(strides, array.asBuffer())
} }
override fun combine( override fun zip(
a: StructureND<Double>, a: StructureND<Double>,
b: StructureND<Double>, b: StructureND<Double>,
transform: DoubleField.(Double, Double) -> Double, transform: DoubleField.(Double, Double) -> Double,

View File

@ -8,6 +8,7 @@ package space.kscience.kmath.structures
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.buffer import space.kscience.kmath.operations.buffer
import space.kscience.kmath.operations.bufferAlgebra import space.kscience.kmath.operations.bufferAlgebra
import space.kscience.kmath.operations.withSize
inline fun <reified R : Any> MutableBuffer.Companion.same( inline fun <reified R : Any> MutableBuffer.Companion.same(
n: Int, n: Int,
@ -16,7 +17,7 @@ inline fun <reified R : Any> MutableBuffer.Companion.same(
fun main() { fun main() {
with(DoubleField.bufferAlgebra(5)) { with(DoubleField.bufferAlgebra.withSize(5)) {
println(number(2.0) + buffer(1, 2, 3, 4, 5)) println(number(2.0) + buffer(1, 2, 3, 4, 5))
} }
} }

View File

@ -1,5 +1,5 @@
distributionBase=GRADLE_USER_HOME distributionBase=GRADLE_USER_HOME
distributionPath=wrapper/dists distributionPath=wrapper/dists
distributionUrl=https\://services.gradle.org/distributions/gradle-7.1.1-bin.zip distributionUrl=https\://services.gradle.org/distributions/gradle-7.2-bin.zip
zipStoreBase=GRADLE_USER_HOME zipStoreBase=GRADLE_USER_HOME
zipStorePath=wrapper/dists zipStorePath=wrapper/dists

View File

@ -18,10 +18,10 @@ import com.github.h0tk3y.betterParse.parser.ParseResult
import com.github.h0tk3y.betterParse.parser.Parser import com.github.h0tk3y.betterParse.parser.Parser
import space.kscience.kmath.expressions.MST import space.kscience.kmath.expressions.MST
import space.kscience.kmath.expressions.Symbol import space.kscience.kmath.expressions.Symbol
import space.kscience.kmath.operations.FieldOperations import space.kscience.kmath.operations.FieldOps
import space.kscience.kmath.operations.GroupOperations import space.kscience.kmath.operations.GroupOps
import space.kscience.kmath.operations.PowerOperations import space.kscience.kmath.operations.PowerOperations
import space.kscience.kmath.operations.RingOperations import space.kscience.kmath.operations.RingOps
/** /**
* better-parse implementation of grammar defined in the ArithmeticsEvaluator.g4. * better-parse implementation of grammar defined in the ArithmeticsEvaluator.g4.
@ -60,7 +60,7 @@ public object ArithmeticsEvaluator : Grammar<MST>() {
.or(binaryFunction) .or(binaryFunction)
.or(unaryFunction) .or(unaryFunction)
.or(singular) .or(singular)
.or(-minus and parser(ArithmeticsEvaluator::term) map { MST.Unary(GroupOperations.MINUS_OPERATION, it) }) .or(-minus and parser(ArithmeticsEvaluator::term) map { MST.Unary(GroupOps.MINUS_OPERATION, it) })
.or(-lpar and parser(ArithmeticsEvaluator::subSumChain) and -rpar) .or(-lpar and parser(ArithmeticsEvaluator::subSumChain) and -rpar)
private val powChain: Parser<MST> by leftAssociative(term = term, operator = pow) { a, _, b -> private val powChain: Parser<MST> by leftAssociative(term = term, operator = pow) { a, _, b ->
@ -72,9 +72,9 @@ public object ArithmeticsEvaluator : Grammar<MST>() {
operator = div or mul use TokenMatch::type operator = div or mul use TokenMatch::type
) { a, op, b -> ) { a, op, b ->
if (op == div) if (op == div)
MST.Binary(FieldOperations.DIV_OPERATION, a, b) MST.Binary(FieldOps.DIV_OPERATION, a, b)
else else
MST.Binary(RingOperations.TIMES_OPERATION, a, b) MST.Binary(RingOps.TIMES_OPERATION, a, b)
} }
private val subSumChain: Parser<MST> by leftAssociative( private val subSumChain: Parser<MST> by leftAssociative(
@ -82,9 +82,9 @@ public object ArithmeticsEvaluator : Grammar<MST>() {
operator = plus or minus use TokenMatch::type operator = plus or minus use TokenMatch::type
) { a, op, b -> ) { a, op, b ->
if (op == plus) if (op == plus)
MST.Binary(GroupOperations.PLUS_OPERATION, a, b) MST.Binary(GroupOps.PLUS_OPERATION, a, b)
else else
MST.Binary(GroupOperations.MINUS_OPERATION, a, b) MST.Binary(GroupOps.MINUS_OPERATION, a, b)
} }
override val rootParser: Parser<MST> by subSumChain override val rootParser: Parser<MST> by subSumChain

View File

@ -39,7 +39,7 @@ public val PrintNumeric: RenderFeature = RenderFeature { _, node ->
@UnstableKMathAPI @UnstableKMathAPI
private fun printSignedNumberString(s: String): MathSyntax = if (s.startsWith('-')) private fun printSignedNumberString(s: String): MathSyntax = if (s.startsWith('-'))
UnaryMinusSyntax( UnaryMinusSyntax(
operation = GroupOperations.MINUS_OPERATION, operation = GroupOps.MINUS_OPERATION,
operand = OperandSyntax( operand = OperandSyntax(
operand = NumberSyntax(string = s.removePrefix("-")), operand = NumberSyntax(string = s.removePrefix("-")),
parentheses = true, parentheses = true,
@ -72,7 +72,7 @@ public class PrettyPrintFloats(public val types: Set<KClass<out Number>>) : Rend
val exponent = afterE.toDouble().toString().removeSuffix(".0") val exponent = afterE.toDouble().toString().removeSuffix(".0")
return MultiplicationSyntax( return MultiplicationSyntax(
operation = RingOperations.TIMES_OPERATION, operation = RingOps.TIMES_OPERATION,
left = OperandSyntax(operand = NumberSyntax(significand), parentheses = true), left = OperandSyntax(operand = NumberSyntax(significand), parentheses = true),
right = OperandSyntax( right = OperandSyntax(
operand = SuperscriptSyntax( operand = SuperscriptSyntax(
@ -91,7 +91,7 @@ public class PrettyPrintFloats(public val types: Set<KClass<out Number>>) : Rend
if (toString.startsWith('-')) if (toString.startsWith('-'))
return UnaryMinusSyntax( return UnaryMinusSyntax(
operation = GroupOperations.MINUS_OPERATION, operation = GroupOps.MINUS_OPERATION,
operand = OperandSyntax(operand = infty, parentheses = true), operand = OperandSyntax(operand = infty, parentheses = true),
) )
@ -211,9 +211,9 @@ public class BinaryPlus(operations: Collection<String>?) : Binary(operations) {
public companion object { public companion object {
/** /**
* The default instance configured with [GroupOperations.PLUS_OPERATION]. * The default instance configured with [GroupOps.PLUS_OPERATION].
*/ */
public val Default: BinaryPlus = BinaryPlus(setOf(GroupOperations.PLUS_OPERATION)) public val Default: BinaryPlus = BinaryPlus(setOf(GroupOps.PLUS_OPERATION))
} }
} }
@ -233,9 +233,9 @@ public class BinaryMinus(operations: Collection<String>?) : Binary(operations) {
public companion object { public companion object {
/** /**
* The default instance configured with [GroupOperations.MINUS_OPERATION]. * The default instance configured with [GroupOps.MINUS_OPERATION].
*/ */
public val Default: BinaryMinus = BinaryMinus(setOf(GroupOperations.MINUS_OPERATION)) public val Default: BinaryMinus = BinaryMinus(setOf(GroupOps.MINUS_OPERATION))
} }
} }
@ -253,9 +253,9 @@ public class UnaryPlus(operations: Collection<String>?) : Unary(operations) {
public companion object { public companion object {
/** /**
* The default instance configured with [GroupOperations.PLUS_OPERATION]. * The default instance configured with [GroupOps.PLUS_OPERATION].
*/ */
public val Default: UnaryPlus = UnaryPlus(setOf(GroupOperations.PLUS_OPERATION)) public val Default: UnaryPlus = UnaryPlus(setOf(GroupOps.PLUS_OPERATION))
} }
} }
@ -273,9 +273,9 @@ public class UnaryMinus(operations: Collection<String>?) : Unary(operations) {
public companion object { public companion object {
/** /**
* The default instance configured with [GroupOperations.MINUS_OPERATION]. * The default instance configured with [GroupOps.MINUS_OPERATION].
*/ */
public val Default: UnaryMinus = UnaryMinus(setOf(GroupOperations.MINUS_OPERATION)) public val Default: UnaryMinus = UnaryMinus(setOf(GroupOps.MINUS_OPERATION))
} }
} }
@ -295,9 +295,9 @@ public class Fraction(operations: Collection<String>?) : Binary(operations) {
public companion object { public companion object {
/** /**
* The default instance configured with [FieldOperations.DIV_OPERATION]. * The default instance configured with [FieldOps.DIV_OPERATION].
*/ */
public val Default: Fraction = Fraction(setOf(FieldOperations.DIV_OPERATION)) public val Default: Fraction = Fraction(setOf(FieldOps.DIV_OPERATION))
} }
} }
@ -422,9 +422,9 @@ public class Multiplication(operations: Collection<String>?) : Binary(operations
public companion object { public companion object {
/** /**
* The default instance configured with [RingOperations.TIMES_OPERATION]. * The default instance configured with [RingOps.TIMES_OPERATION].
*/ */
public val Default: Multiplication = Multiplication(setOf(RingOperations.TIMES_OPERATION)) public val Default: Multiplication = Multiplication(setOf(RingOps.TIMES_OPERATION))
} }
} }

View File

@ -7,10 +7,10 @@ package space.kscience.kmath.ast.rendering
import space.kscience.kmath.ast.rendering.FeaturedMathRendererWithPostProcess.PostProcessPhase import space.kscience.kmath.ast.rendering.FeaturedMathRendererWithPostProcess.PostProcessPhase
import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.FieldOperations import space.kscience.kmath.operations.FieldOps
import space.kscience.kmath.operations.GroupOperations import space.kscience.kmath.operations.GroupOps
import space.kscience.kmath.operations.PowerOperations import space.kscience.kmath.operations.PowerOperations
import space.kscience.kmath.operations.RingOperations import space.kscience.kmath.operations.RingOps
/** /**
* Removes unnecessary times (&times;) symbols from [MultiplicationSyntax]. * Removes unnecessary times (&times;) symbols from [MultiplicationSyntax].
@ -306,10 +306,10 @@ public class SimplifyParentheses(public val precedenceFunction: (MathSyntax) ->
is BinarySyntax -> when (it.operation) { is BinarySyntax -> when (it.operation) {
PowerOperations.POW_OPERATION -> 1 PowerOperations.POW_OPERATION -> 1
RingOperations.TIMES_OPERATION -> 3 RingOps.TIMES_OPERATION -> 3
FieldOperations.DIV_OPERATION -> 3 FieldOps.DIV_OPERATION -> 3
GroupOperations.MINUS_OPERATION -> 4 GroupOps.MINUS_OPERATION -> 4
GroupOperations.PLUS_OPERATION -> 4 GroupOps.PLUS_OPERATION -> 4
else -> 0 else -> 0
} }

View File

@ -7,7 +7,7 @@ package space.kscience.kmath.ast.rendering
import space.kscience.kmath.ast.rendering.TestUtils.testLatex import space.kscience.kmath.ast.rendering.TestUtils.testLatex
import space.kscience.kmath.expressions.MST import space.kscience.kmath.expressions.MST
import space.kscience.kmath.operations.GroupOperations import space.kscience.kmath.operations.GroupOps
import kotlin.test.Test import kotlin.test.Test
internal class TestLatex { internal class TestLatex {
@ -36,7 +36,7 @@ internal class TestLatex {
fun unaryOperator() = testLatex("sin(1)", "\\operatorname{sin}\\,\\left(1\\right)") fun unaryOperator() = testLatex("sin(1)", "\\operatorname{sin}\\,\\left(1\\right)")
@Test @Test
fun unaryPlus() = testLatex(MST.Unary(GroupOperations.PLUS_OPERATION, MST.Numeric(1)), "+1") fun unaryPlus() = testLatex(MST.Unary(GroupOps.PLUS_OPERATION, MST.Numeric(1)), "+1")
@Test @Test
fun unaryMinus() = testLatex("-x", "-x") fun unaryMinus() = testLatex("-x", "-x")

View File

@ -7,7 +7,7 @@ package space.kscience.kmath.ast.rendering
import space.kscience.kmath.ast.rendering.TestUtils.testMathML import space.kscience.kmath.ast.rendering.TestUtils.testMathML
import space.kscience.kmath.expressions.MST import space.kscience.kmath.expressions.MST
import space.kscience.kmath.operations.GroupOperations import space.kscience.kmath.operations.GroupOps
import kotlin.test.Test import kotlin.test.Test
internal class TestMathML { internal class TestMathML {
@ -47,7 +47,7 @@ internal class TestMathML {
@Test @Test
fun unaryPlus() = fun unaryPlus() =
testMathML(MST.Unary(GroupOperations.PLUS_OPERATION, MST.Numeric(1)), "<mo>+</mo><mn>1</mn>") testMathML(MST.Unary(GroupOps.PLUS_OPERATION, MST.Numeric(1)), "<mo>+</mo><mn>1</mn>")
@Test @Test
fun unaryMinus() = testMathML("-x", "<mo>-</mo><mi>x</mi>") fun unaryMinus() = testMathML("-x", "<mo>-</mo><mi>x</mi>")

View File

@ -108,8 +108,8 @@ internal class DoubleWasmBuilder(target: MST) : WasmBuilder<Double>(f64, DoubleF
override fun visitNumeric(mst: Numeric): ExpressionRef = ctx.f64.const(mst.value) override fun visitNumeric(mst: Numeric): ExpressionRef = ctx.f64.const(mst.value)
override fun visitUnary(mst: Unary): ExpressionRef = when (mst.operation) { override fun visitUnary(mst: Unary): ExpressionRef = when (mst.operation) {
GroupOperations.MINUS_OPERATION -> ctx.f64.neg(visit(mst.value)) GroupOps.MINUS_OPERATION -> ctx.f64.neg(visit(mst.value))
GroupOperations.PLUS_OPERATION -> visit(mst.value) GroupOps.PLUS_OPERATION -> visit(mst.value)
PowerOperations.SQRT_OPERATION -> ctx.f64.sqrt(visit(mst.value)) PowerOperations.SQRT_OPERATION -> ctx.f64.sqrt(visit(mst.value))
TrigonometricOperations.SIN_OPERATION -> ctx.call("sin", arrayOf(visit(mst.value)), f64) TrigonometricOperations.SIN_OPERATION -> ctx.call("sin", arrayOf(visit(mst.value)), f64)
TrigonometricOperations.COS_OPERATION -> ctx.call("cos", arrayOf(visit(mst.value)), f64) TrigonometricOperations.COS_OPERATION -> ctx.call("cos", arrayOf(visit(mst.value)), f64)
@ -129,10 +129,10 @@ internal class DoubleWasmBuilder(target: MST) : WasmBuilder<Double>(f64, DoubleF
} }
override fun visitBinary(mst: Binary): ExpressionRef = when (mst.operation) { override fun visitBinary(mst: Binary): ExpressionRef = when (mst.operation) {
GroupOperations.PLUS_OPERATION -> ctx.f64.add(visit(mst.left), visit(mst.right)) GroupOps.PLUS_OPERATION -> ctx.f64.add(visit(mst.left), visit(mst.right))
GroupOperations.MINUS_OPERATION -> ctx.f64.sub(visit(mst.left), visit(mst.right)) GroupOps.MINUS_OPERATION -> ctx.f64.sub(visit(mst.left), visit(mst.right))
RingOperations.TIMES_OPERATION -> ctx.f64.mul(visit(mst.left), visit(mst.right)) RingOps.TIMES_OPERATION -> ctx.f64.mul(visit(mst.left), visit(mst.right))
FieldOperations.DIV_OPERATION -> ctx.f64.div(visit(mst.left), visit(mst.right)) FieldOps.DIV_OPERATION -> ctx.f64.div(visit(mst.left), visit(mst.right))
PowerOperations.POW_OPERATION -> ctx.call("pow", arrayOf(visit(mst.left), visit(mst.right)), f64) PowerOperations.POW_OPERATION -> ctx.call("pow", arrayOf(visit(mst.left), visit(mst.right)), f64)
else -> super.visitBinary(mst) else -> super.visitBinary(mst)
} }
@ -142,15 +142,15 @@ internal class IntWasmBuilder(target: MST) : WasmBuilder<Int>(i32, IntRing, targ
override fun visitNumeric(mst: Numeric): ExpressionRef = ctx.i32.const(mst.value) override fun visitNumeric(mst: Numeric): ExpressionRef = ctx.i32.const(mst.value)
override fun visitUnary(mst: Unary): ExpressionRef = when (mst.operation) { override fun visitUnary(mst: Unary): ExpressionRef = when (mst.operation) {
GroupOperations.MINUS_OPERATION -> ctx.i32.sub(ctx.i32.const(0), visit(mst.value)) GroupOps.MINUS_OPERATION -> ctx.i32.sub(ctx.i32.const(0), visit(mst.value))
GroupOperations.PLUS_OPERATION -> visit(mst.value) GroupOps.PLUS_OPERATION -> visit(mst.value)
else -> super.visitUnary(mst) else -> super.visitUnary(mst)
} }
override fun visitBinary(mst: Binary): ExpressionRef = when (mst.operation) { override fun visitBinary(mst: Binary): ExpressionRef = when (mst.operation) {
GroupOperations.PLUS_OPERATION -> ctx.i32.add(visit(mst.left), visit(mst.right)) GroupOps.PLUS_OPERATION -> ctx.i32.add(visit(mst.left), visit(mst.right))
GroupOperations.MINUS_OPERATION -> ctx.i32.sub(visit(mst.left), visit(mst.right)) GroupOps.MINUS_OPERATION -> ctx.i32.sub(visit(mst.left), visit(mst.right))
RingOperations.TIMES_OPERATION -> ctx.i32.mul(visit(mst.left), visit(mst.right)) RingOps.TIMES_OPERATION -> ctx.i32.mul(visit(mst.left), visit(mst.right))
else -> super.visitBinary(mst) else -> super.visitBinary(mst)
} }
} }

View File

@ -9,7 +9,7 @@ import org.apache.commons.math3.analysis.differentiation.DerivativeStructure
import space.kscience.kmath.expressions.* import space.kscience.kmath.expressions.*
import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.ExtendedField import space.kscience.kmath.operations.ExtendedField
import space.kscience.kmath.operations.NumbersAddOperations import space.kscience.kmath.operations.NumbersAddOps
/** /**
* A field over commons-math [DerivativeStructure]. * A field over commons-math [DerivativeStructure].
@ -22,7 +22,7 @@ public class DerivativeStructureField(
public val order: Int, public val order: Int,
bindings: Map<Symbol, Double>, bindings: Map<Symbol, Double>,
) : ExtendedField<DerivativeStructure>, ExpressionAlgebra<Double, DerivativeStructure>, ) : ExtendedField<DerivativeStructure>, ExpressionAlgebra<Double, DerivativeStructure>,
NumbersAddOperations<DerivativeStructure> { NumbersAddOps<DerivativeStructure> {
public val numberOfVariables: Int = bindings.size public val numberOfVariables: Int = bindings.size
override val zero: DerivativeStructure by lazy { DerivativeStructure(numberOfVariables, order) } override val zero: DerivativeStructure by lazy { DerivativeStructure(numberOfVariables, order) }

View File

@ -52,7 +52,7 @@ private val PI_DIV_2 = Complex(PI / 2, 0)
public object ComplexField : public object ComplexField :
ExtendedField<Complex>, ExtendedField<Complex>,
Norm<Complex, Complex>, Norm<Complex, Complex>,
NumbersAddOperations<Complex>, NumbersAddOps<Complex>,
ScaleOperations<Complex> { ScaleOperations<Complex> {
override val zero: Complex = 0.0.toComplex() override val zero: Complex = 0.0.toComplex()
@ -216,7 +216,6 @@ public data class Complex(val re: Double, val im: Double) {
public val Complex.Companion.algebra: ComplexField get() = ComplexField public val Complex.Companion.algebra: ComplexField get() = ComplexField
/** /**
* Creates a complex number with real part equal to this real. * Creates a complex number with real part equal to this real.
* *

View File

@ -6,13 +6,8 @@
package space.kscience.kmath.complex package space.kscience.kmath.complex
import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.nd.BufferND import space.kscience.kmath.nd.*
import space.kscience.kmath.nd.BufferedFieldND import space.kscience.kmath.operations.*
import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.operations.BufferField
import space.kscience.kmath.operations.ExtendedField
import space.kscience.kmath.operations.NumbersAddOperations
import space.kscience.kmath.operations.bufferAlgebra
import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.Buffer
import kotlin.contracts.InvocationKind import kotlin.contracts.InvocationKind
import kotlin.contracts.contract import kotlin.contracts.contract
@ -22,100 +17,61 @@ import kotlin.contracts.contract
* An optimized nd-field for complex numbers * An optimized nd-field for complex numbers
*/ */
@OptIn(UnstableKMathAPI::class) @OptIn(UnstableKMathAPI::class)
public class ComplexFieldND( public sealed class ComplexFieldOpsND : BufferedFieldOpsND<Complex, ComplexField>(ComplexField.bufferAlgebra),
shape: IntArray, ScaleOperations<StructureND<Complex>>, ExtendedFieldOps<StructureND<Complex>> {
) : BufferedFieldND<Complex, ComplexField>(shape, ComplexField, Buffer.Companion::complex),
NumbersAddOperations<StructureND<Complex>>,
ExtendedField<StructureND<Complex>> {
override val zero: BufferND<Complex> by lazy { produce { zero } } override fun StructureND<Complex>.toBufferND(): BufferND<Complex> = when (this) {
override val one: BufferND<Complex> by lazy { produce { one } } is BufferND -> this
else -> {
override fun number(value: Number): BufferND<Complex> { val indexer = indexerBuilder(shape)
val d = value.toComplex() // minimize conversions BufferND(indexer, Buffer.complex(indexer.linearSize) { offset -> get(indexer.index(offset)) })
return produce { d } }
} }
// //TODO do specialization
// @Suppress("OVERRIDE_BY_INLINE")
// override inline fun map(
// arg: AbstractNDBuffer<Double>,
// transform: DoubleField.(Double) -> Double,
// ): RealNDElement {
// check(arg)
// val array = RealBuffer(arg.strides.linearSize) { offset -> DoubleField.transform(arg.buffer[offset]) }
// return BufferedNDFieldElement(this, array)
// }
//
// @Suppress("OVERRIDE_BY_INLINE")
// override inline fun produce(initializer: DoubleField.(IntArray) -> Double): RealNDElement {
// val array = RealBuffer(strides.linearSize) { offset -> elementContext.initializer(strides.index(offset)) }
// return BufferedNDFieldElement(this, array)
// }
//
// @Suppress("OVERRIDE_BY_INLINE")
// override inline fun mapIndexed(
// arg: AbstractNDBuffer<Double>,
// transform: DoubleField.(index: IntArray, Double) -> Double,
// ): RealNDElement {
// check(arg)
// return BufferedNDFieldElement(
// this,
// RealBuffer(arg.strides.linearSize) { offset ->
// elementContext.transform(
// arg.strides.index(offset),
// arg.buffer[offset]
// )
// })
// }
//
// @Suppress("OVERRIDE_BY_INLINE")
// override inline fun combine(
// a: AbstractNDBuffer<Double>,
// b: AbstractNDBuffer<Double>,
// transform: DoubleField.(Double, Double) -> Double,
// ): RealNDElement {
// check(a, b)
// val buffer = RealBuffer(strides.linearSize) { offset ->
// elementContext.transform(a.buffer[offset], b.buffer[offset])
// }
// return BufferedNDFieldElement(this, buffer)
// }
override fun power(arg: StructureND<Complex>, pow: Number): BufferND<Complex> = arg.map { power(it, pow) } override fun scale(a: StructureND<Complex>, value: Double): BufferND<Complex> =
mapInline(a.toBufferND()) { it * value }
override fun exp(arg: StructureND<Complex>): BufferND<Complex> = arg.map { exp(it) } override fun power(arg: StructureND<Complex>, pow: Number): BufferND<Complex> =
mapInline(arg.toBufferND()) { power(it, pow) }
override fun ln(arg: StructureND<Complex>): BufferND<Complex> = arg.map { ln(it) } override fun exp(arg: StructureND<Complex>): BufferND<Complex> = mapInline(arg.toBufferND()) { exp(it) }
override fun ln(arg: StructureND<Complex>): BufferND<Complex> = mapInline(arg.toBufferND()) { ln(it) }
override fun sin(arg: StructureND<Complex>): BufferND<Complex> = arg.map { sin(it) } override fun sin(arg: StructureND<Complex>): BufferND<Complex> = mapInline(arg.toBufferND()) { sin(it) }
override fun cos(arg: StructureND<Complex>): BufferND<Complex> = arg.map { cos(it) } override fun cos(arg: StructureND<Complex>): BufferND<Complex> = mapInline(arg.toBufferND()) { cos(it) }
override fun tan(arg: StructureND<Complex>): BufferND<Complex> = arg.map { tan(it) } override fun tan(arg: StructureND<Complex>): BufferND<Complex> = mapInline(arg.toBufferND()) { tan(it) }
override fun asin(arg: StructureND<Complex>): BufferND<Complex> = arg.map { asin(it) } override fun asin(arg: StructureND<Complex>): BufferND<Complex> = mapInline(arg.toBufferND()) { asin(it) }
override fun acos(arg: StructureND<Complex>): BufferND<Complex> = arg.map { acos(it) } override fun acos(arg: StructureND<Complex>): BufferND<Complex> = mapInline(arg.toBufferND()) { acos(it) }
override fun atan(arg: StructureND<Complex>): BufferND<Complex> = arg.map { atan(it) } override fun atan(arg: StructureND<Complex>): BufferND<Complex> = mapInline(arg.toBufferND()) { atan(it) }
override fun sinh(arg: StructureND<Complex>): BufferND<Complex> = arg.map { sinh(it) } override fun sinh(arg: StructureND<Complex>): BufferND<Complex> = mapInline(arg.toBufferND()) { sinh(it) }
override fun cosh(arg: StructureND<Complex>): BufferND<Complex> = arg.map { cosh(it) } override fun cosh(arg: StructureND<Complex>): BufferND<Complex> = mapInline(arg.toBufferND()) { cosh(it) }
override fun tanh(arg: StructureND<Complex>): BufferND<Complex> = arg.map { tanh(it) } override fun tanh(arg: StructureND<Complex>): BufferND<Complex> = mapInline(arg.toBufferND()) { tanh(it) }
override fun asinh(arg: StructureND<Complex>): BufferND<Complex> = arg.map { asinh(it) } override fun asinh(arg: StructureND<Complex>): BufferND<Complex> = mapInline(arg.toBufferND()) { asinh(it) }
override fun acosh(arg: StructureND<Complex>): BufferND<Complex> = arg.map { acosh(it) } override fun acosh(arg: StructureND<Complex>): BufferND<Complex> = mapInline(arg.toBufferND()) { acosh(it) }
override fun atanh(arg: StructureND<Complex>): BufferND<Complex> = arg.map { atanh(it) } override fun atanh(arg: StructureND<Complex>): BufferND<Complex> = mapInline(arg.toBufferND()) { atanh(it) }
}
public companion object : ComplexFieldOpsND()
/**
* Fast element production using function inlining
*/
public inline fun BufferedFieldND<Complex, ComplexField>.produceInline(initializer: ComplexField.(Int) -> Complex): BufferND<Complex> {
contract { callsInPlace(initializer, InvocationKind.EXACTLY_ONCE) }
val buffer = Buffer.complex(strides.linearSize) { offset -> ComplexField.initializer(offset) }
return BufferND(strides, buffer)
} }
@UnstableKMathAPI @UnstableKMathAPI
public fun ComplexField.bufferAlgebra(size: Int): BufferField<Complex, ComplexField> = public val ComplexField.bufferAlgebra: BufferFieldOps<Complex, ComplexField>
bufferAlgebra(Buffer.Companion::complex, size) get() = bufferAlgebra(Buffer.Companion::complex)
@OptIn(UnstableKMathAPI::class)
public class ComplexFieldND(override val shape: Shape) :
ComplexFieldOpsND(), FieldND<Complex, ComplexField>, NumbersAddOps<StructureND<Complex>> {
override fun number(value: Number): BufferND<Complex> {
val d = value.toDouble() // minimize conversions
return produce(shape) { d.toComplex() }
}
}
public val ComplexField.ndAlgebra: ComplexFieldOpsND get() = ComplexFieldOpsND
public fun ComplexField.ndAlgebra(vararg shape: Int): ComplexFieldND = ComplexFieldND(shape) public fun ComplexField.ndAlgebra(vararg shape: Int): ComplexFieldND = ComplexFieldND(shape)

View File

@ -44,7 +44,7 @@ public val Quaternion.r: Double
*/ */
@OptIn(UnstableKMathAPI::class) @OptIn(UnstableKMathAPI::class)
public object QuaternionField : Field<Quaternion>, Norm<Quaternion, Quaternion>, PowerOperations<Quaternion>, public object QuaternionField : Field<Quaternion>, Norm<Quaternion, Quaternion>, PowerOperations<Quaternion>,
ExponentialOperations<Quaternion>, NumbersAddOperations<Quaternion>, ScaleOperations<Quaternion> { ExponentialOperations<Quaternion>, NumbersAddOps<Quaternion>, ScaleOperations<Quaternion> {
override val zero: Quaternion = 0.toQuaternion() override val zero: Quaternion = 0.toQuaternion()
override val one: Quaternion = 1.toQuaternion() override val one: Quaternion = 1.toQuaternion()

View File

@ -52,13 +52,13 @@ public open class FunctionalExpressionGroup<T, out A : Group<T>>(
override val zero: Expression<T> get() = const(algebra.zero) override val zero: Expression<T> get() = const(algebra.zero)
override fun Expression<T>.unaryMinus(): Expression<T> = override fun Expression<T>.unaryMinus(): Expression<T> =
unaryOperation(GroupOperations.MINUS_OPERATION, this) unaryOperation(GroupOps.MINUS_OPERATION, this)
/** /**
* Builds an Expression of addition of two another expressions. * Builds an Expression of addition of two another expressions.
*/ */
override fun add(a: Expression<T>, b: Expression<T>): Expression<T> = override fun add(a: Expression<T>, b: Expression<T>): Expression<T> =
binaryOperation(GroupOperations.PLUS_OPERATION, a, b) binaryOperation(GroupOps.PLUS_OPERATION, a, b)
// /** // /**
// * Builds an Expression of multiplication of expression by number. // * Builds an Expression of multiplication of expression by number.
@ -89,7 +89,7 @@ public open class FunctionalExpressionRing<T, out A : Ring<T>>(
* Builds an Expression of multiplication of two expressions. * Builds an Expression of multiplication of two expressions.
*/ */
override fun multiply(a: Expression<T>, b: Expression<T>): Expression<T> = override fun multiply(a: Expression<T>, b: Expression<T>): Expression<T> =
binaryOperationFunction(RingOperations.TIMES_OPERATION)(a, b) binaryOperationFunction(RingOps.TIMES_OPERATION)(a, b)
public operator fun Expression<T>.times(arg: T): Expression<T> = this * const(arg) public operator fun Expression<T>.times(arg: T): Expression<T> = this * const(arg)
public operator fun T.times(arg: Expression<T>): Expression<T> = arg * this public operator fun T.times(arg: Expression<T>): Expression<T> = arg * this
@ -108,7 +108,7 @@ public open class FunctionalExpressionField<T, out A : Field<T>>(
* Builds an Expression of division an expression by another one. * Builds an Expression of division an expression by another one.
*/ */
override fun divide(a: Expression<T>, b: Expression<T>): Expression<T> = override fun divide(a: Expression<T>, b: Expression<T>): Expression<T> =
binaryOperationFunction(FieldOperations.DIV_OPERATION)(a, b) binaryOperationFunction(FieldOps.DIV_OPERATION)(a, b)
public operator fun Expression<T>.div(arg: T): Expression<T> = this / const(arg) public operator fun Expression<T>.div(arg: T): Expression<T> = this / const(arg)
public operator fun T.div(arg: Expression<T>): Expression<T> = arg / this public operator fun T.div(arg: Expression<T>): Expression<T> = arg / this

View File

@ -31,18 +31,18 @@ public object MstGroup : Group<MST>, NumericAlgebra<MST>, ScaleOperations<MST> {
override fun number(value: Number): MST.Numeric = MstNumericAlgebra.number(value) override fun number(value: Number): MST.Numeric = MstNumericAlgebra.number(value)
override fun bindSymbolOrNull(value: String): Symbol = MstNumericAlgebra.bindSymbolOrNull(value) override fun bindSymbolOrNull(value: String): Symbol = MstNumericAlgebra.bindSymbolOrNull(value)
override fun add(a: MST, b: MST): MST.Binary = binaryOperationFunction(GroupOperations.PLUS_OPERATION)(a, b) override fun add(a: MST, b: MST): MST.Binary = binaryOperationFunction(GroupOps.PLUS_OPERATION)(a, b)
override operator fun MST.unaryPlus(): MST.Unary = override operator fun MST.unaryPlus(): MST.Unary =
unaryOperationFunction(GroupOperations.PLUS_OPERATION)(this) unaryOperationFunction(GroupOps.PLUS_OPERATION)(this)
override operator fun MST.unaryMinus(): MST.Unary = override operator fun MST.unaryMinus(): MST.Unary =
unaryOperationFunction(GroupOperations.MINUS_OPERATION)(this) unaryOperationFunction(GroupOps.MINUS_OPERATION)(this)
override operator fun MST.minus(b: MST): MST.Binary = override operator fun MST.minus(b: MST): MST.Binary =
binaryOperationFunction(GroupOperations.MINUS_OPERATION)(this, b) binaryOperationFunction(GroupOps.MINUS_OPERATION)(this, b)
override fun scale(a: MST, value: Double): MST.Binary = override fun scale(a: MST, value: Double): MST.Binary =
binaryOperationFunction(RingOperations.TIMES_OPERATION)(a, number(value)) binaryOperationFunction(RingOps.TIMES_OPERATION)(a, number(value))
override fun binaryOperationFunction(operation: String): (left: MST, right: MST) -> MST.Binary = override fun binaryOperationFunction(operation: String): (left: MST, right: MST) -> MST.Binary =
MstNumericAlgebra.binaryOperationFunction(operation) MstNumericAlgebra.binaryOperationFunction(operation)
@ -56,7 +56,7 @@ public object MstGroup : Group<MST>, NumericAlgebra<MST>, ScaleOperations<MST> {
*/ */
@Suppress("OVERRIDE_BY_INLINE") @Suppress("OVERRIDE_BY_INLINE")
@OptIn(UnstableKMathAPI::class) @OptIn(UnstableKMathAPI::class)
public object MstRing : Ring<MST>, NumbersAddOperations<MST>, ScaleOperations<MST> { public object MstRing : Ring<MST>, NumbersAddOps<MST>, ScaleOperations<MST> {
override inline val zero: MST.Numeric get() = MstGroup.zero override inline val zero: MST.Numeric get() = MstGroup.zero
override val one: MST.Numeric = number(1.0) override val one: MST.Numeric = number(1.0)
@ -65,10 +65,10 @@ public object MstRing : Ring<MST>, NumbersAddOperations<MST>, ScaleOperations<MS
override fun add(a: MST, b: MST): MST.Binary = MstGroup.add(a, b) override fun add(a: MST, b: MST): MST.Binary = MstGroup.add(a, b)
override fun scale(a: MST, value: Double): MST.Binary = override fun scale(a: MST, value: Double): MST.Binary =
MstGroup.binaryOperationFunction(RingOperations.TIMES_OPERATION)(a, MstGroup.number(value)) MstGroup.binaryOperationFunction(RingOps.TIMES_OPERATION)(a, MstGroup.number(value))
override fun multiply(a: MST, b: MST): MST.Binary = override fun multiply(a: MST, b: MST): MST.Binary =
binaryOperationFunction(RingOperations.TIMES_OPERATION)(a, b) binaryOperationFunction(RingOps.TIMES_OPERATION)(a, b)
override operator fun MST.unaryPlus(): MST.Unary = MstGroup { +this@unaryPlus } override operator fun MST.unaryPlus(): MST.Unary = MstGroup { +this@unaryPlus }
override operator fun MST.unaryMinus(): MST.Unary = MstGroup { -this@unaryMinus } override operator fun MST.unaryMinus(): MST.Unary = MstGroup { -this@unaryMinus }
@ -86,7 +86,7 @@ public object MstRing : Ring<MST>, NumbersAddOperations<MST>, ScaleOperations<MS
*/ */
@Suppress("OVERRIDE_BY_INLINE") @Suppress("OVERRIDE_BY_INLINE")
@OptIn(UnstableKMathAPI::class) @OptIn(UnstableKMathAPI::class)
public object MstField : Field<MST>, NumbersAddOperations<MST>, ScaleOperations<MST> { public object MstField : Field<MST>, NumbersAddOps<MST>, ScaleOperations<MST> {
override inline val zero: MST.Numeric get() = MstRing.zero override inline val zero: MST.Numeric get() = MstRing.zero
override inline val one: MST.Numeric get() = MstRing.one override inline val one: MST.Numeric get() = MstRing.one
@ -95,11 +95,11 @@ public object MstField : Field<MST>, NumbersAddOperations<MST>, ScaleOperations<
override fun add(a: MST, b: MST): MST.Binary = MstRing.add(a, b) override fun add(a: MST, b: MST): MST.Binary = MstRing.add(a, b)
override fun scale(a: MST, value: Double): MST.Binary = override fun scale(a: MST, value: Double): MST.Binary =
MstGroup.binaryOperationFunction(RingOperations.TIMES_OPERATION)(a, MstGroup.number(value)) MstGroup.binaryOperationFunction(RingOps.TIMES_OPERATION)(a, MstGroup.number(value))
override fun multiply(a: MST, b: MST): MST.Binary = MstRing.multiply(a, b) override fun multiply(a: MST, b: MST): MST.Binary = MstRing.multiply(a, b)
override fun divide(a: MST, b: MST): MST.Binary = override fun divide(a: MST, b: MST): MST.Binary =
binaryOperationFunction(FieldOperations.DIV_OPERATION)(a, b) binaryOperationFunction(FieldOps.DIV_OPERATION)(a, b)
override operator fun MST.unaryPlus(): MST.Unary = MstRing { +this@unaryPlus } override operator fun MST.unaryPlus(): MST.Unary = MstRing { +this@unaryPlus }
override operator fun MST.unaryMinus(): MST.Unary = MstRing { -this@unaryMinus } override operator fun MST.unaryMinus(): MST.Unary = MstRing { -this@unaryMinus }
@ -138,7 +138,7 @@ public object MstExtendedField : ExtendedField<MST>, NumericAlgebra<MST> {
override fun sqrt(arg: MST): MST = unaryOperationFunction(PowerOperations.SQRT_OPERATION)(arg) override fun sqrt(arg: MST): MST = unaryOperationFunction(PowerOperations.SQRT_OPERATION)(arg)
override fun scale(a: MST, value: Double): MST = override fun scale(a: MST, value: Double): MST =
binaryOperation(GroupOperations.PLUS_OPERATION, a, number(value)) binaryOperation(GroupOps.PLUS_OPERATION, a, number(value))
override fun multiply(a: MST, b: MST): MST.Binary = MstField.multiply(a, b) override fun multiply(a: MST, b: MST): MST.Binary = MstField.multiply(a, b)
override fun divide(a: MST, b: MST): MST.Binary = MstField.divide(a, b) override fun divide(a: MST, b: MST): MST.Binary = MstField.divide(a, b)

View File

@ -59,7 +59,7 @@ public fun <T : Any> DerivationResult<T>.grad(vararg variables: Symbol): Point<T
public open class SimpleAutoDiffField<T : Any, F : Field<T>>( public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
public val context: F, public val context: F,
bindings: Map<Symbol, T>, bindings: Map<Symbol, T>,
) : Field<AutoDiffValue<T>>, ExpressionAlgebra<T, AutoDiffValue<T>>, NumbersAddOperations<AutoDiffValue<T>> { ) : Field<AutoDiffValue<T>>, ExpressionAlgebra<T, AutoDiffValue<T>>, NumbersAddOps<AutoDiffValue<T>> {
override val zero: AutoDiffValue<T> get() = const(context.zero) override val zero: AutoDiffValue<T> get() = const(context.zero)
override val one: AutoDiffValue<T> get() = const(context.one) override val one: AutoDiffValue<T> get() = const(context.one)

View File

@ -6,12 +6,10 @@
package space.kscience.kmath.linear package space.kscience.kmath.linear
import space.kscience.kmath.misc.PerformancePitfall import space.kscience.kmath.misc.PerformancePitfall
import space.kscience.kmath.nd.BufferedRingND import space.kscience.kmath.nd.BufferedRingOpsND
import space.kscience.kmath.nd.as2D import space.kscience.kmath.nd.as2D
import space.kscience.kmath.nd.asND import space.kscience.kmath.nd.asND
import space.kscience.kmath.nd.ndAlgebra import space.kscience.kmath.operations.*
import space.kscience.kmath.operations.Ring
import space.kscience.kmath.operations.invoke
import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.Buffer
import space.kscience.kmath.structures.BufferFactory import space.kscience.kmath.structures.BufferFactory
import space.kscience.kmath.structures.VirtualBuffer import space.kscience.kmath.structures.VirtualBuffer
@ -19,31 +17,28 @@ import space.kscience.kmath.structures.indices
public class BufferedLinearSpace<T, out A : Ring<T>>( public class BufferedLinearSpace<T, out A : Ring<T>>(
override val elementAlgebra: A, private val bufferAlgebra: BufferAlgebra<T, A>
private val bufferFactory: BufferFactory<T>,
) : LinearSpace<T, A> { ) : LinearSpace<T, A> {
override val elementAlgebra: A get() = bufferAlgebra.elementAlgebra
private fun ndRing( private val ndAlgebra = BufferedRingOpsND(bufferAlgebra)
rows: Int,
cols: Int,
): BufferedRingND<T, A> = elementAlgebra.ndAlgebra(bufferFactory, rows, cols)
override fun buildMatrix(rows: Int, columns: Int, initializer: A.(i: Int, j: Int) -> T): Matrix<T> = override fun buildMatrix(rows: Int, columns: Int, initializer: A.(i: Int, j: Int) -> T): Matrix<T> =
ndRing(rows, columns).produce { (i, j) -> elementAlgebra.initializer(i, j) }.as2D() ndAlgebra.produce(intArrayOf(rows, columns)) { (i, j) -> elementAlgebra.initializer(i, j) }.as2D()
override fun buildVector(size: Int, initializer: A.(Int) -> T): Point<T> = override fun buildVector(size: Int, initializer: A.(Int) -> T): Point<T> =
bufferFactory(size) { elementAlgebra.initializer(it) } bufferAlgebra.buffer(size) { elementAlgebra.initializer(it) }
override fun Matrix<T>.unaryMinus(): Matrix<T> = ndRing(rowNum, colNum).run { override fun Matrix<T>.unaryMinus(): Matrix<T> = ndAlgebra {
asND().map { -it }.as2D() asND().map { -it }.as2D()
} }
override fun Matrix<T>.plus(other: Matrix<T>): Matrix<T> = ndRing(rowNum, colNum).run { override fun Matrix<T>.plus(other: Matrix<T>): Matrix<T> = ndAlgebra {
require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::plus. Expected $shape but found ${other.shape}" } require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::plus. Expected $shape but found ${other.shape}" }
asND().plus(other.asND()).as2D() asND().plus(other.asND()).as2D()
} }
override fun Matrix<T>.minus(other: Matrix<T>): Matrix<T> = ndRing(rowNum, colNum).run { override fun Matrix<T>.minus(other: Matrix<T>): Matrix<T> = ndAlgebra {
require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::minus. Expected $shape but found ${other.shape}" } require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::minus. Expected $shape but found ${other.shape}" }
asND().minus(other.asND()).as2D() asND().minus(other.asND()).as2D()
} }
@ -88,11 +83,11 @@ public class BufferedLinearSpace<T, out A : Ring<T>>(
} }
} }
override fun Matrix<T>.times(value: T): Matrix<T> = ndRing(rowNum, colNum).run { override fun Matrix<T>.times(value: T): Matrix<T> = ndAlgebra {
asND().map { it * value }.as2D() asND().map { it * value }.as2D()
} }
} }
public fun <T, A : Ring<T>> A.linearSpace(bufferFactory: BufferFactory<T>): BufferedLinearSpace<T, A> = public fun <T, A : Ring<T>> A.linearSpace(bufferFactory: BufferFactory<T>): BufferedLinearSpace<T, A> =
BufferedLinearSpace(this, bufferFactory) BufferedLinearSpace(BufferRingOps(this, bufferFactory))

View File

@ -6,11 +6,12 @@
package space.kscience.kmath.linear package space.kscience.kmath.linear
import space.kscience.kmath.misc.PerformancePitfall import space.kscience.kmath.misc.PerformancePitfall
import space.kscience.kmath.nd.DoubleFieldND import space.kscience.kmath.nd.DoubleFieldOpsND
import space.kscience.kmath.nd.as2D import space.kscience.kmath.nd.as2D
import space.kscience.kmath.nd.asND import space.kscience.kmath.nd.asND
import space.kscience.kmath.operations.DoubleBufferOperations import space.kscience.kmath.operations.DoubleBufferOps
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.invoke
import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.Buffer
import space.kscience.kmath.structures.DoubleBuffer import space.kscience.kmath.structures.DoubleBuffer
@ -18,30 +19,27 @@ public object DoubleLinearSpace : LinearSpace<Double, DoubleField> {
override val elementAlgebra: DoubleField get() = DoubleField override val elementAlgebra: DoubleField get() = DoubleField
private fun ndRing(
rows: Int,
cols: Int,
): DoubleFieldND = DoubleFieldND(intArrayOf(rows, cols))
override fun buildMatrix( override fun buildMatrix(
rows: Int, rows: Int,
columns: Int, columns: Int,
initializer: DoubleField.(i: Int, j: Int) -> Double initializer: DoubleField.(i: Int, j: Int) -> Double
): Matrix<Double> = ndRing(rows, columns).produce { (i, j) -> DoubleField.initializer(i, j) }.as2D() ): Matrix<Double> = DoubleFieldOpsND.produce(intArrayOf(rows, columns)) { (i, j) ->
DoubleField.initializer(i, j)
}.as2D()
override fun buildVector(size: Int, initializer: DoubleField.(Int) -> Double): DoubleBuffer = override fun buildVector(size: Int, initializer: DoubleField.(Int) -> Double): DoubleBuffer =
DoubleBuffer(size) { DoubleField.initializer(it) } DoubleBuffer(size) { DoubleField.initializer(it) }
override fun Matrix<Double>.unaryMinus(): Matrix<Double> = ndRing(rowNum, colNum).run { override fun Matrix<Double>.unaryMinus(): Matrix<Double> = DoubleFieldOpsND {
asND().map { -it }.as2D() asND().map { -it }.as2D()
} }
override fun Matrix<Double>.plus(other: Matrix<Double>): Matrix<Double> = ndRing(rowNum, colNum).run { override fun Matrix<Double>.plus(other: Matrix<Double>): Matrix<Double> = DoubleFieldOpsND {
require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::plus. Expected $shape but found ${other.shape}" } require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::plus. Expected $shape but found ${other.shape}" }
asND().plus(other.asND()).as2D() asND().plus(other.asND()).as2D()
} }
override fun Matrix<Double>.minus(other: Matrix<Double>): Matrix<Double> = ndRing(rowNum, colNum).run { override fun Matrix<Double>.minus(other: Matrix<Double>): Matrix<Double> = DoubleFieldOpsND {
require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::minus. Expected $shape but found ${other.shape}" } require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::minus. Expected $shape but found ${other.shape}" }
asND().minus(other.asND()).as2D() asND().minus(other.asND()).as2D()
} }
@ -84,23 +82,23 @@ public object DoubleLinearSpace : LinearSpace<Double, DoubleField> {
} }
override fun Matrix<Double>.times(value: Double): Matrix<Double> = ndRing(rowNum, colNum).run { override fun Matrix<Double>.times(value: Double): Matrix<Double> = DoubleFieldOpsND {
asND().map { it * value }.as2D() asND().map { it * value }.as2D()
} }
public override fun Point<Double>.plus(other: Point<Double>): DoubleBuffer = DoubleBufferOperations.run { public override fun Point<Double>.plus(other: Point<Double>): DoubleBuffer = DoubleBufferOps.run {
this@plus + other this@plus + other
} }
public override fun Point<Double>.minus(other: Point<Double>): DoubleBuffer = DoubleBufferOperations.run { public override fun Point<Double>.minus(other: Point<Double>): DoubleBuffer = DoubleBufferOps.run {
this@minus - other this@minus - other
} }
public override fun Point<Double>.times(value: Double): DoubleBuffer = DoubleBufferOperations.run { public override fun Point<Double>.times(value: Double): DoubleBuffer = DoubleBufferOps.run {
scale(this@times, value) scale(this@times, value)
} }
public operator fun Point<Double>.div(value: Double): DoubleBuffer = DoubleBufferOperations.run { public operator fun Point<Double>.div(value: Double): DoubleBuffer = DoubleBufferOps.run {
scale(this@div, 1.0 / value) scale(this@div, 1.0 / value)
} }

View File

@ -10,6 +10,7 @@ import space.kscience.kmath.nd.MutableStructure2D
import space.kscience.kmath.nd.Structure2D import space.kscience.kmath.nd.Structure2D
import space.kscience.kmath.nd.StructureFeature import space.kscience.kmath.nd.StructureFeature
import space.kscience.kmath.nd.as1D import space.kscience.kmath.nd.as1D
import space.kscience.kmath.operations.BufferRingOps
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.Ring import space.kscience.kmath.operations.Ring
import space.kscience.kmath.operations.invoke import space.kscience.kmath.operations.invoke
@ -188,7 +189,7 @@ public interface LinearSpace<T, out A : Ring<T>> {
public fun <T : Any, A : Ring<T>> buffered( public fun <T : Any, A : Ring<T>> buffered(
algebra: A, algebra: A,
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing, bufferFactory: BufferFactory<T> = Buffer.Companion::boxing,
): LinearSpace<T, A> = BufferedLinearSpace(algebra, bufferFactory) ): LinearSpace<T, A> = BufferedLinearSpace(BufferRingOps(algebra, bufferFactory))
@Deprecated("use DoubleField.linearSpace") @Deprecated("use DoubleField.linearSpace")
public val double: LinearSpace<Double, DoubleField> = buffered(DoubleField, ::DoubleBuffer) public val double: LinearSpace<Double, DoubleField> = buffered(DoubleField, ::DoubleBuffer)

View File

@ -7,7 +7,6 @@ package space.kscience.kmath.nd
import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.* import space.kscience.kmath.operations.*
import space.kscience.kmath.structures.*
import kotlin.reflect.KClass import kotlin.reflect.KClass
/** /**
@ -19,6 +18,14 @@ import kotlin.reflect.KClass
public class ShapeMismatchException(public val expected: IntArray, public val actual: IntArray) : public class ShapeMismatchException(public val expected: IntArray, public val actual: IntArray) :
RuntimeException("Shape ${actual.contentToString()} doesn't fit in expected shape ${expected.contentToString()}.") RuntimeException("Shape ${actual.contentToString()} doesn't fit in expected shape ${expected.contentToString()}.")
public typealias Shape = IntArray
public fun Shape(shapeFirst: Int, vararg shapeRest: Int): Shape = intArrayOf(shapeFirst, *shapeRest)
public interface WithShape {
public val shape: Shape
}
/** /**
* The base interface for all ND-algebra implementations. * The base interface for all ND-algebra implementations.
* *
@ -26,20 +33,15 @@ public class ShapeMismatchException(public val expected: IntArray, public val ac
* @param C the type of the element context. * @param C the type of the element context.
*/ */
public interface AlgebraND<T, out C : Algebra<T>> { public interface AlgebraND<T, out C : Algebra<T>> {
/**
* The shape of ND-structures this algebra operates on.
*/
public val shape: IntArray
/** /**
* The algebra over elements of ND structure. * The algebra over elements of ND structure.
*/ */
public val elementContext: C public val elementAlgebra: C
/** /**
* Produces a new NDStructure using given initializer function. * Produces a new NDStructure using given initializer function.
*/ */
public fun produce(initializer: C.(IntArray) -> T): StructureND<T> public fun produce(shape: Shape, initializer: C.(IntArray) -> T): StructureND<T>
/** /**
* Maps elements from one structure to another one by applying [transform] to them. * Maps elements from one structure to another one by applying [transform] to them.
@ -54,7 +56,7 @@ public interface AlgebraND<T, out C : Algebra<T>> {
/** /**
* Combines two structures into one. * Combines two structures into one.
*/ */
public fun combine(a: StructureND<T>, b: StructureND<T>, transform: C.(T, T) -> T): StructureND<T> public fun zip(left: StructureND<T>, right: StructureND<T>, transform: C.(T, T) -> T): StructureND<T>
/** /**
* Element-wise invocation of function working on [T] on a [StructureND]. * Element-wise invocation of function working on [T] on a [StructureND].
@ -77,7 +79,6 @@ public interface AlgebraND<T, out C : Algebra<T>> {
public companion object public companion object
} }
/** /**
* Get a feature of the structure in this scope. Structure features take precedence other context features. * Get a feature of the structure in this scope. Structure features take precedence other context features.
* *
@ -89,37 +90,13 @@ public interface AlgebraND<T, out C : Algebra<T>> {
public inline fun <T : Any, reified F : StructureFeature> AlgebraND<T, *>.getFeature(structure: StructureND<T>): F? = public inline fun <T : Any, reified F : StructureFeature> AlgebraND<T, *>.getFeature(structure: StructureND<T>): F? =
getFeature(structure, F::class) getFeature(structure, F::class)
/**
* Checks if given elements are consistent with this context.
*
* @param structures the structures to check.
* @return the array of valid structures.
*/
internal fun <T, C : Algebra<T>> AlgebraND<T, C>.checkShape(vararg structures: StructureND<T>): Array<out StructureND<T>> =
structures
.map(StructureND<T>::shape)
.singleOrNull { !shape.contentEquals(it) }
?.let<IntArray, Array<out StructureND<T>>> { throw ShapeMismatchException(shape, it) }
?: structures
/**
* Checks if given element is consistent with this context.
*
* @param element the structure to check.
* @return the valid structure.
*/
internal fun <T, C : Algebra<T>> AlgebraND<T, C>.checkShape(element: StructureND<T>): StructureND<T> {
if (!element.shape.contentEquals(shape)) throw ShapeMismatchException(shape, element.shape)
return element
}
/** /**
* Space of [StructureND]. * Space of [StructureND].
* *
* @param T the type of the element contained in ND structure. * @param T the type of the element contained in ND structure.
* @param S the type of group over structure elements. * @param A the type of group over structure elements.
*/ */
public interface GroupND<T, out S : Group<T>> : Group<StructureND<T>>, AlgebraND<T, S> { public interface GroupOpsND<T, out A : GroupOps<T>> : GroupOps<StructureND<T>>, AlgebraND<T, A> {
/** /**
* Element-wise addition. * Element-wise addition.
* *
@ -128,7 +105,7 @@ public interface GroupND<T, out S : Group<T>> : Group<StructureND<T>>, AlgebraND
* @return the sum. * @return the sum.
*/ */
override fun add(a: StructureND<T>, b: StructureND<T>): StructureND<T> = override fun add(a: StructureND<T>, b: StructureND<T>): StructureND<T> =
combine(a, b) { aValue, bValue -> add(aValue, bValue) } zip(a, b) { aValue, bValue -> add(aValue, bValue) }
// TODO move to extensions after KEEP-176 // TODO move to extensions after KEEP-176
@ -171,13 +148,17 @@ public interface GroupND<T, out S : Group<T>> : Group<StructureND<T>>, AlgebraND
public companion object public companion object
} }
public interface GroupND<T, out A : Group<T>> : Group<StructureND<T>>, GroupOpsND<T, A>, WithShape {
override val zero: StructureND<T> get() = produce(shape) { elementAlgebra.zero }
}
/** /**
* Ring of [StructureND]. * Ring of [StructureND].
* *
* @param T the type of the element contained in ND structure. * @param T the type of the element contained in ND structure.
* @param R the type of ring over structure elements. * @param A the type of ring over structure elements.
*/ */
public interface RingND<T, out R : Ring<T>> : Ring<StructureND<T>>, GroupND<T, R> { public interface RingOpsND<T, out A : RingOps<T>> : RingOps<StructureND<T>>, GroupOpsND<T, A> {
/** /**
* Element-wise multiplication. * Element-wise multiplication.
* *
@ -186,7 +167,7 @@ public interface RingND<T, out R : Ring<T>> : Ring<StructureND<T>>, GroupND<T, R
* @return the product. * @return the product.
*/ */
override fun multiply(a: StructureND<T>, b: StructureND<T>): StructureND<T> = override fun multiply(a: StructureND<T>, b: StructureND<T>): StructureND<T> =
combine(a, b) { aValue, bValue -> multiply(aValue, bValue) } zip(a, b) { aValue, bValue -> multiply(aValue, bValue) }
//TODO move to extensions after KEEP-176 //TODO move to extensions after KEEP-176
@ -211,13 +192,19 @@ public interface RingND<T, out R : Ring<T>> : Ring<StructureND<T>>, GroupND<T, R
public companion object public companion object
} }
public interface RingND<T, out A : Ring<T>> : Ring<StructureND<T>>, RingOpsND<T, A>, GroupND<T, A>, WithShape {
override val one: StructureND<T> get() = produce(shape) { elementAlgebra.one }
}
/** /**
* Field of [StructureND]. * Field of [StructureND].
* *
* @param T the type of the element contained in ND structure. * @param T the type of the element contained in ND structure.
* @param F the type field over structure elements. * @param A the type field over structure elements.
*/ */
public interface FieldND<T, out F : Field<T>> : Field<StructureND<T>>, RingND<T, F> { public interface FieldOpsND<T, out A : Field<T>> : FieldOps<StructureND<T>>, RingOpsND<T, A>,
ScaleOperations<StructureND<T>> {
/** /**
* Element-wise division. * Element-wise division.
* *
@ -226,9 +213,9 @@ public interface FieldND<T, out F : Field<T>> : Field<StructureND<T>>, RingND<T,
* @return the quotient. * @return the quotient.
*/ */
override fun divide(a: StructureND<T>, b: StructureND<T>): StructureND<T> = override fun divide(a: StructureND<T>, b: StructureND<T>): StructureND<T> =
combine(a, b) { aValue, bValue -> divide(aValue, bValue) } zip(a, b) { aValue, bValue -> divide(aValue, bValue) }
//TODO move to extensions after KEEP-176 //TODO move to extensions after https://github.com/Kotlin/KEEP/blob/master/proposals/context-receivers.md
/** /**
* Divides an ND structure by an element of it. * Divides an ND structure by an element of it.
* *
@ -247,42 +234,9 @@ public interface FieldND<T, out F : Field<T>> : Field<StructureND<T>>, RingND<T,
*/ */
public operator fun T.div(arg: StructureND<T>): StructureND<T> = arg.map { divide(it, this@div) } public operator fun T.div(arg: StructureND<T>): StructureND<T> = arg.map { divide(it, this@div) }
/**
* Element-wise scaling.
*
* @param a the multiplicand.
* @param value the multiplier.
* @return the product.
*/
override fun scale(a: StructureND<T>, value: Double): StructureND<T> = a.map { scale(it, value) } override fun scale(a: StructureND<T>, value: Double): StructureND<T> = a.map { scale(it, value) }
}
// @ThreadLocal
// public companion object { public interface FieldND<T, out A : Field<T>> : Field<StructureND<T>>, FieldOpsND<T, A>, RingND<T, A>, WithShape {
// private val realNDFieldCache: MutableMap<IntArray, RealNDField> = hashMapOf() override val one: StructureND<T> get() = produce(shape) { elementAlgebra.one }
//
// /**
// * Create a nd-field for [Double] values or pull it from cache if it was created previously.
// */
// public fun real(vararg shape: Int): RealNDField = realNDFieldCache.getOrPut(shape) { RealNDField(shape) }
//
// /**
// * Create an ND field with boxing generic buffer.
// */
// public fun <T : Any, F : Field<T>> boxing(
// field: F,
// vararg shape: Int,
// bufferFactory: BufferFactory<T> = Buffer.Companion::boxing,
// ): BufferedNDField<T, F> = BufferedNDField(shape, field, bufferFactory)
//
// /**
// * Create a most suitable implementation for nd-field using reified class.
// */
// @Suppress("UNCHECKED_CAST")
// public inline fun <reified T : Any, F : Field<T>> auto(field: F, vararg shape: Int): NDField<T, F> =
// when {
// T::class == Double::class -> real(*shape) as NDField<T, F>
// T::class == Complex::class -> complex(*shape) as BufferedNDField<T, F>
// else -> BoxingNDField(shape, field, Buffer.Companion::auto)
// }
// }
} }

View File

@ -3,145 +3,173 @@
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
*/ */
@file:OptIn(UnstableKMathAPI::class)
package space.kscience.kmath.nd package space.kscience.kmath.nd
import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.* import space.kscience.kmath.operations.*
import space.kscience.kmath.structures.Buffer
import space.kscience.kmath.structures.BufferFactory import space.kscience.kmath.structures.BufferFactory
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
import kotlin.jvm.JvmName
public interface BufferAlgebraND<T, out A : Algebra<T>> : AlgebraND<T, A> { public interface BufferAlgebraND<T, out A : Algebra<T>> : AlgebraND<T, A> {
public val strides: Strides public val indexerBuilder: (IntArray) -> ShapeIndex
public val bufferFactory: BufferFactory<T> public val bufferAlgebra: BufferAlgebra<T, A>
override val elementAlgebra: A get() = bufferAlgebra.elementAlgebra
override fun produce(initializer: A.(IntArray) -> T): BufferND<T> = BufferND( override fun produce(shape: Shape, initializer: A.(IntArray) -> T): BufferND<T> {
strides, val indexer = indexerBuilder(shape)
bufferFactory(strides.linearSize) { offset -> return BufferND(
elementContext.initializer(strides.index(offset)) indexer,
bufferAlgebra.buffer(indexer.linearSize) { offset ->
elementAlgebra.initializer(indexer.index(offset))
} }
) )
public val StructureND<T>.buffer: Buffer<T>
get() = when {
!shape.contentEquals(this@BufferAlgebraND.shape) -> throw ShapeMismatchException(
this@BufferAlgebraND.shape,
shape
)
this is BufferND && this.strides == this@BufferAlgebraND.strides -> this.buffer
else -> bufferFactory(strides.linearSize) { offset -> get(strides.index(offset)) }
} }
override fun StructureND<T>.map(transform: A.(T) -> T): BufferND<T> { public fun StructureND<T>.toBufferND(): BufferND<T> = when (this) {
val buffer = bufferFactory(strides.linearSize) { offset -> is BufferND -> this
elementContext.transform(buffer[offset]) else -> {
val indexer = indexerBuilder(shape)
BufferND(indexer, bufferAlgebra.buffer(indexer.linearSize) { offset -> get(indexer.index(offset)) })
} }
return BufferND(strides, buffer)
} }
override fun StructureND<T>.mapIndexed(transform: A.(index: IntArray, T) -> T): BufferND<T> { override fun StructureND<T>.map(transform: A.(T) -> T): BufferND<T> = mapInline(toBufferND(), transform)
val buffer = bufferFactory(strides.linearSize) { offset ->
elementContext.transform(
strides.index(offset),
buffer[offset]
)
}
return BufferND(strides, buffer)
}
override fun combine(a: StructureND<T>, b: StructureND<T>, transform: A.(T, T) -> T): BufferND<T> { override fun StructureND<T>.mapIndexed(transform: A.(index: IntArray, T) -> T): BufferND<T> =
val buffer = bufferFactory(strides.linearSize) { offset -> mapIndexedInline(toBufferND(), transform)
elementContext.transform(a.buffer[offset], b.buffer[offset])
} override fun zip(left: StructureND<T>, right: StructureND<T>, transform: A.(T, T) -> T): BufferND<T> =
return BufferND(strides, buffer) zipInline(left.toBufferND(), right.toBufferND(), transform)
public companion object {
public val defaultIndexerBuilder: (IntArray) -> ShapeIndex = DefaultStrides.Companion::invoke
} }
} }
public open class BufferedGroupND<T, out A : Group<T>>( public inline fun <T, A : Algebra<T>> BufferAlgebraND<T, A>.mapInline(
final override val shape: IntArray, arg: BufferND<T>,
final override val elementContext: A, crossinline transform: A.(T) -> T
final override val bufferFactory: BufferFactory<T>, ): BufferND<T> {
) : GroupND<T, A>, BufferAlgebraND<T, A> { val indexes = arg.indexes
override val strides: Strides = DefaultStrides(shape) return BufferND(indexes, bufferAlgebra.mapInline(arg.buffer, transform))
override val zero: BufferND<T> by lazy { produce { zero } }
override fun StructureND<T>.unaryMinus(): StructureND<T> = produce { -get(it) }
} }
public open class BufferedRingND<T, out R : Ring<T>>( internal inline fun <T, A : Algebra<T>> BufferAlgebraND<T, A>.mapIndexedInline(
shape: IntArray, arg: BufferND<T>,
elementContext: R, crossinline transform: A.(index: IntArray, arg: T) -> T
): BufferND<T> {
val indexes = arg.indexes
return BufferND(
indexes,
bufferAlgebra.mapIndexedInline(arg.buffer) { offset, value ->
transform(indexes.index(offset), value)
}
)
}
internal inline fun <T, A : Algebra<T>> BufferAlgebraND<T, A>.zipInline(
l: BufferND<T>,
r: BufferND<T>,
crossinline block: A.(l: T, r: T) -> T
): BufferND<T> {
require(l.indexes == r.indexes)
val indexes = l.indexes
return BufferND(indexes, bufferAlgebra.zipInline(l.buffer, r.buffer, block))
}
public open class BufferedGroupNDOps<T, out A : Group<T>>(
override val bufferAlgebra: BufferAlgebra<T, A>,
override val indexerBuilder: (IntArray) -> ShapeIndex = BufferAlgebraND.defaultIndexerBuilder
) : GroupOpsND<T, A>, BufferAlgebraND<T, A> {
override fun StructureND<T>.unaryMinus(): StructureND<T> = map { -it }
}
public open class BufferedRingOpsND<T, out A : Ring<T>>(
bufferAlgebra: BufferAlgebra<T, A>,
indexerBuilder: (IntArray) -> ShapeIndex = BufferAlgebraND.defaultIndexerBuilder
) : BufferedGroupNDOps<T, A>(bufferAlgebra, indexerBuilder), RingOpsND<T, A>
public open class BufferedFieldOpsND<T, out A : Field<T>>(
bufferAlgebra: BufferAlgebra<T, A>,
indexerBuilder: (IntArray) -> ShapeIndex = BufferAlgebraND.defaultIndexerBuilder
) : BufferedRingOpsND<T, A>(bufferAlgebra, indexerBuilder), FieldOpsND<T, A> {
public constructor(
elementAlgebra: A,
bufferFactory: BufferFactory<T>, bufferFactory: BufferFactory<T>,
) : BufferedGroupND<T, R>(shape, elementContext, bufferFactory), RingND<T, R> { indexerBuilder: (IntArray) -> ShapeIndex = BufferAlgebraND.defaultIndexerBuilder
override val one: BufferND<T> by lazy { produce { one } } ) : this(BufferFieldOps(elementAlgebra, bufferFactory), indexerBuilder)
}
public open class BufferedFieldND<T, out R : Field<T>>(
shape: IntArray,
elementContext: R,
bufferFactory: BufferFactory<T>,
) : BufferedRingND<T, R>(shape, elementContext, bufferFactory), FieldND<T, R> {
override fun scale(a: StructureND<T>, value: Double): StructureND<T> = a.map { it * value } override fun scale(a: StructureND<T>, value: Double): StructureND<T> = a.map { it * value }
} }
// group factories public val <T, A : Group<T>> BufferAlgebra<T, A>.nd: BufferedGroupNDOps<T, A> get() = BufferedGroupNDOps(this)
public fun <T, A : Group<T>> A.ndAlgebra( public val <T, A : Ring<T>> BufferAlgebra<T, A>.nd: BufferedRingOpsND<T, A> get() = BufferedRingOpsND(this)
bufferFactory: BufferFactory<T>, public val <T, A : Field<T>> BufferAlgebra<T, A>.nd: BufferedFieldOpsND<T, A> get() = BufferedFieldOpsND(this)
vararg shape: Int,
): BufferedGroupND<T, A> = BufferedGroupND(shape, this, bufferFactory)
@JvmName("withNdGroup")
public inline fun <T, A : Group<T>, R> A.withNdAlgebra(
noinline bufferFactory: BufferFactory<T>,
vararg shape: Int,
action: BufferedGroupND<T, A>.() -> R,
): R {
contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) }
return ndAlgebra(bufferFactory, *shape).run(action)
}
//ring factories public fun <T, A : Algebra<T>> BufferAlgebraND<T, A>.produce(
public fun <T, A : Ring<T>> A.ndAlgebra(
bufferFactory: BufferFactory<T>,
vararg shape: Int, vararg shape: Int,
): BufferedRingND<T, A> = BufferedRingND(shape, this, bufferFactory) initializer: A.(IntArray) -> T
): BufferND<T> = produce(shape, initializer)
@JvmName("withNdRing") //// group factories
public inline fun <T, A : Ring<T>, R> A.withNdAlgebra( //public fun <T, A : Group<T>> A.ndAlgebra(
noinline bufferFactory: BufferFactory<T>, // bufferAlgebra: BufferAlgebra<T, A>,
vararg shape: Int, // vararg shape: Int,
action: BufferedRingND<T, A>.() -> R, //): BufferedGroupNDOps<T, A> = BufferedGroupNDOps(bufferAlgebra)
): R { //
contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) } //@JvmName("withNdGroup")
return ndAlgebra(bufferFactory, *shape).run(action) //public inline fun <T, A : Group<T>, R> A.withNdAlgebra(
} // noinline bufferFactory: BufferFactory<T>,
// vararg shape: Int,
// action: BufferedGroupNDOps<T, A>.() -> R,
//): R {
// contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) }
// return ndAlgebra(bufferFactory, *shape).run(action)
//}
//field factories ////ring factories
public fun <T, A : Field<T>> A.ndAlgebra( //public fun <T, A : Ring<T>> A.ndAlgebra(
bufferFactory: BufferFactory<T>, // bufferFactory: BufferFactory<T>,
vararg shape: Int, // vararg shape: Int,
): BufferedFieldND<T, A> = BufferedFieldND(shape, this, bufferFactory) //): BufferedRingNDOps<T, A> = BufferedRingNDOps(shape, this, bufferFactory)
//
/** //@JvmName("withNdRing")
* Create a [FieldND] for this [Field] inferring proper buffer factory from the type //public inline fun <T, A : Ring<T>, R> A.withNdAlgebra(
*/ // noinline bufferFactory: BufferFactory<T>,
@UnstableKMathAPI // vararg shape: Int,
@Suppress("UNCHECKED_CAST") // action: BufferedRingNDOps<T, A>.() -> R,
public inline fun <reified T : Any, A : Field<T>> A.autoNdAlgebra( //): R {
vararg shape: Int, // contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) }
): FieldND<T, A> = when (this) { // return ndAlgebra(bufferFactory, *shape).run(action)
DoubleField -> DoubleFieldND(shape) as FieldND<T, A> //}
else -> BufferedFieldND(shape, this, Buffer.Companion::auto) //
} ////field factories
//public fun <T, A : Field<T>> A.ndAlgebra(
@JvmName("withNdField") // bufferFactory: BufferFactory<T>,
public inline fun <T, A : Field<T>, R> A.withNdAlgebra( // vararg shape: Int,
noinline bufferFactory: BufferFactory<T>, //): BufferedFieldNDOps<T, A> = BufferedFieldNDOps(shape, this, bufferFactory)
vararg shape: Int, //
action: BufferedFieldND<T, A>.() -> R, ///**
): R { // * Create a [FieldND] for this [Field] inferring proper buffer factory from the type
contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) } // */
return ndAlgebra(bufferFactory, *shape).run(action) //@UnstableKMathAPI
} //@Suppress("UNCHECKED_CAST")
//public inline fun <reified T : Any, A : Field<T>> A.autoNdAlgebra(
// vararg shape: Int,
//): FieldND<T, A> = when (this) {
// DoubleField -> DoubleFieldND(shape) as FieldND<T, A>
// else -> BufferedFieldNDOps(shape, this, Buffer.Companion::auto)
//}
//
//@JvmName("withNdField")
//public inline fun <T, A : Field<T>, R> A.withNdAlgebra(
// noinline bufferFactory: BufferFactory<T>,
// vararg shape: Int,
// action: BufferedFieldNDOps<T, A>.() -> R,
//): R {
// contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) }
// return ndAlgebra(bufferFactory, *shape).run(action)
//}

View File

@ -15,26 +15,20 @@ import space.kscience.kmath.structures.MutableBufferFactory
* Represents [StructureND] over [Buffer]. * Represents [StructureND] over [Buffer].
* *
* @param T the type of items. * @param T the type of items.
* @param strides The strides to access elements of [Buffer] by linear indices. * @param indexes The strides to access elements of [Buffer] by linear indices.
* @param buffer The underlying buffer. * @param buffer The underlying buffer.
*/ */
public open class BufferND<out T>( public open class BufferND<out T>(
public val strides: Strides, public val indexes: ShapeIndex,
public val buffer: Buffer<T>, public val buffer: Buffer<T>,
) : StructureND<T> { ) : StructureND<T> {
init { override operator fun get(index: IntArray): T = buffer[indexes.offset(index)]
if (strides.linearSize != buffer.size) {
error("Expected buffer side of ${strides.linearSize}, but found ${buffer.size}")
}
}
override operator fun get(index: IntArray): T = buffer[strides.offset(index)] override val shape: IntArray get() = indexes.shape
override val shape: IntArray get() = strides.shape
@PerformancePitfall @PerformancePitfall
override fun elements(): Sequence<Pair<IntArray, T>> = strides.indices().map { override fun elements(): Sequence<Pair<IntArray, T>> = indexes.indices().map {
it to this[it] it to this[it]
} }
@ -49,7 +43,7 @@ public inline fun <T, reified R : Any> StructureND<T>.mapToBuffer(
crossinline transform: (T) -> R, crossinline transform: (T) -> R,
): BufferND<R> { ): BufferND<R> {
return if (this is BufferND<T>) return if (this is BufferND<T>)
BufferND(this.strides, factory.invoke(strides.linearSize) { transform(buffer[it]) }) BufferND(this.indexes, factory.invoke(indexes.linearSize) { transform(buffer[it]) })
else { else {
val strides = DefaultStrides(shape) val strides = DefaultStrides(shape)
BufferND(strides, factory.invoke(strides.linearSize) { transform(get(strides.index(it))) }) BufferND(strides, factory.invoke(strides.linearSize) { transform(get(strides.index(it))) })
@ -64,11 +58,11 @@ public inline fun <T, reified R : Any> StructureND<T>.mapToBuffer(
* @param mutableBuffer The underlying buffer. * @param mutableBuffer The underlying buffer.
*/ */
public class MutableBufferND<T>( public class MutableBufferND<T>(
strides: Strides, strides: ShapeIndex,
public val mutableBuffer: MutableBuffer<T>, public val mutableBuffer: MutableBuffer<T>,
) : MutableStructureND<T>, BufferND<T>(strides, mutableBuffer) { ) : MutableStructureND<T>, BufferND<T>(strides, mutableBuffer) {
override fun set(index: IntArray, value: T) { override fun set(index: IntArray, value: T) {
mutableBuffer[strides.offset(index)] = value mutableBuffer[indexes.offset(index)] = value
} }
} }
@ -80,7 +74,7 @@ public inline fun <T, reified R : Any> MutableStructureND<T>.mapToMutableBuffer(
crossinline transform: (T) -> R, crossinline transform: (T) -> R,
): MutableBufferND<R> { ): MutableBufferND<R> {
return if (this is MutableBufferND<T>) return if (this is MutableBufferND<T>)
MutableBufferND(this.strides, factory.invoke(strides.linearSize) { transform(mutableBuffer[it]) }) MutableBufferND(this.indexes, factory.invoke(indexes.linearSize) { transform(mutableBuffer[it]) })
else { else {
val strides = DefaultStrides(shape) val strides = DefaultStrides(shape)
MutableBufferND(strides, factory.invoke(strides.linearSize) { transform(get(strides.index(it))) }) MutableBufferND(strides, factory.invoke(strides.linearSize) { transform(get(strides.index(it))) })

View File

@ -6,108 +6,68 @@
package space.kscience.kmath.nd package space.kscience.kmath.nd
import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.*
import space.kscience.kmath.operations.ExtendedField
import space.kscience.kmath.operations.NumbersAddOperations
import space.kscience.kmath.operations.ScaleOperations
import space.kscience.kmath.structures.DoubleBuffer import space.kscience.kmath.structures.DoubleBuffer
import kotlin.contracts.InvocationKind import kotlin.contracts.InvocationKind
import kotlin.contracts.contract import kotlin.contracts.contract
@OptIn(UnstableKMathAPI::class) public sealed class DoubleFieldOpsND : BufferedFieldOpsND<Double, DoubleField>(DoubleField.bufferAlgebra),
public class DoubleFieldND( ScaleOperations<StructureND<Double>>, ExtendedFieldOps<StructureND<Double>> {
shape: IntArray,
) : BufferedFieldND<Double, DoubleField>(shape, DoubleField, ::DoubleBuffer),
NumbersAddOperations<StructureND<Double>>,
ScaleOperations<StructureND<Double>>,
ExtendedField<StructureND<Double>> {
override val zero: BufferND<Double> by lazy { produce { zero } } override fun StructureND<Double>.toBufferND(): BufferND<Double> = when (this) {
override val one: BufferND<Double> by lazy { produce { one } } is BufferND -> this
else -> {
val indexer = indexerBuilder(shape)
BufferND(indexer, DoubleBuffer(indexer.linearSize) { offset -> get(indexer.index(offset)) })
}
}
//TODO do specialization
override fun scale(a: StructureND<Double>, value: Double): BufferND<Double> =
mapInline(a.toBufferND()) { it * value }
override fun power(arg: StructureND<Double>, pow: Number): BufferND<Double> =
mapInline(arg.toBufferND()) { power(it, pow) }
override fun exp(arg: StructureND<Double>): BufferND<Double> = mapInline(arg.toBufferND()) { exp(it) }
override fun ln(arg: StructureND<Double>): BufferND<Double> = mapInline(arg.toBufferND()) { ln(it) }
override fun sin(arg: StructureND<Double>): BufferND<Double> = mapInline(arg.toBufferND()) { sin(it) }
override fun cos(arg: StructureND<Double>): BufferND<Double> = mapInline(arg.toBufferND()) { cos(it) }
override fun tan(arg: StructureND<Double>): BufferND<Double> = mapInline(arg.toBufferND()) { tan(it) }
override fun asin(arg: StructureND<Double>): BufferND<Double> = mapInline(arg.toBufferND()) { asin(it) }
override fun acos(arg: StructureND<Double>): BufferND<Double> = mapInline(arg.toBufferND()) { acos(it) }
override fun atan(arg: StructureND<Double>): BufferND<Double> = mapInline(arg.toBufferND()) { atan(it) }
override fun sinh(arg: StructureND<Double>): BufferND<Double> = mapInline(arg.toBufferND()) { sinh(it) }
override fun cosh(arg: StructureND<Double>): BufferND<Double> = mapInline(arg.toBufferND()) { cosh(it) }
override fun tanh(arg: StructureND<Double>): BufferND<Double> = mapInline(arg.toBufferND()) { tanh(it) }
override fun asinh(arg: StructureND<Double>): BufferND<Double> = mapInline(arg.toBufferND()) { asinh(it) }
override fun acosh(arg: StructureND<Double>): BufferND<Double> = mapInline(arg.toBufferND()) { acosh(it) }
override fun atanh(arg: StructureND<Double>): BufferND<Double> = mapInline(arg.toBufferND()) { atanh(it) }
public companion object : DoubleFieldOpsND()
}
@OptIn(UnstableKMathAPI::class)
public class DoubleFieldND(override val shape: Shape) :
DoubleFieldOpsND(), FieldND<Double, DoubleField>, NumbersAddOps<StructureND<Double>> {
override fun number(value: Number): BufferND<Double> { override fun number(value: Number): BufferND<Double> {
val d = value.toDouble() // minimize conversions val d = value.toDouble() // minimize conversions
return produce { d } return produce(shape) { d }
} }
override val StructureND<Double>.buffer: DoubleBuffer
get() = when {
!shape.contentEquals(this@DoubleFieldND.shape) -> throw ShapeMismatchException(
this@DoubleFieldND.shape,
shape
)
this is BufferND && this.strides == this@DoubleFieldND.strides -> this.buffer as DoubleBuffer
else -> DoubleBuffer(strides.linearSize) { offset -> get(strides.index(offset)) }
}
@Suppress("OVERRIDE_BY_INLINE")
override inline fun StructureND<Double>.map(
transform: DoubleField.(Double) -> Double,
): BufferND<Double> {
val buffer = DoubleBuffer(strides.linearSize) { offset -> DoubleField.transform(buffer.array[offset]) }
return BufferND(strides, buffer)
}
@Suppress("OVERRIDE_BY_INLINE")
override inline fun produce(initializer: DoubleField.(IntArray) -> Double): BufferND<Double> {
val array = DoubleArray(strides.linearSize) { offset ->
val index = strides.index(offset)
DoubleField.initializer(index)
}
return BufferND(strides, DoubleBuffer(array))
}
@Suppress("OVERRIDE_BY_INLINE")
override inline fun StructureND<Double>.mapIndexed(
transform: DoubleField.(index: IntArray, Double) -> Double,
): BufferND<Double> = BufferND(
strides,
buffer = DoubleBuffer(strides.linearSize) { offset ->
DoubleField.transform(
strides.index(offset),
buffer.array[offset]
)
})
@Suppress("OVERRIDE_BY_INLINE")
override inline fun combine(
a: StructureND<Double>,
b: StructureND<Double>,
transform: DoubleField.(Double, Double) -> Double,
): BufferND<Double> {
val buffer = DoubleBuffer(strides.linearSize) { offset ->
DoubleField.transform(a.buffer.array[offset], b.buffer.array[offset])
}
return BufferND(strides, buffer)
}
override fun scale(a: StructureND<Double>, value: Double): StructureND<Double> = a.map { it * value }
override fun power(arg: StructureND<Double>, pow: Number): BufferND<Double> = arg.map { power(it, pow) }
override fun exp(arg: StructureND<Double>): BufferND<Double> = arg.map { exp(it) }
override fun ln(arg: StructureND<Double>): BufferND<Double> = arg.map { ln(it) }
override fun sin(arg: StructureND<Double>): BufferND<Double> = arg.map { sin(it) }
override fun cos(arg: StructureND<Double>): BufferND<Double> = arg.map { cos(it) }
override fun tan(arg: StructureND<Double>): BufferND<Double> = arg.map { tan(it) }
override fun asin(arg: StructureND<Double>): BufferND<Double> = arg.map { asin(it) }
override fun acos(arg: StructureND<Double>): BufferND<Double> = arg.map { acos(it) }
override fun atan(arg: StructureND<Double>): BufferND<Double> = arg.map { atan(it) }
override fun sinh(arg: StructureND<Double>): BufferND<Double> = arg.map { sinh(it) }
override fun cosh(arg: StructureND<Double>): BufferND<Double> = arg.map { cosh(it) }
override fun tanh(arg: StructureND<Double>): BufferND<Double> = arg.map { tanh(it) }
override fun asinh(arg: StructureND<Double>): BufferND<Double> = arg.map { asinh(it) }
override fun acosh(arg: StructureND<Double>): BufferND<Double> = arg.map { acosh(it) }
override fun atanh(arg: StructureND<Double>): BufferND<Double> = arg.map { atanh(it) }
} }
public val DoubleField.ndAlgebra: DoubleFieldOpsND get() = DoubleFieldOpsND
public fun DoubleField.ndAlgebra(vararg shape: Int): DoubleFieldND = DoubleFieldND(shape) public fun DoubleField.ndAlgebra(vararg shape: Int): DoubleFieldND = DoubleFieldND(shape)
/** /**
* Produce a context for n-dimensional operations inside this real field * Produce a context for n-dimensional operations inside this real field
*/ */
@UnstableKMathAPI
public inline fun <R> DoubleField.withNdAlgebra(vararg shape: Int, action: DoubleFieldND.() -> R): R { public inline fun <R> DoubleField.withNdAlgebra(vararg shape: Int, action: DoubleFieldND.() -> R): R {
contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) } contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) }
return DoubleFieldND(shape).run(action) return DoubleFieldND(shape).run(action)

View File

@ -0,0 +1,120 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
*/
package space.kscience.kmath.nd
import kotlin.native.concurrent.ThreadLocal
/**
* A converter from linear index to multivariate index
*/
public interface ShapeIndex{
public val shape: Shape
/**
* Get linear index from multidimensional index
*/
public fun offset(index: IntArray): Int
/**
* Get multidimensional from linear
*/
public fun index(offset: Int): IntArray
/**
* The size of linear buffer to accommodate all elements of ND-structure corresponding to strides
*/
public val linearSize: Int
// TODO introduce a fast way to calculate index of the next element?
/**
* Iterate over ND indices in a natural order
*/
public fun indices(): Sequence<IntArray>
override fun equals(other: Any?): Boolean
override fun hashCode(): Int
}
/**
* Linear transformation of indexes
*/
public abstract class Strides: ShapeIndex {
/**
* Array strides
*/
public abstract val strides: IntArray
public override fun offset(index: IntArray): Int = index.mapIndexed { i, value ->
if (value < 0 || value >= shape[i]) throw IndexOutOfBoundsException("Index $value out of shape bounds: (0,${this.shape[i]})")
value * strides[i]
}.sum()
// TODO introduce a fast way to calculate index of the next element?
/**
* Iterate over ND indices in a natural order
*/
public override fun indices(): Sequence<IntArray> = (0 until linearSize).asSequence().map(::index)
}
/**
* Simple implementation of [Strides].
*/
public class DefaultStrides private constructor(override val shape: IntArray) : Strides() {
override val linearSize: Int get() = strides[shape.size]
/**
* Strides for memory access
*/
override val strides: IntArray by lazy {
sequence {
var current = 1
yield(1)
shape.forEach {
current *= it
yield(current)
}
}.toList().toIntArray()
}
override fun index(offset: Int): IntArray {
val res = IntArray(shape.size)
var current = offset
var strideIndex = strides.size - 2
while (strideIndex >= 0) {
res[strideIndex] = (current / strides[strideIndex])
current %= strides[strideIndex]
strideIndex--
}
return res
}
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (other !is DefaultStrides) return false
if (!shape.contentEquals(other.shape)) return false
return true
}
override fun hashCode(): Int = shape.contentHashCode()
@ThreadLocal
public companion object {
//private val defaultStridesCache = HashMap<IntArray, Strides>()
/**
* Cached builder for default strides
*/
public operator fun invoke(shape: IntArray): Strides = DefaultStrides(shape)
//defaultStridesCache.getOrPut(shape) { DefaultStrides(shape) }
//TODO fix cache
}
}

View File

@ -6,34 +6,27 @@
package space.kscience.kmath.nd package space.kscience.kmath.nd
import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.NumbersAddOperations import space.kscience.kmath.operations.NumbersAddOps
import space.kscience.kmath.operations.ShortRing import space.kscience.kmath.operations.ShortRing
import space.kscience.kmath.structures.Buffer import space.kscience.kmath.operations.bufferAlgebra
import space.kscience.kmath.structures.ShortBuffer
import kotlin.contracts.InvocationKind import kotlin.contracts.InvocationKind
import kotlin.contracts.contract import kotlin.contracts.contract
public sealed class ShortRingOpsND : BufferedRingOpsND<Short, ShortRing>(ShortRing.bufferAlgebra) {
public companion object : ShortRingOpsND()
}
@OptIn(UnstableKMathAPI::class) @OptIn(UnstableKMathAPI::class)
public class ShortRingND( public class ShortRingND(
shape: IntArray, override val shape: Shape
) : BufferedRingND<Short, ShortRing>(shape, ShortRing, Buffer.Companion::auto), ) : ShortRingOpsND(), RingND<Short, ShortRing>, NumbersAddOps<StructureND<Short>> {
NumbersAddOperations<StructureND<Short>> {
override val zero: BufferND<Short> by lazy { produce { zero } }
override val one: BufferND<Short> by lazy { produce { one } }
override fun number(value: Number): BufferND<Short> { override fun number(value: Number): BufferND<Short> {
val d = value.toShort() // minimize conversions val d = value.toShort() // minimize conversions
return produce { d } return produce(shape) { d }
} }
} }
/**
* Fast element production using function inlining.
*/
public inline fun BufferedRingND<Short, ShortRing>.produceInline(crossinline initializer: ShortRing.(Int) -> Short): BufferND<Short> =
BufferND(strides, ShortBuffer(ShortArray(strides.linearSize) { offset -> ShortRing.initializer(offset) }))
public inline fun <R> ShortRing.withNdAlgebra(vararg shape: Int, action: ShortRingND.() -> R): R { public inline fun <R> ShortRing.withNdAlgebra(vararg shape: Int, action: ShortRingND.() -> R): R {
contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) } contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) }
return ShortRingND(shape).run(action) return ShortRingND(shape).run(action)

View File

@ -15,7 +15,6 @@ import space.kscience.kmath.structures.Buffer
import space.kscience.kmath.structures.BufferFactory import space.kscience.kmath.structures.BufferFactory
import kotlin.jvm.JvmName import kotlin.jvm.JvmName
import kotlin.math.abs import kotlin.math.abs
import kotlin.native.concurrent.ThreadLocal
import kotlin.reflect.KClass import kotlin.reflect.KClass
public interface StructureFeature : Feature<StructureFeature> public interface StructureFeature : Feature<StructureFeature>
@ -72,7 +71,7 @@ public interface StructureND<out T> : Featured<StructureFeature> {
if (st1 === st2) return true if (st1 === st2) return true
// fast comparison of buffers if possible // fast comparison of buffers if possible
if (st1 is BufferND && st2 is BufferND && st1.strides == st2.strides) if (st1 is BufferND && st2 is BufferND && st1.indexes == st2.indexes)
return Buffer.contentEquals(st1.buffer, st2.buffer) return Buffer.contentEquals(st1.buffer, st2.buffer)
//element by element comparison if it could not be avoided //element by element comparison if it could not be avoided
@ -88,7 +87,7 @@ public interface StructureND<out T> : Featured<StructureFeature> {
if (st1 === st2) return true if (st1 === st2) return true
// fast comparison of buffers if possible // fast comparison of buffers if possible
if (st1 is BufferND && st2 is BufferND && st1.strides == st2.strides) if (st1 is BufferND && st2 is BufferND && st1.indexes == st2.indexes)
return Buffer.contentEquals(st1.buffer, st2.buffer) return Buffer.contentEquals(st1.buffer, st2.buffer)
//element by element comparison if it could not be avoided //element by element comparison if it could not be avoided
@ -187,11 +186,11 @@ public fun <T : Comparable<T>> LinearSpace<T, Ring<T>>.contentEquals(
* Indicates whether some [StructureND] is equal to another one with [absoluteTolerance]. * Indicates whether some [StructureND] is equal to another one with [absoluteTolerance].
*/ */
@PerformancePitfall @PerformancePitfall
public fun <T : Comparable<T>> GroupND<T, Ring<T>>.contentEquals( public fun <T : Comparable<T>> GroupOpsND<T, Ring<T>>.contentEquals(
st1: StructureND<T>, st1: StructureND<T>,
st2: StructureND<T>, st2: StructureND<T>,
absoluteTolerance: T, absoluteTolerance: T,
): Boolean = st1.elements().all { (index, value) -> elementContext { (value - st2[index]) } < absoluteTolerance } ): Boolean = st1.elements().all { (index, value) -> elementAlgebra { (value - st2[index]) } < absoluteTolerance }
/** /**
* Indicates whether some [StructureND] is equal to another one with [absoluteTolerance]. * Indicates whether some [StructureND] is equal to another one with [absoluteTolerance].
@ -231,107 +230,10 @@ public interface MutableStructureND<T> : StructureND<T> {
* Transform a structure element-by element in place. * Transform a structure element-by element in place.
*/ */
@OptIn(PerformancePitfall::class) @OptIn(PerformancePitfall::class)
public inline fun <T> MutableStructureND<T>.mapInPlace(action: (IntArray, T) -> T): Unit = public inline fun <T> MutableStructureND<T>.mapInPlace(action: (index: IntArray, t: T) -> T): Unit =
elements().forEach { (index, oldValue) -> this[index] = action(index, oldValue) } elements().forEach { (index, oldValue) -> this[index] = action(index, oldValue) }
/** public inline fun <reified T : Any> StructureND<T>.zip(
* A way to convert ND indices to linear one and back.
*/
public interface Strides {
/**
* Shape of NDStructure
*/
public val shape: IntArray
/**
* Array strides
*/
public val strides: IntArray
/**
* Get linear index from multidimensional index
*/
public fun offset(index: IntArray): Int = index.mapIndexed { i, value ->
if (value < 0 || value >= shape[i]) throw IndexOutOfBoundsException("Index $value out of shape bounds: (0,${this.shape[i]})")
value * strides[i]
}.sum()
/**
* Get multidimensional from linear
*/
public fun index(offset: Int): IntArray
/**
* The size of linear buffer to accommodate all elements of ND-structure corresponding to strides
*/
public val linearSize: Int
// TODO introduce a fast way to calculate index of the next element?
/**
* Iterate over ND indices in a natural order
*/
public fun indices(): Sequence<IntArray> = (0 until linearSize).asSequence().map(::index)
}
/**
* Simple implementation of [Strides].
*/
public class DefaultStrides private constructor(override val shape: IntArray) : Strides {
override val linearSize: Int
get() = strides[shape.size]
/**
* Strides for memory access
*/
override val strides: IntArray by lazy {
sequence {
var current = 1
yield(1)
shape.forEach {
current *= it
yield(current)
}
}.toList().toIntArray()
}
override fun index(offset: Int): IntArray {
val res = IntArray(shape.size)
var current = offset
var strideIndex = strides.size - 2
while (strideIndex >= 0) {
res[strideIndex] = (current / strides[strideIndex])
current %= strides[strideIndex]
strideIndex--
}
return res
}
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (other !is DefaultStrides) return false
if (!shape.contentEquals(other.shape)) return false
return true
}
override fun hashCode(): Int = shape.contentHashCode()
@ThreadLocal
public companion object {
private val defaultStridesCache = HashMap<IntArray, Strides>()
/**
* Cached builder for default strides
*/
public operator fun invoke(shape: IntArray): Strides =
defaultStridesCache.getOrPut(shape) { DefaultStrides(shape) }
}
}
public inline fun <reified T : Any> StructureND<T>.combine(
struct: StructureND<T>, struct: StructureND<T>,
crossinline block: (T, T) -> T, crossinline block: (T, T) -> T,
): StructureND<T> { ): StructureND<T> {

View File

@ -0,0 +1,34 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
*/
package space.kscience.kmath.nd
import space.kscience.kmath.operations.Algebra
import space.kscience.kmath.operations.Group
import space.kscience.kmath.operations.Ring
import kotlin.jvm.JvmName
public fun <T, A : Algebra<T>> AlgebraND<T, A>.produce(
shapeFirst: Int,
vararg shapeRest: Int,
initializer: A.(IntArray) -> T
): StructureND<T> = produce(Shape(shapeFirst, *shapeRest), initializer)
public fun <T, A : Group<T>> AlgebraND<T, A>.zero(shape: Shape): StructureND<T> = produce(shape) { zero }
@JvmName("zeroVarArg")
public fun <T, A : Group<T>> AlgebraND<T, A>.zero(
shapeFirst: Int,
vararg shapeRest: Int,
): StructureND<T> = produce(shapeFirst, *shapeRest) { zero }
public fun <T, A : Ring<T>> AlgebraND<T, A>.one(shape: Shape): StructureND<T> = produce(shape) { one }
@JvmName("oneVarArg")
public fun <T, A : Ring<T>> AlgebraND<T, A>.one(
shapeFirst: Int,
vararg shapeRest: Int,
): StructureND<T> = produce(shapeFirst, *shapeRest) { one }

View File

@ -117,7 +117,7 @@ public inline operator fun <A : Algebra<*>, R> A.invoke(block: A.() -> R): R = r
* *
* @param T the type of element of this semispace. * @param T the type of element of this semispace.
*/ */
public interface GroupOperations<T> : Algebra<T> { public interface GroupOps<T> : Algebra<T> {
/** /**
* Addition of two elements. * Addition of two elements.
* *
@ -162,7 +162,7 @@ public interface GroupOperations<T> : Algebra<T> {
* @return the difference. * @return the difference.
*/ */
public operator fun T.minus(b: T): T = add(this, -b) public operator fun T.minus(b: T): T = add(this, -b)
// Dynamic dispatch of operations
override fun unaryOperationFunction(operation: String): (arg: T) -> T = when (operation) { override fun unaryOperationFunction(operation: String): (arg: T) -> T = when (operation) {
PLUS_OPERATION -> { arg -> +arg } PLUS_OPERATION -> { arg -> +arg }
MINUS_OPERATION -> { arg -> -arg } MINUS_OPERATION -> { arg -> -arg }
@ -193,7 +193,7 @@ public interface GroupOperations<T> : Algebra<T> {
* *
* @param T the type of element of this semispace. * @param T the type of element of this semispace.
*/ */
public interface Group<T> : GroupOperations<T> { public interface Group<T> : GroupOps<T> {
/** /**
* The neutral element of addition. * The neutral element of addition.
*/ */
@ -206,7 +206,7 @@ public interface Group<T> : GroupOperations<T> {
* *
* @param T the type of element of this semiring. * @param T the type of element of this semiring.
*/ */
public interface RingOperations<T> : GroupOperations<T> { public interface RingOps<T> : GroupOps<T> {
/** /**
* Multiplies two elements. * Multiplies two elements.
* *
@ -242,7 +242,7 @@ public interface RingOperations<T> : GroupOperations<T> {
* *
* @param T the type of element of this ring. * @param T the type of element of this ring.
*/ */
public interface Ring<T> : Group<T>, RingOperations<T> { public interface Ring<T> : Group<T>, RingOps<T> {
/** /**
* The neutral element of multiplication * The neutral element of multiplication
*/ */
@ -256,7 +256,7 @@ public interface Ring<T> : Group<T>, RingOperations<T> {
* *
* @param T the type of element of this semifield. * @param T the type of element of this semifield.
*/ */
public interface FieldOperations<T> : RingOperations<T> { public interface FieldOps<T> : RingOps<T> {
/** /**
* Division of two elements. * Division of two elements.
* *
@ -295,6 +295,6 @@ public interface FieldOperations<T> : RingOperations<T> {
* *
* @param T the type of element of this field. * @param T the type of element of this field.
*/ */
public interface Field<T> : Ring<T>, FieldOperations<T>, ScaleOperations<T>, NumericAlgebra<T> { public interface Field<T> : Ring<T>, FieldOps<T>, ScaleOperations<T>, NumericAlgebra<T> {
override fun number(value: Number): T = scale(one, value.toDouble()) override fun number(value: Number): T = scale(one, value.toDouble())
} }

View File

@ -6,7 +6,7 @@
package space.kscience.kmath.operations package space.kscience.kmath.operations
import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.nd.BufferedRingND import space.kscience.kmath.nd.BufferedRingOpsND
import space.kscience.kmath.operations.BigInt.Companion.BASE import space.kscience.kmath.operations.BigInt.Companion.BASE
import space.kscience.kmath.operations.BigInt.Companion.BASE_SIZE import space.kscience.kmath.operations.BigInt.Companion.BASE_SIZE
import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.Buffer
@ -26,7 +26,7 @@ private typealias TBase = ULong
* @author Peter Klimai * @author Peter Klimai
*/ */
@OptIn(UnstableKMathAPI::class) @OptIn(UnstableKMathAPI::class)
public object BigIntField : Field<BigInt>, NumbersAddOperations<BigInt>, ScaleOperations<BigInt> { public object BigIntField : Field<BigInt>, NumbersAddOps<BigInt>, ScaleOperations<BigInt> {
override val zero: BigInt = BigInt.ZERO override val zero: BigInt = BigInt.ZERO
override val one: BigInt = BigInt.ONE override val one: BigInt = BigInt.ONE
@ -542,5 +542,5 @@ public inline fun MutableBuffer.Companion.bigInt(size: Int, initializer: (Int) -
public inline fun BigInt.mutableBuffer(size: Int, initializer: (Int) -> BigInt): Buffer<BigInt> = public inline fun BigInt.mutableBuffer(size: Int, initializer: (Int) -> BigInt): Buffer<BigInt> =
Buffer.boxing(size, initializer) Buffer.boxing(size, initializer)
public fun BigIntField.nd(vararg shape: Int): BufferedRingND<BigInt, BigIntField> = public val BigIntField.nd: BufferedRingOpsND<BigInt, BigIntField>
BufferedRingND(shape, BigIntField, BigInt::buffer) get() = BufferedRingOpsND(BufferRingOps(BigIntField, BigInt::buffer))

View File

@ -5,32 +5,34 @@
package space.kscience.kmath.operations package space.kscience.kmath.operations
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.Buffer
import space.kscience.kmath.structures.BufferFactory import space.kscience.kmath.structures.BufferFactory
import space.kscience.kmath.structures.DoubleBuffer import space.kscience.kmath.structures.DoubleBuffer
import space.kscience.kmath.structures.ShortBuffer
public interface WithSize {
public val size: Int
}
/** /**
* An algebra over [Buffer] * An algebra over [Buffer]
*/ */
@UnstableKMathAPI public interface BufferAlgebra<T, out A : Algebra<T>> : Algebra<Buffer<T>> {
public interface BufferAlgebra<T, A : Algebra<T>> : Algebra<Buffer<T>> {
public val bufferFactory: BufferFactory<T>
public val elementAlgebra: A public val elementAlgebra: A
public val size: Int public val bufferFactory: BufferFactory<T>
public fun buffer(vararg elements: T): Buffer<T> { public fun buffer(size: Int, vararg elements: T): Buffer<T> {
require(elements.size == size) { "Expected $size elements but found ${elements.size}" } require(elements.size == size) { "Expected $size elements but found ${elements.size}" }
return bufferFactory(size) { elements[it] } return bufferFactory(size) { elements[it] }
} }
//TODO move to multi-receiver inline extension //TODO move to multi-receiver inline extension
public fun Buffer<T>.map(block: (T) -> T): Buffer<T> = bufferFactory(size) { block(get(it)) } public fun Buffer<T>.map(block: A.(T) -> T): Buffer<T> = mapInline(this, block)
public fun Buffer<T>.zip(other: Buffer<T>, block: (left: T, right: T) -> T): Buffer<T> { public fun Buffer<T>.mapIndexed(block: A.(index: Int, arg: T) -> T): Buffer<T> = mapIndexedInline(this, block)
require(size == other.size) { "Incompatible buffer sizes. left: $size, right: ${other.size}" }
return bufferFactory(size) { block(this[it], other[it]) } public fun Buffer<T>.zip(other: Buffer<T>, block: A.(left: T, right: T) -> T): Buffer<T> =
} zipInline(this, other, block)
override fun unaryOperationFunction(operation: String): (arg: Buffer<T>) -> Buffer<T> { override fun unaryOperationFunction(operation: String): (arg: Buffer<T>) -> Buffer<T> {
val operationFunction = elementAlgebra.unaryOperationFunction(operation) val operationFunction = elementAlgebra.unaryOperationFunction(operation)
@ -45,112 +47,149 @@ public interface BufferAlgebra<T, A : Algebra<T>> : Algebra<Buffer<T>> {
} }
} }
@UnstableKMathAPI /**
public fun <T> BufferField<T, *>.buffer(initializer: (Int) -> T): Buffer<T> { * Inline map
*/
public inline fun <T, A : Algebra<T>> BufferAlgebra<T, A>.mapInline(
buffer: Buffer<T>,
crossinline block: A.(T) -> T
): Buffer<T> = bufferFactory(buffer.size) { elementAlgebra.block(buffer[it]) }
/**
* Inline map
*/
public inline fun <T, A : Algebra<T>> BufferAlgebra<T, A>.mapIndexedInline(
buffer: Buffer<T>,
crossinline block: A.(index: Int, arg: T) -> T
): Buffer<T> = bufferFactory(buffer.size) { elementAlgebra.block(it, buffer[it]) }
/**
* Inline zip
*/
public inline fun <T, A : Algebra<T>> BufferAlgebra<T, A>.zipInline(
l: Buffer<T>,
r: Buffer<T>,
crossinline block: A.(l: T, r: T) -> T
): Buffer<T> {
require(l.size == r.size) { "Incompatible buffer sizes. left: ${l.size}, right: ${r.size}" }
return bufferFactory(l.size) { elementAlgebra.block(l[it], r[it]) }
}
public fun <T> BufferAlgebra<T, *>.buffer(size: Int, initializer: (Int) -> T): Buffer<T> {
return bufferFactory(size, initializer)
}
public fun <T, A> A.buffer(initializer: (Int) -> T): Buffer<T> where A : BufferAlgebra<T, *>, A : WithSize {
return bufferFactory(size, initializer) return bufferFactory(size, initializer)
} }
@UnstableKMathAPI
public fun <T, A : TrigonometricOperations<T>> BufferAlgebra<T, A>.sin(arg: Buffer<T>): Buffer<T> = public fun <T, A : TrigonometricOperations<T>> BufferAlgebra<T, A>.sin(arg: Buffer<T>): Buffer<T> =
arg.map(elementAlgebra::sin) mapInline(arg) { sin(it) }
@UnstableKMathAPI
public fun <T, A : TrigonometricOperations<T>> BufferAlgebra<T, A>.cos(arg: Buffer<T>): Buffer<T> = public fun <T, A : TrigonometricOperations<T>> BufferAlgebra<T, A>.cos(arg: Buffer<T>): Buffer<T> =
arg.map(elementAlgebra::cos) mapInline(arg) { cos(it) }
@UnstableKMathAPI
public fun <T, A : TrigonometricOperations<T>> BufferAlgebra<T, A>.tan(arg: Buffer<T>): Buffer<T> = public fun <T, A : TrigonometricOperations<T>> BufferAlgebra<T, A>.tan(arg: Buffer<T>): Buffer<T> =
arg.map(elementAlgebra::tan) mapInline(arg) { tan(it) }
@UnstableKMathAPI
public fun <T, A : TrigonometricOperations<T>> BufferAlgebra<T, A>.asin(arg: Buffer<T>): Buffer<T> = public fun <T, A : TrigonometricOperations<T>> BufferAlgebra<T, A>.asin(arg: Buffer<T>): Buffer<T> =
arg.map(elementAlgebra::asin) mapInline(arg) { asin(it) }
@UnstableKMathAPI
public fun <T, A : TrigonometricOperations<T>> BufferAlgebra<T, A>.acos(arg: Buffer<T>): Buffer<T> = public fun <T, A : TrigonometricOperations<T>> BufferAlgebra<T, A>.acos(arg: Buffer<T>): Buffer<T> =
arg.map(elementAlgebra::acos) mapInline(arg) { acos(it) }
@UnstableKMathAPI
public fun <T, A : TrigonometricOperations<T>> BufferAlgebra<T, A>.atan(arg: Buffer<T>): Buffer<T> = public fun <T, A : TrigonometricOperations<T>> BufferAlgebra<T, A>.atan(arg: Buffer<T>): Buffer<T> =
arg.map(elementAlgebra::atan) mapInline(arg) { atan(it) }
@UnstableKMathAPI
public fun <T, A : ExponentialOperations<T>> BufferAlgebra<T, A>.exp(arg: Buffer<T>): Buffer<T> = public fun <T, A : ExponentialOperations<T>> BufferAlgebra<T, A>.exp(arg: Buffer<T>): Buffer<T> =
arg.map(elementAlgebra::exp) mapInline(arg) { exp(it) }
@UnstableKMathAPI
public fun <T, A : ExponentialOperations<T>> BufferAlgebra<T, A>.ln(arg: Buffer<T>): Buffer<T> = public fun <T, A : ExponentialOperations<T>> BufferAlgebra<T, A>.ln(arg: Buffer<T>): Buffer<T> =
arg.map(elementAlgebra::ln) mapInline(arg) { ln(it) }
@UnstableKMathAPI
public fun <T, A : ExponentialOperations<T>> BufferAlgebra<T, A>.sinh(arg: Buffer<T>): Buffer<T> = public fun <T, A : ExponentialOperations<T>> BufferAlgebra<T, A>.sinh(arg: Buffer<T>): Buffer<T> =
arg.map(elementAlgebra::sinh) mapInline(arg) { sinh(it) }
@UnstableKMathAPI
public fun <T, A : ExponentialOperations<T>> BufferAlgebra<T, A>.cosh(arg: Buffer<T>): Buffer<T> = public fun <T, A : ExponentialOperations<T>> BufferAlgebra<T, A>.cosh(arg: Buffer<T>): Buffer<T> =
arg.map(elementAlgebra::cosh) mapInline(arg) { cosh(it) }
@UnstableKMathAPI
public fun <T, A : ExponentialOperations<T>> BufferAlgebra<T, A>.tanh(arg: Buffer<T>): Buffer<T> = public fun <T, A : ExponentialOperations<T>> BufferAlgebra<T, A>.tanh(arg: Buffer<T>): Buffer<T> =
arg.map(elementAlgebra::tanh) mapInline(arg) { tanh(it) }
@UnstableKMathAPI
public fun <T, A : ExponentialOperations<T>> BufferAlgebra<T, A>.asinh(arg: Buffer<T>): Buffer<T> = public fun <T, A : ExponentialOperations<T>> BufferAlgebra<T, A>.asinh(arg: Buffer<T>): Buffer<T> =
arg.map(elementAlgebra::asinh) mapInline(arg) { asinh(it) }
@UnstableKMathAPI
public fun <T, A : ExponentialOperations<T>> BufferAlgebra<T, A>.acosh(arg: Buffer<T>): Buffer<T> = public fun <T, A : ExponentialOperations<T>> BufferAlgebra<T, A>.acosh(arg: Buffer<T>): Buffer<T> =
arg.map(elementAlgebra::acosh) mapInline(arg) { acosh(it) }
@UnstableKMathAPI
public fun <T, A : ExponentialOperations<T>> BufferAlgebra<T, A>.atanh(arg: Buffer<T>): Buffer<T> = public fun <T, A : ExponentialOperations<T>> BufferAlgebra<T, A>.atanh(arg: Buffer<T>): Buffer<T> =
arg.map(elementAlgebra::atanh) mapInline(arg) { atanh(it) }
@UnstableKMathAPI
public fun <T, A : PowerOperations<T>> BufferAlgebra<T, A>.pow(arg: Buffer<T>, pow: Number): Buffer<T> = public fun <T, A : PowerOperations<T>> BufferAlgebra<T, A>.pow(arg: Buffer<T>, pow: Number): Buffer<T> =
with(elementAlgebra) { arg.map { power(it, pow) } } mapInline(arg) { power(it, pow) }
@UnstableKMathAPI public open class BufferRingOps<T, A: Ring<T>>(
public class BufferField<T, A : Field<T>>(
override val bufferFactory: BufferFactory<T>,
override val elementAlgebra: A, override val elementAlgebra: A,
override val bufferFactory: BufferFactory<T>,
) : BufferAlgebra<T, A>, RingOps<Buffer<T>>{
override fun add(a: Buffer<T>, b: Buffer<T>): Buffer<T> = zipInline(a, b) { l, r -> l + r }
override fun multiply(a: Buffer<T>, b: Buffer<T>): Buffer<T> = zipInline(a, b) { l, r -> l * r }
override fun Buffer<T>.unaryMinus(): Buffer<T> = map { -it }
override fun unaryOperationFunction(operation: String): (arg: Buffer<T>) -> Buffer<T> =
super<BufferAlgebra>.unaryOperationFunction(operation)
override fun binaryOperationFunction(operation: String): (left: Buffer<T>, right: Buffer<T>) -> Buffer<T> =
super<BufferAlgebra>.binaryOperationFunction(operation)
}
public val ShortRing.bufferAlgebra: BufferRingOps<Short, ShortRing>
get() = BufferRingOps(ShortRing, ::ShortBuffer)
public open class BufferFieldOps<T, A : Field<T>>(
elementAlgebra: A,
bufferFactory: BufferFactory<T>,
) : BufferRingOps<T, A>(elementAlgebra, bufferFactory), BufferAlgebra<T, A>, FieldOps<Buffer<T>>, ScaleOperations<Buffer<T>> {
override fun add(a: Buffer<T>, b: Buffer<T>): Buffer<T> = zipInline(a, b) { l, r -> l + r }
override fun multiply(a: Buffer<T>, b: Buffer<T>): Buffer<T> = zipInline(a, b) { l, r -> l * r }
override fun divide(a: Buffer<T>, b: Buffer<T>): Buffer<T> = zipInline(a, b) { l, r -> l / r }
override fun scale(a: Buffer<T>, value: Double): Buffer<T> = a.map { scale(it, value) }
override fun Buffer<T>.unaryMinus(): Buffer<T> = map { -it }
override fun binaryOperationFunction(operation: String): (left: Buffer<T>, right: Buffer<T>) -> Buffer<T> =
super<BufferRingOps>.binaryOperationFunction(operation)
}
public class BufferField<T, A : Field<T>>(
elementAlgebra: A,
bufferFactory: BufferFactory<T>,
override val size: Int override val size: Int
) : BufferAlgebra<T, A>, Field<Buffer<T>> { ) : BufferFieldOps<T, A>(elementAlgebra, bufferFactory), Field<Buffer<T>>, WithSize {
override val zero: Buffer<T> = bufferFactory(size) { elementAlgebra.zero } override val zero: Buffer<T> = bufferFactory(size) { elementAlgebra.zero }
override val one: Buffer<T> = bufferFactory(size) { elementAlgebra.one } override val one: Buffer<T> = bufferFactory(size) { elementAlgebra.one }
override fun add(a: Buffer<T>, b: Buffer<T>): Buffer<T> = a.zip(b, elementAlgebra::add)
override fun multiply(a: Buffer<T>, b: Buffer<T>): Buffer<T> = a.zip(b, elementAlgebra::multiply)
override fun divide(a: Buffer<T>, b: Buffer<T>): Buffer<T> = a.zip(b, elementAlgebra::divide)
override fun scale(a: Buffer<T>, value: Double): Buffer<T> = with(elementAlgebra) { a.map { scale(it, value) } }
override fun Buffer<T>.unaryMinus(): Buffer<T> = with(elementAlgebra) { map { -it } }
override fun unaryOperationFunction(operation: String): (arg: Buffer<T>) -> Buffer<T> {
return super<BufferAlgebra>.unaryOperationFunction(operation)
}
override fun binaryOperationFunction(operation: String): (left: Buffer<T>, right: Buffer<T>) -> Buffer<T> {
return super<BufferAlgebra>.binaryOperationFunction(operation)
}
} }
/**
* Generate full buffer field from given buffer operations
*/
public fun <T, A : Field<T>> BufferFieldOps<T, A>.withSize(size: Int): BufferField<T, A> =
BufferField(elementAlgebra, bufferFactory, size)
//Double buffer specialization //Double buffer specialization
@UnstableKMathAPI
public fun BufferField<Double, *>.buffer(vararg elements: Number): Buffer<Double> { public fun BufferField<Double, *>.buffer(vararg elements: Number): Buffer<Double> {
require(elements.size == size) { "Expected $size elements but found ${elements.size}" } require(elements.size == size) { "Expected $size elements but found ${elements.size}" }
return bufferFactory(size) { elements[it].toDouble() } return bufferFactory(size) { elements[it].toDouble() }
} }
@UnstableKMathAPI public fun <T, A : Field<T>> A.bufferAlgebra(bufferFactory: BufferFactory<T>): BufferFieldOps<T, A> =
public fun <T, A : Field<T>> A.bufferAlgebra(bufferFactory: BufferFactory<T>, size: Int): BufferField<T, A> = BufferFieldOps(this, bufferFactory)
BufferField(bufferFactory, this, size)
@UnstableKMathAPI public val DoubleField.bufferAlgebra: BufferFieldOps<Double, DoubleField>
public fun DoubleField.bufferAlgebra(size: Int): BufferField<Double, DoubleField> = get() = BufferFieldOps(DoubleField, ::DoubleBuffer)
BufferField(::DoubleBuffer, DoubleField, size)

View File

@ -13,21 +13,21 @@ import space.kscience.kmath.structures.DoubleBuffer
* *
* @property size the size of buffers to operate on. * @property size the size of buffers to operate on.
*/ */
public class DoubleBufferField(public val size: Int) : ExtendedField<Buffer<Double>>, DoubleBufferOperations() { public class DoubleBufferField(public val size: Int) : ExtendedField<Buffer<Double>>, DoubleBufferOps() {
override val zero: Buffer<Double> by lazy { DoubleBuffer(size) { 0.0 } } override val zero: Buffer<Double> by lazy { DoubleBuffer(size) { 0.0 } }
override val one: Buffer<Double> by lazy { DoubleBuffer(size) { 1.0 } } override val one: Buffer<Double> by lazy { DoubleBuffer(size) { 1.0 } }
override fun sinh(arg: Buffer<Double>): DoubleBuffer = super<DoubleBufferOperations>.sinh(arg) override fun sinh(arg: Buffer<Double>): DoubleBuffer = super<DoubleBufferOps>.sinh(arg)
override fun cosh(arg: Buffer<Double>): DoubleBuffer = super<DoubleBufferOperations>.cosh(arg) override fun cosh(arg: Buffer<Double>): DoubleBuffer = super<DoubleBufferOps>.cosh(arg)
override fun tanh(arg: Buffer<Double>): DoubleBuffer = super<DoubleBufferOperations>.tanh(arg) override fun tanh(arg: Buffer<Double>): DoubleBuffer = super<DoubleBufferOps>.tanh(arg)
override fun asinh(arg: Buffer<Double>): DoubleBuffer = super<DoubleBufferOperations>.asinh(arg) override fun asinh(arg: Buffer<Double>): DoubleBuffer = super<DoubleBufferOps>.asinh(arg)
override fun acosh(arg: Buffer<Double>): DoubleBuffer = super<DoubleBufferOperations>.acosh(arg) override fun acosh(arg: Buffer<Double>): DoubleBuffer = super<DoubleBufferOps>.acosh(arg)
override fun atanh(arg: Buffer<Double>): DoubleBuffer= super<DoubleBufferOperations>.atanh(arg) override fun atanh(arg: Buffer<Double>): DoubleBuffer= super<DoubleBufferOps>.atanh(arg)
// override fun number(value: Number): Buffer<Double> = DoubleBuffer(size) { value.toDouble() } // override fun number(value: Number): Buffer<Double> = DoubleBuffer(size) { value.toDouble() }
// //

View File

@ -12,9 +12,9 @@ import space.kscience.kmath.structures.DoubleBuffer
import kotlin.math.* import kotlin.math.*
/** /**
* [ExtendedFieldOperations] over [DoubleBuffer]. * [ExtendedFieldOps] over [DoubleBuffer].
*/ */
public abstract class DoubleBufferOperations : ExtendedFieldOperations<Buffer<Double>>, Norm<Buffer<Double>, Double> { public abstract class DoubleBufferOps : ExtendedFieldOps<Buffer<Double>>, Norm<Buffer<Double>, Double> {
override fun Buffer<Double>.unaryMinus(): DoubleBuffer = if (this is DoubleBuffer) { override fun Buffer<Double>.unaryMinus(): DoubleBuffer = if (this is DoubleBuffer) {
DoubleBuffer(size) { -array[it] } DoubleBuffer(size) { -array[it] }
} else { } else {
@ -185,7 +185,7 @@ public abstract class DoubleBufferOperations : ExtendedFieldOperations<Buffer<Do
DoubleBuffer(DoubleArray(a.size) { aArray[it] * value }) DoubleBuffer(DoubleArray(a.size) { aArray[it] * value })
} else DoubleBuffer(DoubleArray(a.size) { a[it] * value }) } else DoubleBuffer(DoubleArray(a.size) { a[it] * value })
public companion object : DoubleBufferOperations() public companion object : DoubleBufferOps()
} }
public object DoubleL2Norm : Norm<Point<Double>, Double> { public object DoubleL2Norm : Norm<Point<Double>, Double> {

View File

@ -150,7 +150,7 @@ public interface ScaleOperations<T> : Algebra<T> {
* TODO to be removed and replaced by extensions after multiple receivers are there * TODO to be removed and replaced by extensions after multiple receivers are there
*/ */
@UnstableKMathAPI @UnstableKMathAPI
public interface NumbersAddOperations<T> : Ring<T>, NumericAlgebra<T> { public interface NumbersAddOps<T> : Ring<T>, NumericAlgebra<T> {
/** /**
* Addition of element and scalar. * Addition of element and scalar.
* *

View File

@ -10,8 +10,8 @@ import kotlin.math.pow as kpow
/** /**
* Advanced Number-like semifield that implements basic operations. * Advanced Number-like semifield that implements basic operations.
*/ */
public interface ExtendedFieldOperations<T> : public interface ExtendedFieldOps<T> :
FieldOperations<T>, FieldOps<T>,
TrigonometricOperations<T>, TrigonometricOperations<T>,
PowerOperations<T>, PowerOperations<T>,
ExponentialOperations<T>, ExponentialOperations<T>,
@ -35,14 +35,14 @@ public interface ExtendedFieldOperations<T> :
ExponentialOperations.ACOSH_OPERATION -> ::acosh ExponentialOperations.ACOSH_OPERATION -> ::acosh
ExponentialOperations.ASINH_OPERATION -> ::asinh ExponentialOperations.ASINH_OPERATION -> ::asinh
ExponentialOperations.ATANH_OPERATION -> ::atanh ExponentialOperations.ATANH_OPERATION -> ::atanh
else -> super<FieldOperations>.unaryOperationFunction(operation) else -> super<FieldOps>.unaryOperationFunction(operation)
} }
} }
/** /**
* Advanced Number-like field that implements basic operations. * Advanced Number-like field that implements basic operations.
*/ */
public interface ExtendedField<T> : ExtendedFieldOperations<T>, Field<T>, NumericAlgebra<T>{ public interface ExtendedField<T> : ExtendedFieldOps<T>, Field<T>, NumericAlgebra<T>{
override fun sinh(arg: T): T = (exp(arg) - exp(-arg)) / 2.0 override fun sinh(arg: T): T = (exp(arg) - exp(-arg)) / 2.0
override fun cosh(arg: T): T = (exp(arg) + exp(-arg)) / 2.0 override fun cosh(arg: T): T = (exp(arg) + exp(-arg)) / 2.0
override fun tanh(arg: T): T = (exp(arg) - exp(-arg)) / (exp(-arg) + exp(arg)) override fun tanh(arg: T): T = (exp(arg) - exp(-arg)) / (exp(-arg) + exp(arg))

View File

@ -7,6 +7,7 @@ package space.kscience.kmath.structures
import space.kscience.kmath.nd.get import space.kscience.kmath.nd.get
import space.kscience.kmath.nd.ndAlgebra import space.kscience.kmath.nd.ndAlgebra
import space.kscience.kmath.nd.produce
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.invoke import space.kscience.kmath.operations.invoke
import space.kscience.kmath.testutils.FieldVerifier import space.kscience.kmath.testutils.FieldVerifier
@ -21,7 +22,7 @@ internal class NDFieldTest {
@Test @Test
fun testStrides() { fun testStrides() {
val ndArray = DoubleField.ndAlgebra(10, 10).produce { (it[0] + it[1]).toDouble() } val ndArray = DoubleField.ndAlgebra.produce(10, 10) { (it[0] + it[1]).toDouble() }
assertEquals(ndArray[5, 5], 10.0) assertEquals(ndArray[5, 5], 10.0)
} }
} }

View File

@ -7,10 +7,7 @@ package space.kscience.kmath.structures
import space.kscience.kmath.linear.linearSpace import space.kscience.kmath.linear.linearSpace
import space.kscience.kmath.misc.PerformancePitfall import space.kscience.kmath.misc.PerformancePitfall
import space.kscience.kmath.nd.StructureND import space.kscience.kmath.nd.*
import space.kscience.kmath.nd.combine
import space.kscience.kmath.nd.get
import space.kscience.kmath.nd.ndAlgebra
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.Norm import space.kscience.kmath.operations.Norm
import space.kscience.kmath.operations.algebra import space.kscience.kmath.operations.algebra
@ -22,9 +19,9 @@ import kotlin.test.assertEquals
@Suppress("UNUSED_VARIABLE") @Suppress("UNUSED_VARIABLE")
class NumberNDFieldTest { class NumberNDFieldTest {
val algebra = DoubleField.ndAlgebra(3, 3) val algebra = DoubleField.ndAlgebra
val array1 = algebra.produce { (i, j) -> (i + j).toDouble() } val array1 = algebra.produce(3, 3) { (i, j) -> (i + j).toDouble() }
val array2 = algebra.produce { (i, j) -> (i - j).toDouble() } val array2 = algebra.produce(3, 3) { (i, j) -> (i - j).toDouble() }
@Test @Test
fun testSum() { fun testSum() {
@ -77,7 +74,7 @@ class NumberNDFieldTest {
@Test @Test
fun combineTest() { fun combineTest() {
val division = array1.combine(array2, Double::div) val division = array1.zip(array2, Double::div)
} }
object L2Norm : Norm<StructureND<Number>, Double> { object L2Norm : Norm<StructureND<Number>, Double> {

View File

@ -10,12 +10,12 @@ import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.map import kotlinx.coroutines.flow.map
import kotlinx.coroutines.flow.runningReduce import kotlinx.coroutines.flow.runningReduce
import kotlinx.coroutines.flow.scan import kotlinx.coroutines.flow.scan
import space.kscience.kmath.operations.GroupOperations import space.kscience.kmath.operations.GroupOps
import space.kscience.kmath.operations.Ring import space.kscience.kmath.operations.Ring
import space.kscience.kmath.operations.ScaleOperations import space.kscience.kmath.operations.ScaleOperations
import space.kscience.kmath.operations.invoke import space.kscience.kmath.operations.invoke
public fun <T> Flow<T>.cumulativeSum(group: GroupOperations<T>): Flow<T> = public fun <T> Flow<T>.cumulativeSum(group: GroupOps<T>): Flow<T> =
group { runningReduce { sum, element -> sum + element } } group { runningReduce { sum, element -> sum + element } }
@ExperimentalCoroutinesApi @ExperimentalCoroutinesApi

View File

@ -13,8 +13,8 @@ import space.kscience.kmath.structures.DoubleBuffer
* Map one [BufferND] using function without indices. * Map one [BufferND] using function without indices.
*/ */
public inline fun BufferND<Double>.mapInline(crossinline transform: DoubleField.(Double) -> Double): BufferND<Double> { public inline fun BufferND<Double>.mapInline(crossinline transform: DoubleField.(Double) -> Double): BufferND<Double> {
val array = DoubleArray(strides.linearSize) { offset -> DoubleField.transform(buffer[offset]) } val array = DoubleArray(indexes.linearSize) { offset -> DoubleField.transform(buffer[offset]) }
return BufferND(strides, DoubleBuffer(array)) return BufferND(indexes, DoubleBuffer(array))
} }
/** /**

View File

@ -28,10 +28,9 @@ public class DoubleHistogramSpace(
public val dimension: Int get() = lower.size public val dimension: Int get() = lower.size
private val shape = IntArray(binNums.size) { binNums[it] + 2 } override val shape: IntArray = IntArray(binNums.size) { binNums[it] + 2 }
override val histogramValueSpace: DoubleFieldND = DoubleField.ndAlgebra(*shape) override val histogramValueSpace: DoubleFieldND = DoubleField.ndAlgebra(*shape)
override val strides: Strides get() = histogramValueSpace.strides
private val binSize = DoubleBuffer(dimension) { (upper[it] - lower[it]) / binNums[it] } private val binSize = DoubleBuffer(dimension) { (upper[it] - lower[it]) / binNums[it] }
/** /**
@ -52,7 +51,7 @@ public class DoubleHistogramSpace(
val lowerBoundary = index.mapIndexed { axis, i -> val lowerBoundary = index.mapIndexed { axis, i ->
when (i) { when (i) {
0 -> Double.NEGATIVE_INFINITY 0 -> Double.NEGATIVE_INFINITY
strides.shape[axis] - 1 -> upper[axis] shape[axis] - 1 -> upper[axis]
else -> lower[axis] + (i.toDouble()) * binSize[axis] else -> lower[axis] + (i.toDouble()) * binSize[axis]
} }
}.asBuffer() }.asBuffer()
@ -60,7 +59,7 @@ public class DoubleHistogramSpace(
val upperBoundary = index.mapIndexed { axis, i -> val upperBoundary = index.mapIndexed { axis, i ->
when (i) { when (i) {
0 -> lower[axis] 0 -> lower[axis]
strides.shape[axis] - 1 -> Double.POSITIVE_INFINITY shape[axis] - 1 -> Double.POSITIVE_INFINITY
else -> lower[axis] + (i.toDouble() + 1) * binSize[axis] else -> lower[axis] + (i.toDouble() + 1) * binSize[axis]
} }
}.asBuffer() }.asBuffer()
@ -75,7 +74,7 @@ public class DoubleHistogramSpace(
} }
override fun produce(builder: HistogramBuilder<Double>.() -> Unit): IndexedHistogram<Double, Double> { override fun produce(builder: HistogramBuilder<Double>.() -> Unit): IndexedHistogram<Double, Double> {
val ndCounter = StructureND.auto(strides) { Counter.real() } val ndCounter = StructureND.auto(shape) { Counter.real() }
val hBuilder = HistogramBuilder<Double> { point, value -> val hBuilder = HistogramBuilder<Double> { point, value ->
val index = getIndex(point) val index = getIndex(point)
ndCounter[index].add(value.toDouble()) ndCounter[index].add(value.toDouble())

View File

@ -8,8 +8,9 @@ package space.kscience.kmath.histogram
import space.kscience.kmath.domains.Domain import space.kscience.kmath.domains.Domain
import space.kscience.kmath.linear.Point import space.kscience.kmath.linear.Point
import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.nd.DefaultStrides
import space.kscience.kmath.nd.FieldND import space.kscience.kmath.nd.FieldND
import space.kscience.kmath.nd.Strides import space.kscience.kmath.nd.Shape
import space.kscience.kmath.nd.StructureND import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.operations.Group import space.kscience.kmath.operations.Group
import space.kscience.kmath.operations.ScaleOperations import space.kscience.kmath.operations.ScaleOperations
@ -34,10 +35,10 @@ public class IndexedHistogram<T : Comparable<T>, V : Any>(
return context.produceBin(index, values[index]) return context.produceBin(index, values[index])
} }
override val dimension: Int get() = context.strides.shape.size override val dimension: Int get() = context.shape.size
override val bins: Iterable<Bin<T>> override val bins: Iterable<Bin<T>>
get() = context.strides.indices().map { get() = DefaultStrides(context.shape).indices().map {
context.produceBin(it, values[it]) context.produceBin(it, values[it])
}.asIterable() }.asIterable()
@ -49,7 +50,7 @@ public class IndexedHistogram<T : Comparable<T>, V : Any>(
public interface IndexedHistogramSpace<T : Comparable<T>, V : Any> public interface IndexedHistogramSpace<T : Comparable<T>, V : Any>
: Group<IndexedHistogram<T, V>>, ScaleOperations<IndexedHistogram<T, V>> { : Group<IndexedHistogram<T, V>>, ScaleOperations<IndexedHistogram<T, V>> {
//public val valueSpace: Space<V> //public val valueSpace: Space<V>
public val strides: Strides public val shape: Shape
public val histogramValueSpace: FieldND<V, *> //= NDAlgebra.space(valueSpace, Buffer.Companion::boxing, *shape), public val histogramValueSpace: FieldND<V, *> //= NDAlgebra.space(valueSpace, Buffer.Companion::boxing, *shape),
/** /**

View File

@ -5,6 +5,7 @@
package space.kscience.kmath.histogram package space.kscience.kmath.histogram
import space.kscience.kmath.nd.DefaultStrides
import space.kscience.kmath.operations.invoke import space.kscience.kmath.operations.invoke
import space.kscience.kmath.real.DoubleVector import space.kscience.kmath.real.DoubleVector
import kotlin.random.Random import kotlin.random.Random
@ -69,7 +70,7 @@ internal class MultivariateHistogramTest {
} }
val res = histogram1 - histogram2 val res = histogram1 - histogram2
assertTrue { assertTrue {
strides.indices().all { index -> DefaultStrides(shape).indices().all { index ->
res.values[index] <= histogram1.values[index] res.values[index] <= histogram1.values[index]
} }
} }

View File

@ -106,8 +106,8 @@ public fun <X : SFun<X>> MST.toSFun(): SFun<X> = when (this) {
is Symbol -> toSVar() is Symbol -> toSVar()
is MST.Unary -> when (operation) { is MST.Unary -> when (operation) {
GroupOperations.PLUS_OPERATION -> +value.toSFun<X>() GroupOps.PLUS_OPERATION -> +value.toSFun<X>()
GroupOperations.MINUS_OPERATION -> -value.toSFun<X>() GroupOps.MINUS_OPERATION -> -value.toSFun<X>()
TrigonometricOperations.SIN_OPERATION -> sin(value.toSFun()) TrigonometricOperations.SIN_OPERATION -> sin(value.toSFun())
TrigonometricOperations.COS_OPERATION -> cos(value.toSFun()) TrigonometricOperations.COS_OPERATION -> cos(value.toSFun())
TrigonometricOperations.TAN_OPERATION -> tan(value.toSFun()) TrigonometricOperations.TAN_OPERATION -> tan(value.toSFun())
@ -124,10 +124,10 @@ public fun <X : SFun<X>> MST.toSFun(): SFun<X> = when (this) {
} }
is MST.Binary -> when (operation) { is MST.Binary -> when (operation) {
GroupOperations.PLUS_OPERATION -> left.toSFun<X>() + right.toSFun() GroupOps.PLUS_OPERATION -> left.toSFun<X>() + right.toSFun()
GroupOperations.MINUS_OPERATION -> left.toSFun<X>() - right.toSFun() GroupOps.MINUS_OPERATION -> left.toSFun<X>() - right.toSFun()
RingOperations.TIMES_OPERATION -> left.toSFun<X>() * right.toSFun() RingOps.TIMES_OPERATION -> left.toSFun<X>() * right.toSFun()
FieldOperations.DIV_OPERATION -> left.toSFun<X>() / right.toSFun() FieldOps.DIV_OPERATION -> left.toSFun<X>() / right.toSFun()
PowerOperations.POW_OPERATION -> left.toSFun<X>() pow (right as MST.Numeric).toSConst() PowerOperations.POW_OPERATION -> left.toSFun<X>() pow (right as MST.Numeric).toSConst()
else -> error("Binary operation $operation not defined in $this") else -> error("Binary operation $operation not defined in $this")
} }

View File

@ -15,13 +15,6 @@ import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.nd.* import space.kscience.kmath.nd.*
import space.kscience.kmath.operations.* import space.kscience.kmath.operations.*
internal fun AlgebraND<*, *>.checkShape(array: INDArray): INDArray {
val arrayShape = array.shape().toIntArray()
if (!shape.contentEquals(arrayShape)) throw ShapeMismatchException(shape, arrayShape)
return array
}
/** /**
* Represents [AlgebraND] over [Nd4jArrayAlgebra]. * Represents [AlgebraND] over [Nd4jArrayAlgebra].
* *
@ -39,33 +32,34 @@ public sealed interface Nd4jArrayAlgebra<T, out C : Algebra<T>> : AlgebraND<T, C
*/ */
public val StructureND<T>.ndArray: INDArray public val StructureND<T>.ndArray: INDArray
override fun produce(initializer: C.(IntArray) -> T): Nd4jArrayStructure<T> { override fun produce(shape: Shape, initializer: C.(IntArray) -> T): Nd4jArrayStructure<T> {
val struct = Nd4j.create(*shape)!!.wrap() val struct = Nd4j.create(*shape)!!.wrap()
struct.indicesIterator().forEach { struct[it] = elementContext.initializer(it) } struct.indicesIterator().forEach { struct[it] = elementAlgebra.initializer(it) }
return struct return struct
} }
override fun StructureND<T>.map(transform: C.(T) -> T): Nd4jArrayStructure<T> { override fun StructureND<T>.map(transform: C.(T) -> T): Nd4jArrayStructure<T> {
val newStruct = ndArray.dup().wrap() val newStruct = ndArray.dup().wrap()
newStruct.elements().forEach { (idx, value) -> newStruct[idx] = elementContext.transform(value) } newStruct.elements().forEach { (idx, value) -> newStruct[idx] = elementAlgebra.transform(value) }
return newStruct return newStruct
} }
override fun StructureND<T>.mapIndexed( override fun StructureND<T>.mapIndexed(
transform: C.(index: IntArray, T) -> T, transform: C.(index: IntArray, T) -> T,
): Nd4jArrayStructure<T> { ): Nd4jArrayStructure<T> {
val new = Nd4j.create(*this@Nd4jArrayAlgebra.shape).wrap() val new = Nd4j.create(*shape).wrap()
new.indicesIterator().forEach { idx -> new[idx] = elementContext.transform(idx, this[idx]) } new.indicesIterator().forEach { idx -> new[idx] = elementAlgebra.transform(idx, this[idx]) }
return new return new
} }
override fun combine( override fun zip(
a: StructureND<T>, left: StructureND<T>,
b: StructureND<T>, right: StructureND<T>,
transform: C.(T, T) -> T, transform: C.(T, T) -> T,
): Nd4jArrayStructure<T> { ): Nd4jArrayStructure<T> {
val new = Nd4j.create(*shape).wrap() require(left.shape.contentEquals(right.shape)) { "Can't zip tow structures of shape ${left.shape} and ${right.shape}" }
new.indicesIterator().forEach { idx -> new[idx] = elementContext.transform(a[idx], b[idx]) } val new = Nd4j.create(*left.shape).wrap()
new.indicesIterator().forEach { idx -> new[idx] = elementAlgebra.transform(left[idx], right[idx]) }
return new return new
} }
} }
@ -76,10 +70,7 @@ public sealed interface Nd4jArrayAlgebra<T, out C : Algebra<T>> : AlgebraND<T, C
* @param T the type of the element contained in ND structure. * @param T the type of the element contained in ND structure.
* @param S the type of space of structure elements. * @param S the type of space of structure elements.
*/ */
public sealed interface Nd4jArrayGroup<T, out S : Ring<T>> : GroupND<T, S>, Nd4jArrayAlgebra<T, S> { public sealed interface Nd4jArrayGroupOps<T, out S : Ring<T>> : GroupOpsND<T, S>, Nd4jArrayAlgebra<T, S> {
override val zero: Nd4jArrayStructure<T>
get() = Nd4j.zeros(*shape).wrap()
override fun add(a: StructureND<T>, b: StructureND<T>): Nd4jArrayStructure<T> = override fun add(a: StructureND<T>, b: StructureND<T>): Nd4jArrayStructure<T> =
a.ndArray.add(b.ndArray).wrap() a.ndArray.add(b.ndArray).wrap()
@ -101,10 +92,7 @@ public sealed interface Nd4jArrayGroup<T, out S : Ring<T>> : GroupND<T, S>, Nd4j
* @param R the type of ring of structure elements. * @param R the type of ring of structure elements.
*/ */
@OptIn(UnstableKMathAPI::class) @OptIn(UnstableKMathAPI::class)
public sealed interface Nd4jArrayRing<T, out R : Ring<T>> : RingND<T, R>, Nd4jArrayGroup<T, R> { public sealed interface Nd4jArrayRingOps<T, out R : Ring<T>> : RingOpsND<T, R>, Nd4jArrayGroupOps<T, R> {
override val one: Nd4jArrayStructure<T>
get() = Nd4j.ones(*shape).wrap()
override fun multiply(a: StructureND<T>, b: StructureND<T>): Nd4jArrayStructure<T> = override fun multiply(a: StructureND<T>, b: StructureND<T>): Nd4jArrayStructure<T> =
a.ndArray.mul(b.ndArray).wrap() a.ndArray.mul(b.ndArray).wrap()
@ -125,21 +113,12 @@ public sealed interface Nd4jArrayRing<T, out R : Ring<T>> : RingND<T, R>, Nd4jAr
// } // }
public companion object { public companion object {
private val intNd4jArrayRingCache: ThreadLocal<MutableMap<IntArray, IntNd4jArrayRing>> =
ThreadLocal.withInitial(::HashMap)
/**
* Creates an [RingND] for [Int] values or pull it from cache if it was created previously.
*/
public fun int(vararg shape: Int): Nd4jArrayRing<Int, IntRing> =
intNd4jArrayRingCache.get().getOrPut(shape) { IntNd4jArrayRing(shape) }
/** /**
* Creates a most suitable implementation of [RingND] using reified class. * Creates a most suitable implementation of [RingND] using reified class.
*/ */
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
public inline fun <reified T : Number> auto(vararg shape: Int): Nd4jArrayRing<T, Ring<T>> = when { public inline fun <reified T : Number> auto(vararg shape: Int): Nd4jArrayRingOps<T, Ring<T>> = when {
T::class == Int::class -> int(*shape) as Nd4jArrayRing<T, Ring<T>> T::class == Int::class -> IntRing.nd4j as Nd4jArrayRingOps<T, Ring<T>>
else -> throw UnsupportedOperationException("This factory method only supports Long type.") else -> throw UnsupportedOperationException("This factory method only supports Long type.")
} }
} }
@ -151,38 +130,21 @@ public sealed interface Nd4jArrayRing<T, out R : Ring<T>> : RingND<T, R>, Nd4jAr
* @param T the type of the element contained in ND structure. * @param T the type of the element contained in ND structure.
* @param F the type field of structure elements. * @param F the type field of structure elements.
*/ */
public sealed interface Nd4jArrayField<T, out F : Field<T>> : FieldND<T, F>, Nd4jArrayRing<T, F> { public sealed interface Nd4jArrayField<T, out F : Field<T>> : FieldOpsND<T, F>, Nd4jArrayRingOps<T, F> {
override fun divide(a: StructureND<T>, b: StructureND<T>): Nd4jArrayStructure<T> = override fun divide(a: StructureND<T>, b: StructureND<T>): Nd4jArrayStructure<T> =
a.ndArray.div(b.ndArray).wrap() a.ndArray.div(b.ndArray).wrap()
public operator fun Number.div(b: StructureND<T>): Nd4jArrayStructure<T> = b.ndArray.rdiv(this).wrap() public operator fun Number.div(b: StructureND<T>): Nd4jArrayStructure<T> = b.ndArray.rdiv(this).wrap()
public companion object { public companion object {
private val floatNd4jArrayFieldCache: ThreadLocal<MutableMap<IntArray, FloatNd4jArrayField>> =
ThreadLocal.withInitial(::HashMap)
private val doubleNd4JArrayFieldCache: ThreadLocal<MutableMap<IntArray, DoubleNd4jArrayField>> =
ThreadLocal.withInitial(::HashMap)
/**
* Creates an [FieldND] for [Float] values or pull it from cache if it was created previously.
*/
public fun float(vararg shape: Int): Nd4jArrayRing<Float, FloatField> =
floatNd4jArrayFieldCache.get().getOrPut(shape) { FloatNd4jArrayField(shape) }
/**
* Creates an [FieldND] for [Double] values or pull it from cache if it was created previously.
*/
public fun real(vararg shape: Int): Nd4jArrayRing<Double, DoubleField> =
doubleNd4JArrayFieldCache.get().getOrPut(shape) { DoubleNd4jArrayField(shape) }
/** /**
* Creates a most suitable implementation of [FieldND] using reified class. * Creates a most suitable implementation of [FieldND] using reified class.
*/ */
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
public inline fun <reified T : Any> auto(vararg shape: Int): Nd4jArrayField<T, Field<T>> = when { public inline fun <reified T : Any> auto(vararg shape: Int): Nd4jArrayField<T, Field<T>> = when {
T::class == Float::class -> float(*shape) as Nd4jArrayField<T, Field<T>> T::class == Float::class -> FloatField.nd4j as Nd4jArrayField<T, Field<T>>
T::class == Double::class -> real(*shape) as Nd4jArrayField<T, Field<T>> T::class == Double::class -> DoubleField.nd4j as Nd4jArrayField<T, Field<T>>
else -> throw UnsupportedOperationException("This factory method only supports Float and Double types.") else -> throw UnsupportedOperationException("This factory method only supports Float and Double types.")
} }
} }
@ -191,8 +153,9 @@ public sealed interface Nd4jArrayField<T, out F : Field<T>> : FieldND<T, F>, Nd4
/** /**
* Represents intersection of [ExtendedField] and [Field] over [Nd4jArrayStructure]. * Represents intersection of [ExtendedField] and [Field] over [Nd4jArrayStructure].
*/ */
public sealed interface Nd4jArrayExtendedField<T, out F : ExtendedField<T>> : ExtendedField<StructureND<T>>, public sealed interface Nd4jArrayExtendedFieldOps<T, out F : ExtendedField<T>> :
Nd4jArrayField<T, F> { ExtendedFieldOps<StructureND<T>>, Nd4jArrayField<T, F> {
override fun sin(arg: StructureND<T>): StructureND<T> = Transforms.sin(arg.ndArray).wrap() override fun sin(arg: StructureND<T>): StructureND<T> = Transforms.sin(arg.ndArray).wrap()
override fun cos(arg: StructureND<T>): StructureND<T> = Transforms.cos(arg.ndArray).wrap() override fun cos(arg: StructureND<T>): StructureND<T> = Transforms.cos(arg.ndArray).wrap()
override fun asin(arg: StructureND<T>): StructureND<T> = Transforms.asin(arg.ndArray).wrap() override fun asin(arg: StructureND<T>): StructureND<T> = Transforms.asin(arg.ndArray).wrap()
@ -221,63 +184,59 @@ public sealed interface Nd4jArrayExtendedField<T, out F : ExtendedField<T>> : Ex
/** /**
* Represents [FieldND] over [Nd4jArrayDoubleStructure]. * Represents [FieldND] over [Nd4jArrayDoubleStructure].
*/ */
public class DoubleNd4jArrayField(override val shape: IntArray) : Nd4jArrayExtendedField<Double, DoubleField> { public open class DoubleNd4jArrayFieldOps : Nd4jArrayExtendedFieldOps<Double, DoubleField> {
override val elementContext: DoubleField get() = DoubleField override val elementAlgebra: DoubleField get() = DoubleField
override fun INDArray.wrap(): Nd4jArrayStructure<Double> = checkShape(this).asDoubleStructure() override fun INDArray.wrap(): Nd4jArrayStructure<Double> = asDoubleStructure()
@OptIn(PerformancePitfall::class) @OptIn(PerformancePitfall::class)
override val StructureND<Double>.ndArray: INDArray override val StructureND<Double>.ndArray: INDArray
get() = when (this) { get() = when (this) {
is Nd4jArrayStructure<Double> -> checkShape(ndArray) is Nd4jArrayStructure<Double> -> ndArray
else -> Nd4j.zeros(*shape).also { else -> Nd4j.zeros(*shape).also {
elements().forEach { (idx, value) -> it.putScalar(idx, value) } elements().forEach { (idx, value) -> it.putScalar(idx, value) }
} }
} }
override fun scale(a: StructureND<Double>, value: Double): Nd4jArrayStructure<Double> { override fun scale(a: StructureND<Double>, value: Double): Nd4jArrayStructure<Double> = a.ndArray.mul(value).wrap()
return a.ndArray.mul(value).wrap()
}
override operator fun StructureND<Double>.div(arg: Double): Nd4jArrayStructure<Double> { override operator fun StructureND<Double>.div(arg: Double): Nd4jArrayStructure<Double> = ndArray.div(arg).wrap()
return ndArray.div(arg).wrap()
}
override operator fun StructureND<Double>.plus(arg: Double): Nd4jArrayStructure<Double> { override operator fun StructureND<Double>.plus(arg: Double): Nd4jArrayStructure<Double> = ndArray.add(arg).wrap()
return ndArray.add(arg).wrap()
}
override operator fun StructureND<Double>.minus(arg: Double): Nd4jArrayStructure<Double> { override operator fun StructureND<Double>.minus(arg: Double): Nd4jArrayStructure<Double> = ndArray.sub(arg).wrap()
return ndArray.sub(arg).wrap()
}
override operator fun StructureND<Double>.times(arg: Double): Nd4jArrayStructure<Double> { override operator fun StructureND<Double>.times(arg: Double): Nd4jArrayStructure<Double> = ndArray.mul(arg).wrap()
return ndArray.mul(arg).wrap()
}
override operator fun Double.div(arg: StructureND<Double>): Nd4jArrayStructure<Double> { override operator fun Double.div(arg: StructureND<Double>): Nd4jArrayStructure<Double> =
return arg.ndArray.rdiv(this).wrap() arg.ndArray.rdiv(this).wrap()
}
override operator fun Double.minus(arg: StructureND<Double>): Nd4jArrayStructure<Double> { override operator fun Double.minus(arg: StructureND<Double>): Nd4jArrayStructure<Double> =
return arg.ndArray.rsub(this).wrap() arg.ndArray.rsub(this).wrap()
}
public companion object : DoubleNd4jArrayFieldOps()
} }
public fun DoubleField.nd4j(vararg shape: Int): DoubleNd4jArrayField = DoubleNd4jArrayField(intArrayOf(*shape)) public val DoubleField.nd4j: DoubleNd4jArrayFieldOps get() = DoubleNd4jArrayFieldOps
public class DoubleNd4jArrayField(override val shape: Shape) : DoubleNd4jArrayFieldOps(), FieldND<Double, DoubleField>
public fun DoubleField.nd4j(shapeFirst: Int, vararg shapeRest: Int): DoubleNd4jArrayField =
DoubleNd4jArrayField(intArrayOf(shapeFirst, * shapeRest))
/** /**
* Represents [FieldND] over [Nd4jArrayStructure] of [Float]. * Represents [FieldND] over [Nd4jArrayStructure] of [Float].
*/ */
public class FloatNd4jArrayField(override val shape: IntArray) : Nd4jArrayExtendedField<Float, FloatField> { public open class FloatNd4jArrayFieldOps : Nd4jArrayExtendedFieldOps<Float, FloatField> {
override val elementContext: FloatField get() = FloatField override val elementAlgebra: FloatField get() = FloatField
override fun INDArray.wrap(): Nd4jArrayStructure<Float> = checkShape(this).asFloatStructure() override fun INDArray.wrap(): Nd4jArrayStructure<Float> = asFloatStructure()
@OptIn(PerformancePitfall::class) @OptIn(PerformancePitfall::class)
override val StructureND<Float>.ndArray: INDArray override val StructureND<Float>.ndArray: INDArray
get() = when (this) { get() = when (this) {
is Nd4jArrayStructure<Float> -> checkShape(ndArray) is Nd4jArrayStructure<Float> -> ndArray
else -> Nd4j.zeros(*shape).also { else -> Nd4j.zeros(*shape).also {
elements().forEach { (idx, value) -> it.putScalar(idx, value) } elements().forEach { (idx, value) -> it.putScalar(idx, value) }
} }
@ -303,21 +262,29 @@ public class FloatNd4jArrayField(override val shape: IntArray) : Nd4jArrayExtend
override operator fun Float.minus(arg: StructureND<Float>): Nd4jArrayStructure<Float> = override operator fun Float.minus(arg: StructureND<Float>): Nd4jArrayStructure<Float> =
arg.ndArray.rsub(this).wrap() arg.ndArray.rsub(this).wrap()
public companion object : FloatNd4jArrayFieldOps()
} }
public class FloatNd4jArrayField(override val shape: Shape) : FloatNd4jArrayFieldOps(), RingND<Float, FloatField>
public val FloatField.nd4j: FloatNd4jArrayFieldOps get() = FloatNd4jArrayFieldOps
public fun FloatField.nd4j(shapeFirst: Int, vararg shapeRest: Int): FloatNd4jArrayField =
FloatNd4jArrayField(intArrayOf(shapeFirst, * shapeRest))
/** /**
* Represents [RingND] over [Nd4jArrayIntStructure]. * Represents [RingND] over [Nd4jArrayIntStructure].
*/ */
public class IntNd4jArrayRing(override val shape: IntArray) : Nd4jArrayRing<Int, IntRing> { public open class IntNd4jArrayRingOps : Nd4jArrayRingOps<Int, IntRing> {
override val elementContext: IntRing override val elementAlgebra: IntRing get() = IntRing
get() = IntRing
override fun INDArray.wrap(): Nd4jArrayStructure<Int> = checkShape(this).asIntStructure() override fun INDArray.wrap(): Nd4jArrayStructure<Int> = asIntStructure()
@OptIn(PerformancePitfall::class) @OptIn(PerformancePitfall::class)
override val StructureND<Int>.ndArray: INDArray override val StructureND<Int>.ndArray: INDArray
get() = when (this) { get() = when (this) {
is Nd4jArrayStructure<Int> -> checkShape(ndArray) is Nd4jArrayStructure<Int> -> ndArray
else -> Nd4j.zeros(*shape).also { else -> Nd4j.zeros(*shape).also {
elements().forEach { (idx, value) -> it.putScalar(idx, value) } elements().forEach { (idx, value) -> it.putScalar(idx, value) }
} }
@ -334,4 +301,13 @@ public class IntNd4jArrayRing(override val shape: IntArray) : Nd4jArrayRing<Int,
override operator fun Int.minus(arg: StructureND<Int>): Nd4jArrayStructure<Int> = override operator fun Int.minus(arg: StructureND<Int>): Nd4jArrayStructure<Int> =
arg.ndArray.rsub(this).wrap() arg.ndArray.rsub(this).wrap()
public companion object : IntNd4jArrayRingOps()
} }
public val IntRing.nd4j: IntNd4jArrayRingOps get() = IntNd4jArrayRingOps
public class IntNd4jArrayRing(override val shape: Shape) : IntNd4jArrayRingOps(), RingND<Int, IntRing>
public fun IntRing.nd4j(shapeFirst: Int, vararg shapeRest: Int): IntNd4jArrayRing =
IntNd4jArrayRing(intArrayOf(shapeFirst, * shapeRest))

View File

@ -8,6 +8,10 @@ package space.kscience.kmath.nd4j
import org.nd4j.linalg.factory.Nd4j import org.nd4j.linalg.factory.Nd4j
import space.kscience.kmath.misc.PerformancePitfall import space.kscience.kmath.misc.PerformancePitfall
import space.kscience.kmath.nd.StructureND import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.nd.one
import space.kscience.kmath.nd.produce
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.IntRing
import space.kscience.kmath.operations.invoke import space.kscience.kmath.operations.invoke
import kotlin.math.PI import kotlin.math.PI
import kotlin.test.Test import kotlin.test.Test
@ -19,7 +23,7 @@ import kotlin.test.fail
internal class Nd4jArrayAlgebraTest { internal class Nd4jArrayAlgebraTest {
@Test @Test
fun testProduce() { fun testProduce() {
val res = with(DoubleNd4jArrayField(intArrayOf(2, 2))) { produce { it.sum().toDouble() } } val res = DoubleField.nd4j.produce(2, 2) { it.sum().toDouble() }
val expected = (Nd4j.create(2, 2) ?: fail()).asDoubleStructure() val expected = (Nd4j.create(2, 2) ?: fail()).asDoubleStructure()
expected[intArrayOf(0, 0)] = 0.0 expected[intArrayOf(0, 0)] = 0.0
expected[intArrayOf(0, 1)] = 1.0 expected[intArrayOf(0, 1)] = 1.0
@ -30,7 +34,9 @@ internal class Nd4jArrayAlgebraTest {
@Test @Test
fun testMap() { fun testMap() {
val res = with(IntNd4jArrayRing(intArrayOf(2, 2))) { one.map { it + it * 2 } } val res = IntRing.nd4j {
one(2, 2).map { it + it * 2 }
}
val expected = (Nd4j.create(2, 2) ?: fail()).asIntStructure() val expected = (Nd4j.create(2, 2) ?: fail()).asIntStructure()
expected[intArrayOf(0, 0)] = 3 expected[intArrayOf(0, 0)] = 3
expected[intArrayOf(0, 1)] = 3 expected[intArrayOf(0, 1)] = 3
@ -41,7 +47,7 @@ internal class Nd4jArrayAlgebraTest {
@Test @Test
fun testAdd() { fun testAdd() {
val res = with(IntNd4jArrayRing(intArrayOf(2, 2))) { one + 25 } val res = IntRing.nd4j { one(2, 2) + 25 }
val expected = (Nd4j.create(2, 2) ?: fail()).asIntStructure() val expected = (Nd4j.create(2, 2) ?: fail()).asIntStructure()
expected[intArrayOf(0, 0)] = 26 expected[intArrayOf(0, 0)] = 26
expected[intArrayOf(0, 1)] = 26 expected[intArrayOf(0, 1)] = 26
@ -51,10 +57,10 @@ internal class Nd4jArrayAlgebraTest {
} }
@Test @Test
fun testSin() = DoubleNd4jArrayField(intArrayOf(2, 2)).invoke { fun testSin() = DoubleField.nd4j{
val initial = produce { (i, j) -> if (i == j) PI / 2 else 0.0 } val initial = produce(2, 2) { (i, j) -> if (i == j) PI / 2 else 0.0 }
val transformed = sin(initial) val transformed = sin(initial)
val expected = produce { (i, j) -> if (i == j) 1.0 else 0.0 } val expected = produce(2, 2) { (i, j) -> if (i == j) 1.0 else 0.0 }
println(transformed) println(transformed)
assertTrue { StructureND.contentEquals(transformed, expected) } assertTrue { StructureND.contentEquals(transformed, expected) }

View File

@ -64,8 +64,8 @@ public fun MST.toIExpr(): IExpr = when (this) {
} }
is MST.Unary -> when (operation) { is MST.Unary -> when (operation) {
GroupOperations.PLUS_OPERATION -> value.toIExpr() GroupOps.PLUS_OPERATION -> value.toIExpr()
GroupOperations.MINUS_OPERATION -> F.Negate(value.toIExpr()) GroupOps.MINUS_OPERATION -> F.Negate(value.toIExpr())
TrigonometricOperations.SIN_OPERATION -> F.Sin(value.toIExpr()) TrigonometricOperations.SIN_OPERATION -> F.Sin(value.toIExpr())
TrigonometricOperations.COS_OPERATION -> F.Cos(value.toIExpr()) TrigonometricOperations.COS_OPERATION -> F.Cos(value.toIExpr())
TrigonometricOperations.TAN_OPERATION -> F.Tan(value.toIExpr()) TrigonometricOperations.TAN_OPERATION -> F.Tan(value.toIExpr())
@ -85,10 +85,10 @@ public fun MST.toIExpr(): IExpr = when (this) {
} }
is MST.Binary -> when (operation) { is MST.Binary -> when (operation) {
GroupOperations.PLUS_OPERATION -> left.toIExpr() + right.toIExpr() GroupOps.PLUS_OPERATION -> left.toIExpr() + right.toIExpr()
GroupOperations.MINUS_OPERATION -> left.toIExpr() - right.toIExpr() GroupOps.MINUS_OPERATION -> left.toIExpr() - right.toIExpr()
RingOperations.TIMES_OPERATION -> left.toIExpr() * right.toIExpr() RingOps.TIMES_OPERATION -> left.toIExpr() * right.toIExpr()
FieldOperations.DIV_OPERATION -> F.Divide(left.toIExpr(), right.toIExpr()) FieldOps.DIV_OPERATION -> F.Divide(left.toIExpr(), right.toIExpr())
PowerOperations.POW_OPERATION -> F.Power(left.toIExpr(), F.symjify((right as MST.Numeric).value)) PowerOperations.POW_OPERATION -> F.Power(left.toIExpr(), F.symjify((right as MST.Numeric).value))
else -> error("Binary operation $operation not defined in $this") else -> error("Binary operation $operation not defined in $this")
} }

View File

@ -373,8 +373,12 @@ public open class DoubleTensorAlgebra :
return resTensor return resTensor
} }
override fun diagonalEmbedding(diagonalEntries: Tensor<Double>, offset: Int, dim1: Int, dim2: Int): override fun diagonalEmbedding(
DoubleTensor { diagonalEntries: Tensor<Double>,
offset: Int,
dim1: Int,
dim2: Int
): DoubleTensor {
val n = diagonalEntries.shape.size val n = diagonalEntries.shape.size
val d1 = minusIndexFrom(n + 1, dim1) val d1 = minusIndexFrom(n + 1, dim1)
val d2 = minusIndexFrom(n + 1, dim2) val d2 = minusIndexFrom(n + 1, dim2)

View File

@ -44,7 +44,7 @@ internal fun indexFromOffset(offset: Int, strides: IntArray, nDim: Int): IntArra
* *
* @param shape the shape of the tensor. * @param shape the shape of the tensor.
*/ */
internal class TensorLinearStructure(override val shape: IntArray) : Strides { internal class TensorLinearStructure(override val shape: IntArray) : Strides() {
override val strides: IntArray override val strides: IntArray
get() = stridesFromShape(shape) get() = stridesFromShape(shape)
@ -54,4 +54,18 @@ internal class TensorLinearStructure(override val shape: IntArray) : Strides {
override val linearSize: Int override val linearSize: Int
get() = shape.reduce(Int::times) get() = shape.reduce(Int::times)
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (other == null || this::class != other::class) return false
other as TensorLinearStructure
if (!shape.contentEquals(other.shape)) return false
return true
}
override fun hashCode(): Int {
return shape.contentHashCode()
}
} }

View File

@ -26,8 +26,11 @@ internal fun <T> Tensor<T>.copyToBufferedTensor(): BufferedTensor<T> =
internal fun <T> Tensor<T>.toBufferedTensor(): BufferedTensor<T> = when (this) { internal fun <T> Tensor<T>.toBufferedTensor(): BufferedTensor<T> = when (this) {
is BufferedTensor<T> -> this is BufferedTensor<T> -> this
is MutableBufferND<T> -> if (this.strides.strides contentEquals TensorLinearStructure(this.shape).strides) is MutableBufferND<T> -> if (this.indexes == TensorLinearStructure(this.shape)) {
BufferedTensor(this.shape, this.mutableBuffer, 0) else this.copyToBufferedTensor() BufferedTensor(this.shape, this.mutableBuffer, 0)
} else {
this.copyToBufferedTensor()
}
else -> this.copyToBufferedTensor() else -> this.copyToBufferedTensor()
} }

View File

@ -11,7 +11,7 @@ import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.nd.* import space.kscience.kmath.nd.*
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.ExtendedField import space.kscience.kmath.operations.ExtendedField
import space.kscience.kmath.operations.NumbersAddOperations import space.kscience.kmath.operations.NumbersAddOps
import space.kscience.kmath.operations.ScaleOperations import space.kscience.kmath.operations.ScaleOperations
@Suppress("OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") @Suppress("OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
@ -34,7 +34,7 @@ public fun F64Array.asStructure(): ViktorStructureND = ViktorStructureND(this)
@OptIn(UnstableKMathAPI::class) @OptIn(UnstableKMathAPI::class)
@Suppress("OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") @Suppress("OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
public class ViktorFieldND(override val shape: IntArray) : FieldND<Double, DoubleField>, public class ViktorFieldND(override val shape: IntArray) : FieldND<Double, DoubleField>,
NumbersAddOperations<StructureND<Double>>, ExtendedField<StructureND<Double>>, NumbersAddOps<StructureND<Double>>, ExtendedField<StructureND<Double>>,
ScaleOperations<StructureND<Double>> { ScaleOperations<StructureND<Double>> {
public val StructureND<Double>.f64Buffer: F64Array public val StructureND<Double>.f64Buffer: F64Array
@ -44,7 +44,7 @@ public class ViktorFieldND(override val shape: IntArray) : FieldND<Double, Doubl
shape shape
) )
this is ViktorStructureND && this.f64Buffer.shape.contentEquals(this@ViktorFieldND.shape) -> this.f64Buffer this is ViktorStructureND && this.f64Buffer.shape.contentEquals(this@ViktorFieldND.shape) -> this.f64Buffer
else -> produce { this@f64Buffer[it] }.f64Buffer else -> produce(shape) { this@f64Buffer[it] }.f64Buffer
} }
override val zero: ViktorStructureND by lazy { F64Array.full(init = 0.0, shape = shape).asStructure() } override val zero: ViktorStructureND by lazy { F64Array.full(init = 0.0, shape = shape).asStructure() }
@ -52,9 +52,9 @@ public class ViktorFieldND(override val shape: IntArray) : FieldND<Double, Doubl
private val strides: Strides = DefaultStrides(shape) private val strides: Strides = DefaultStrides(shape)
override val elementContext: DoubleField get() = DoubleField override val elementAlgebra: DoubleField get() = DoubleField
override fun produce(initializer: DoubleField.(IntArray) -> Double): ViktorStructureND = override fun produce(shape: IntArray, initializer: DoubleField.(IntArray) -> Double): ViktorStructureND =
F64Array(*shape).apply { F64Array(*shape).apply {
this@ViktorFieldND.strides.indices().forEach { index -> this@ViktorFieldND.strides.indices().forEach { index ->
set(value = DoubleField.initializer(index), indices = index) set(value = DoubleField.initializer(index), indices = index)
@ -78,13 +78,13 @@ public class ViktorFieldND(override val shape: IntArray) : FieldND<Double, Doubl
} }
}.asStructure() }.asStructure()
override fun combine( override fun zip(
a: StructureND<Double>, left: StructureND<Double>,
b: StructureND<Double>, right: StructureND<Double>,
transform: DoubleField.(Double, Double) -> Double, transform: DoubleField.(Double, Double) -> Double,
): ViktorStructureND = F64Array(*shape).apply { ): ViktorStructureND = F64Array(*shape).apply {
this@ViktorFieldND.strides.indices().forEach { index -> this@ViktorFieldND.strides.indices().forEach { index ->
set(value = DoubleField.transform(a[index], b[index]), indices = index) set(value = DoubleField.transform(left[index], right[index]), indices = index)
} }
}.asStructure() }.asStructure()
@ -123,4 +123,4 @@ public class ViktorFieldND(override val shape: IntArray) : FieldND<Double, Doubl
override fun ln(arg: StructureND<Double>): ViktorStructureND = arg.f64Buffer.log().asStructure() override fun ln(arg: StructureND<Double>): ViktorStructureND = arg.f64Buffer.log().asStructure()
} }
public fun ViktorNDField(vararg shape: Int): ViktorFieldND = ViktorFieldND(shape) public fun ViktorFieldND(vararg shape: Int): ViktorFieldND = ViktorFieldND(shape)

View File

@ -1,16 +1,18 @@
pluginManagement { pluginManagement {
repositories { repositories {
mavenLocal()
maven("https://repo.kotlin.link") maven("https://repo.kotlin.link")
mavenCentral() mavenCentral()
gradlePluginPortal() gradlePluginPortal()
} }
val kotlinVersion = "1.6.0-M1" val kotlinVersion = "1.6.0-RC"
val toolsVersion = "0.10.5"
plugins { plugins {
id("org.jetbrains.kotlinx.benchmark") version "0.3.1" id("org.jetbrains.kotlinx.benchmark") version "0.3.1"
id("ru.mipt.npm.gradle.project") version "0.10.5" id("ru.mipt.npm.gradle.project") version toolsVersion
id("ru.mipt.npm.gradle.jvm") version toolsVersion
id("ru.mipt.npm.gradle.mpp") version toolsVersion
kotlin("multiplatform") version kotlinVersion kotlin("multiplatform") version kotlinVersion
kotlin("plugin.allopen") version kotlinVersion kotlin("plugin.allopen") version kotlinVersion
} }