Shapeless ND and Buffer algebras

This commit is contained in:
Alexander Nozik 2021-10-17 11:12:35 +03:00
parent 8d2770c275
commit d0354da80a
56 changed files with 893 additions and 906 deletions

View File

@ -42,6 +42,9 @@
- Use `Symbol` factory function instead of `StringSymbol`
- New discoverability pattern: `<Type>.algebra.<nd/etc>`
- Adjusted commons-math API for linear solvers to match conventions.
- Buffer algebra does not require size anymore
- Operations -> Ops
- Default Buffer and ND algebras are now Ops and lack neutral elements (0, 1) as well as algebra-level shapes.
### Deprecated
- Specialized `DoubleBufferAlgebra`

View File

@ -9,9 +9,10 @@ import kotlinx.benchmark.Benchmark
import kotlinx.benchmark.Blackhole
import kotlinx.benchmark.Scope
import kotlinx.benchmark.State
import space.kscience.kmath.nd.BufferedFieldOpsND
import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.nd.autoNdAlgebra
import space.kscience.kmath.nd.ndAlgebra
import space.kscience.kmath.nd.one
import space.kscience.kmath.nd4j.nd4j
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.structures.Buffer
@ -23,21 +24,21 @@ import space.kscience.kmath.tensors.core.tensorAlgebra
internal class NDFieldBenchmark {
@Benchmark
fun autoFieldAdd(blackhole: Blackhole) = with(autoField) {
var res: StructureND<Double> = one
repeat(n) { res += one }
var res: StructureND<Double> = one(shape)
repeat(n) { res += 1.0 }
blackhole.consume(res)
}
@Benchmark
fun specializedFieldAdd(blackhole: Blackhole) = with(specializedField) {
var res: StructureND<Double> = one
var res: StructureND<Double> = one(shape)
repeat(n) { res += 1.0 }
blackhole.consume(res)
}
@Benchmark
fun boxingFieldAdd(blackhole: Blackhole) = with(genericField) {
var res: StructureND<Double> = one
var res: StructureND<Double> = one(shape)
repeat(n) { res += 1.0 }
blackhole.consume(res)
}
@ -56,19 +57,20 @@ internal class NDFieldBenchmark {
blackhole.consume(res)
}
// @Benchmark
// fun nd4jAdd(blackhole: Blackhole) = with(nd4jField) {
// var res: StructureND<Double> = one
// repeat(n) { res += 1.0 }
// blackhole.consume(res)
// }
@Benchmark
fun nd4jAdd(blackhole: Blackhole) = with(nd4jField) {
var res: StructureND<Double> = one(dim, dim)
repeat(n) { res += 1.0 }
blackhole.consume(res)
}
private companion object {
private const val dim = 1000
private const val n = 100
private val autoField = DoubleField.autoNdAlgebra(dim, dim)
private val specializedField = DoubleField.ndAlgebra(dim, dim)
private val genericField = DoubleField.ndAlgebra(Buffer.Companion::boxing, dim, dim)
private val nd4jField = DoubleField.nd4j(dim, dim)
private val shape = intArrayOf(dim, dim)
private val autoField = BufferedFieldOpsND(DoubleField, Buffer.Companion::auto)
private val specializedField = DoubleField.ndAlgebra
private val genericField = BufferedFieldOpsND(DoubleField, Buffer.Companion::boxing)
private val nd4jField = DoubleField.nd4j
}
}

View File

@ -10,18 +10,17 @@ import kotlinx.benchmark.Blackhole
import kotlinx.benchmark.Scope
import kotlinx.benchmark.State
import org.jetbrains.bio.viktor.F64Array
import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.nd.autoNdAlgebra
import space.kscience.kmath.nd.ndAlgebra
import space.kscience.kmath.nd.*
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.viktor.ViktorNDField
import space.kscience.kmath.structures.Buffer
import space.kscience.kmath.viktor.ViktorFieldND
@State(Scope.Benchmark)
internal class ViktorBenchmark {
@Benchmark
fun automaticFieldAddition(blackhole: Blackhole) {
with(autoField) {
var res: StructureND<Double> = one
var res: StructureND<Double> = one(shape)
repeat(n) { res += 1.0 }
blackhole.consume(res)
}
@ -30,7 +29,7 @@ internal class ViktorBenchmark {
@Benchmark
fun realFieldAddition(blackhole: Blackhole) {
with(realField) {
var res: StructureND<Double> = one
var res: StructureND<Double> = one(shape)
repeat(n) { res += 1.0 }
blackhole.consume(res)
}
@ -39,7 +38,7 @@ internal class ViktorBenchmark {
@Benchmark
fun viktorFieldAddition(blackhole: Blackhole) {
with(viktorField) {
var res = one
var res = one(shape)
repeat(n) { res += 1.0 }
blackhole.consume(res)
}
@ -56,10 +55,11 @@ internal class ViktorBenchmark {
private companion object {
private const val dim = 1000
private const val n = 100
private val shape = Shape(dim, dim)
// automatically build context most suited for given type.
private val autoField = DoubleField.autoNdAlgebra(dim, dim)
private val realField = DoubleField.ndAlgebra(dim, dim)
private val viktorField = ViktorNDField(dim, dim)
private val autoField = BufferedFieldOpsND(DoubleField, Buffer.Companion::auto)
private val realField = DoubleField.ndAlgebra
private val viktorField = ViktorFieldND(dim, dim)
}
}

View File

@ -10,18 +10,21 @@ import kotlinx.benchmark.Blackhole
import kotlinx.benchmark.Scope
import kotlinx.benchmark.State
import org.jetbrains.bio.viktor.F64Array
import space.kscience.kmath.nd.autoNdAlgebra
import space.kscience.kmath.nd.BufferedFieldOpsND
import space.kscience.kmath.nd.Shape
import space.kscience.kmath.nd.ndAlgebra
import space.kscience.kmath.nd.one
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.structures.Buffer
import space.kscience.kmath.viktor.ViktorFieldND
@State(Scope.Benchmark)
internal class ViktorLogBenchmark {
@Benchmark
fun realFieldLog(blackhole: Blackhole) {
with(realNdField) {
val fortyTwo = produce { 42.0 }
var res = one
with(realField) {
val fortyTwo = produce(shape) { 42.0 }
var res = one(shape)
repeat(n) { res = ln(fortyTwo) }
blackhole.consume(res)
}
@ -30,7 +33,7 @@ internal class ViktorLogBenchmark {
@Benchmark
fun viktorFieldLog(blackhole: Blackhole) {
with(viktorField) {
val fortyTwo = produce { 42.0 }
val fortyTwo = produce(shape) { 42.0 }
var res = one
repeat(n) { res = ln(fortyTwo) }
blackhole.consume(res)
@ -48,10 +51,11 @@ internal class ViktorLogBenchmark {
private companion object {
private const val dim = 1000
private const val n = 100
private val shape = Shape(dim, dim)
// automatically build context most suited for given type.
private val autoField = DoubleField.autoNdAlgebra(dim, dim)
private val realNdField = DoubleField.ndAlgebra(dim, dim)
private val viktorField = ViktorFieldND(intArrayOf(dim, dim))
private val autoField = BufferedFieldOpsND(DoubleField, Buffer.Companion::auto)
private val realField = DoubleField.ndAlgebra
private val viktorField = ViktorFieldND(dim, dim)
}
}

View File

@ -11,27 +11,27 @@ import space.kscience.kmath.complex.bufferAlgebra
import space.kscience.kmath.complex.ndAlgebra
import space.kscience.kmath.nd.BufferND
import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.nd.produce
fun main() = Complex.algebra {
val complex = 2 + 2 * i
println(complex * 8 - 5 * i)
//flat buffer
val buffer = bufferAlgebra(8).run {
buffer { Complex(it, -it) }.map { Complex(it.im, it.re) }
val buffer = with(bufferAlgebra){
buffer(8) { Complex(it, -it) }.map { Complex(it.im, it.re) }
}
println(buffer)
// 2d element
val element: BufferND<Complex> = ndAlgebra(2, 2).produce { (i, j) ->
val element: BufferND<Complex> = ndAlgebra.produce(2, 2) { (i, j) ->
Complex(i - j, i + j)
}
println(element)
// 1d element operation
val result: StructureND<Complex> = ndAlgebra(8).run {
val a = produce { (it) -> i * it - it.toDouble() }
val result: StructureND<Complex> = ndAlgebra{
val a = produce(8) { (it) -> i * it - it.toDouble() }
val b = 3
val c = Complex(1.0, 1.0)

View File

@ -14,7 +14,7 @@ import space.kscience.kmath.nd.ndAlgebra
import space.kscience.kmath.nd4j.Nd4jArrayField
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.invoke
import space.kscience.kmath.viktor.ViktorNDField
import space.kscience.kmath.viktor.ViktorFieldND
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
import kotlin.system.measureTimeMillis
@ -41,7 +41,7 @@ fun main() {
// Nd4j specialized field.
val nd4jField = Nd4jArrayField.real(dim, dim)
//viktor field
val viktorField = ViktorNDField(dim, dim)
val viktorField = ViktorFieldND(dim, dim)
//parallel processing based on Java Streams
val parallelField = DoubleField.ndStreaming(dim, dim)

View File

@ -8,7 +8,7 @@ package space.kscience.kmath.structures
import space.kscience.kmath.nd.*
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.ExtendedField
import space.kscience.kmath.operations.NumbersAddOperations
import space.kscience.kmath.operations.NumbersAddOps
import java.util.*
import java.util.stream.IntStream
@ -17,11 +17,11 @@ import java.util.stream.IntStream
* execution.
*/
class StreamDoubleFieldND(override val shape: IntArray) : FieldND<Double, DoubleField>,
NumbersAddOperations<StructureND<Double>>,
NumbersAddOps<StructureND<Double>>,
ExtendedField<StructureND<Double>> {
private val strides = DefaultStrides(shape)
override val elementContext: DoubleField get() = DoubleField
override val elementAlgebra: DoubleField get() = DoubleField
override val zero: BufferND<Double> by lazy { produce { zero } }
override val one: BufferND<Double> by lazy { produce { one } }
@ -36,7 +36,7 @@ class StreamDoubleFieldND(override val shape: IntArray) : FieldND<Double, Double
this@StreamDoubleFieldND.shape,
shape
)
this is BufferND && this.strides == this@StreamDoubleFieldND.strides -> this.buffer as DoubleBuffer
this is BufferND && this.indexes == this@StreamDoubleFieldND.strides -> this.buffer as DoubleBuffer
else -> DoubleBuffer(strides.linearSize) { offset -> get(strides.index(offset)) }
}
@ -69,7 +69,7 @@ class StreamDoubleFieldND(override val shape: IntArray) : FieldND<Double, Double
return BufferND(strides, array.asBuffer())
}
override fun combine(
override fun zip(
a: StructureND<Double>,
b: StructureND<Double>,
transform: DoubleField.(Double, Double) -> Double,

View File

@ -8,6 +8,7 @@ package space.kscience.kmath.structures
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.buffer
import space.kscience.kmath.operations.bufferAlgebra
import space.kscience.kmath.operations.withSize
inline fun <reified R : Any> MutableBuffer.Companion.same(
n: Int,
@ -16,7 +17,7 @@ inline fun <reified R : Any> MutableBuffer.Companion.same(
fun main() {
with(DoubleField.bufferAlgebra(5)) {
with(DoubleField.bufferAlgebra.withSize(5)) {
println(number(2.0) + buffer(1, 2, 3, 4, 5))
}
}

View File

@ -1,5 +1,5 @@
distributionBase=GRADLE_USER_HOME
distributionPath=wrapper/dists
distributionUrl=https\://services.gradle.org/distributions/gradle-7.1.1-bin.zip
distributionUrl=https\://services.gradle.org/distributions/gradle-7.2-bin.zip
zipStoreBase=GRADLE_USER_HOME
zipStorePath=wrapper/dists

View File

@ -18,10 +18,10 @@ import com.github.h0tk3y.betterParse.parser.ParseResult
import com.github.h0tk3y.betterParse.parser.Parser
import space.kscience.kmath.expressions.MST
import space.kscience.kmath.expressions.Symbol
import space.kscience.kmath.operations.FieldOperations
import space.kscience.kmath.operations.GroupOperations
import space.kscience.kmath.operations.FieldOps
import space.kscience.kmath.operations.GroupOps
import space.kscience.kmath.operations.PowerOperations
import space.kscience.kmath.operations.RingOperations
import space.kscience.kmath.operations.RingOps
/**
* better-parse implementation of grammar defined in the ArithmeticsEvaluator.g4.
@ -60,7 +60,7 @@ public object ArithmeticsEvaluator : Grammar<MST>() {
.or(binaryFunction)
.or(unaryFunction)
.or(singular)
.or(-minus and parser(ArithmeticsEvaluator::term) map { MST.Unary(GroupOperations.MINUS_OPERATION, it) })
.or(-minus and parser(ArithmeticsEvaluator::term) map { MST.Unary(GroupOps.MINUS_OPERATION, it) })
.or(-lpar and parser(ArithmeticsEvaluator::subSumChain) and -rpar)
private val powChain: Parser<MST> by leftAssociative(term = term, operator = pow) { a, _, b ->
@ -72,9 +72,9 @@ public object ArithmeticsEvaluator : Grammar<MST>() {
operator = div or mul use TokenMatch::type
) { a, op, b ->
if (op == div)
MST.Binary(FieldOperations.DIV_OPERATION, a, b)
MST.Binary(FieldOps.DIV_OPERATION, a, b)
else
MST.Binary(RingOperations.TIMES_OPERATION, a, b)
MST.Binary(RingOps.TIMES_OPERATION, a, b)
}
private val subSumChain: Parser<MST> by leftAssociative(
@ -82,9 +82,9 @@ public object ArithmeticsEvaluator : Grammar<MST>() {
operator = plus or minus use TokenMatch::type
) { a, op, b ->
if (op == plus)
MST.Binary(GroupOperations.PLUS_OPERATION, a, b)
MST.Binary(GroupOps.PLUS_OPERATION, a, b)
else
MST.Binary(GroupOperations.MINUS_OPERATION, a, b)
MST.Binary(GroupOps.MINUS_OPERATION, a, b)
}
override val rootParser: Parser<MST> by subSumChain

View File

@ -39,7 +39,7 @@ public val PrintNumeric: RenderFeature = RenderFeature { _, node ->
@UnstableKMathAPI
private fun printSignedNumberString(s: String): MathSyntax = if (s.startsWith('-'))
UnaryMinusSyntax(
operation = GroupOperations.MINUS_OPERATION,
operation = GroupOps.MINUS_OPERATION,
operand = OperandSyntax(
operand = NumberSyntax(string = s.removePrefix("-")),
parentheses = true,
@ -72,7 +72,7 @@ public class PrettyPrintFloats(public val types: Set<KClass<out Number>>) : Rend
val exponent = afterE.toDouble().toString().removeSuffix(".0")
return MultiplicationSyntax(
operation = RingOperations.TIMES_OPERATION,
operation = RingOps.TIMES_OPERATION,
left = OperandSyntax(operand = NumberSyntax(significand), parentheses = true),
right = OperandSyntax(
operand = SuperscriptSyntax(
@ -91,7 +91,7 @@ public class PrettyPrintFloats(public val types: Set<KClass<out Number>>) : Rend
if (toString.startsWith('-'))
return UnaryMinusSyntax(
operation = GroupOperations.MINUS_OPERATION,
operation = GroupOps.MINUS_OPERATION,
operand = OperandSyntax(operand = infty, parentheses = true),
)
@ -211,9 +211,9 @@ public class BinaryPlus(operations: Collection<String>?) : Binary(operations) {
public companion object {
/**
* The default instance configured with [GroupOperations.PLUS_OPERATION].
* The default instance configured with [GroupOps.PLUS_OPERATION].
*/
public val Default: BinaryPlus = BinaryPlus(setOf(GroupOperations.PLUS_OPERATION))
public val Default: BinaryPlus = BinaryPlus(setOf(GroupOps.PLUS_OPERATION))
}
}
@ -233,9 +233,9 @@ public class BinaryMinus(operations: Collection<String>?) : Binary(operations) {
public companion object {
/**
* The default instance configured with [GroupOperations.MINUS_OPERATION].
* The default instance configured with [GroupOps.MINUS_OPERATION].
*/
public val Default: BinaryMinus = BinaryMinus(setOf(GroupOperations.MINUS_OPERATION))
public val Default: BinaryMinus = BinaryMinus(setOf(GroupOps.MINUS_OPERATION))
}
}
@ -253,9 +253,9 @@ public class UnaryPlus(operations: Collection<String>?) : Unary(operations) {
public companion object {
/**
* The default instance configured with [GroupOperations.PLUS_OPERATION].
* The default instance configured with [GroupOps.PLUS_OPERATION].
*/
public val Default: UnaryPlus = UnaryPlus(setOf(GroupOperations.PLUS_OPERATION))
public val Default: UnaryPlus = UnaryPlus(setOf(GroupOps.PLUS_OPERATION))
}
}
@ -273,9 +273,9 @@ public class UnaryMinus(operations: Collection<String>?) : Unary(operations) {
public companion object {
/**
* The default instance configured with [GroupOperations.MINUS_OPERATION].
* The default instance configured with [GroupOps.MINUS_OPERATION].
*/
public val Default: UnaryMinus = UnaryMinus(setOf(GroupOperations.MINUS_OPERATION))
public val Default: UnaryMinus = UnaryMinus(setOf(GroupOps.MINUS_OPERATION))
}
}
@ -295,9 +295,9 @@ public class Fraction(operations: Collection<String>?) : Binary(operations) {
public companion object {
/**
* The default instance configured with [FieldOperations.DIV_OPERATION].
* The default instance configured with [FieldOps.DIV_OPERATION].
*/
public val Default: Fraction = Fraction(setOf(FieldOperations.DIV_OPERATION))
public val Default: Fraction = Fraction(setOf(FieldOps.DIV_OPERATION))
}
}
@ -422,9 +422,9 @@ public class Multiplication(operations: Collection<String>?) : Binary(operations
public companion object {
/**
* The default instance configured with [RingOperations.TIMES_OPERATION].
* The default instance configured with [RingOps.TIMES_OPERATION].
*/
public val Default: Multiplication = Multiplication(setOf(RingOperations.TIMES_OPERATION))
public val Default: Multiplication = Multiplication(setOf(RingOps.TIMES_OPERATION))
}
}

View File

@ -7,10 +7,10 @@ package space.kscience.kmath.ast.rendering
import space.kscience.kmath.ast.rendering.FeaturedMathRendererWithPostProcess.PostProcessPhase
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.FieldOperations
import space.kscience.kmath.operations.GroupOperations
import space.kscience.kmath.operations.FieldOps
import space.kscience.kmath.operations.GroupOps
import space.kscience.kmath.operations.PowerOperations
import space.kscience.kmath.operations.RingOperations
import space.kscience.kmath.operations.RingOps
/**
* Removes unnecessary times (&times;) symbols from [MultiplicationSyntax].
@ -306,10 +306,10 @@ public class SimplifyParentheses(public val precedenceFunction: (MathSyntax) ->
is BinarySyntax -> when (it.operation) {
PowerOperations.POW_OPERATION -> 1
RingOperations.TIMES_OPERATION -> 3
FieldOperations.DIV_OPERATION -> 3
GroupOperations.MINUS_OPERATION -> 4
GroupOperations.PLUS_OPERATION -> 4
RingOps.TIMES_OPERATION -> 3
FieldOps.DIV_OPERATION -> 3
GroupOps.MINUS_OPERATION -> 4
GroupOps.PLUS_OPERATION -> 4
else -> 0
}

View File

@ -7,7 +7,7 @@ package space.kscience.kmath.ast.rendering
import space.kscience.kmath.ast.rendering.TestUtils.testLatex
import space.kscience.kmath.expressions.MST
import space.kscience.kmath.operations.GroupOperations
import space.kscience.kmath.operations.GroupOps
import kotlin.test.Test
internal class TestLatex {
@ -36,7 +36,7 @@ internal class TestLatex {
fun unaryOperator() = testLatex("sin(1)", "\\operatorname{sin}\\,\\left(1\\right)")
@Test
fun unaryPlus() = testLatex(MST.Unary(GroupOperations.PLUS_OPERATION, MST.Numeric(1)), "+1")
fun unaryPlus() = testLatex(MST.Unary(GroupOps.PLUS_OPERATION, MST.Numeric(1)), "+1")
@Test
fun unaryMinus() = testLatex("-x", "-x")

View File

@ -7,7 +7,7 @@ package space.kscience.kmath.ast.rendering
import space.kscience.kmath.ast.rendering.TestUtils.testMathML
import space.kscience.kmath.expressions.MST
import space.kscience.kmath.operations.GroupOperations
import space.kscience.kmath.operations.GroupOps
import kotlin.test.Test
internal class TestMathML {
@ -47,7 +47,7 @@ internal class TestMathML {
@Test
fun unaryPlus() =
testMathML(MST.Unary(GroupOperations.PLUS_OPERATION, MST.Numeric(1)), "<mo>+</mo><mn>1</mn>")
testMathML(MST.Unary(GroupOps.PLUS_OPERATION, MST.Numeric(1)), "<mo>+</mo><mn>1</mn>")
@Test
fun unaryMinus() = testMathML("-x", "<mo>-</mo><mi>x</mi>")

View File

@ -108,8 +108,8 @@ internal class DoubleWasmBuilder(target: MST) : WasmBuilder<Double>(f64, DoubleF
override fun visitNumeric(mst: Numeric): ExpressionRef = ctx.f64.const(mst.value)
override fun visitUnary(mst: Unary): ExpressionRef = when (mst.operation) {
GroupOperations.MINUS_OPERATION -> ctx.f64.neg(visit(mst.value))
GroupOperations.PLUS_OPERATION -> visit(mst.value)
GroupOps.MINUS_OPERATION -> ctx.f64.neg(visit(mst.value))
GroupOps.PLUS_OPERATION -> visit(mst.value)
PowerOperations.SQRT_OPERATION -> ctx.f64.sqrt(visit(mst.value))
TrigonometricOperations.SIN_OPERATION -> ctx.call("sin", arrayOf(visit(mst.value)), f64)
TrigonometricOperations.COS_OPERATION -> ctx.call("cos", arrayOf(visit(mst.value)), f64)
@ -129,10 +129,10 @@ internal class DoubleWasmBuilder(target: MST) : WasmBuilder<Double>(f64, DoubleF
}
override fun visitBinary(mst: Binary): ExpressionRef = when (mst.operation) {
GroupOperations.PLUS_OPERATION -> ctx.f64.add(visit(mst.left), visit(mst.right))
GroupOperations.MINUS_OPERATION -> ctx.f64.sub(visit(mst.left), visit(mst.right))
RingOperations.TIMES_OPERATION -> ctx.f64.mul(visit(mst.left), visit(mst.right))
FieldOperations.DIV_OPERATION -> ctx.f64.div(visit(mst.left), visit(mst.right))
GroupOps.PLUS_OPERATION -> ctx.f64.add(visit(mst.left), visit(mst.right))
GroupOps.MINUS_OPERATION -> ctx.f64.sub(visit(mst.left), visit(mst.right))
RingOps.TIMES_OPERATION -> ctx.f64.mul(visit(mst.left), visit(mst.right))
FieldOps.DIV_OPERATION -> ctx.f64.div(visit(mst.left), visit(mst.right))
PowerOperations.POW_OPERATION -> ctx.call("pow", arrayOf(visit(mst.left), visit(mst.right)), f64)
else -> super.visitBinary(mst)
}
@ -142,15 +142,15 @@ internal class IntWasmBuilder(target: MST) : WasmBuilder<Int>(i32, IntRing, targ
override fun visitNumeric(mst: Numeric): ExpressionRef = ctx.i32.const(mst.value)
override fun visitUnary(mst: Unary): ExpressionRef = when (mst.operation) {
GroupOperations.MINUS_OPERATION -> ctx.i32.sub(ctx.i32.const(0), visit(mst.value))
GroupOperations.PLUS_OPERATION -> visit(mst.value)
GroupOps.MINUS_OPERATION -> ctx.i32.sub(ctx.i32.const(0), visit(mst.value))
GroupOps.PLUS_OPERATION -> visit(mst.value)
else -> super.visitUnary(mst)
}
override fun visitBinary(mst: Binary): ExpressionRef = when (mst.operation) {
GroupOperations.PLUS_OPERATION -> ctx.i32.add(visit(mst.left), visit(mst.right))
GroupOperations.MINUS_OPERATION -> ctx.i32.sub(visit(mst.left), visit(mst.right))
RingOperations.TIMES_OPERATION -> ctx.i32.mul(visit(mst.left), visit(mst.right))
GroupOps.PLUS_OPERATION -> ctx.i32.add(visit(mst.left), visit(mst.right))
GroupOps.MINUS_OPERATION -> ctx.i32.sub(visit(mst.left), visit(mst.right))
RingOps.TIMES_OPERATION -> ctx.i32.mul(visit(mst.left), visit(mst.right))
else -> super.visitBinary(mst)
}
}

View File

@ -9,7 +9,7 @@ import org.apache.commons.math3.analysis.differentiation.DerivativeStructure
import space.kscience.kmath.expressions.*
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.ExtendedField
import space.kscience.kmath.operations.NumbersAddOperations
import space.kscience.kmath.operations.NumbersAddOps
/**
* A field over commons-math [DerivativeStructure].
@ -22,7 +22,7 @@ public class DerivativeStructureField(
public val order: Int,
bindings: Map<Symbol, Double>,
) : ExtendedField<DerivativeStructure>, ExpressionAlgebra<Double, DerivativeStructure>,
NumbersAddOperations<DerivativeStructure> {
NumbersAddOps<DerivativeStructure> {
public val numberOfVariables: Int = bindings.size
override val zero: DerivativeStructure by lazy { DerivativeStructure(numberOfVariables, order) }

View File

@ -52,7 +52,7 @@ private val PI_DIV_2 = Complex(PI / 2, 0)
public object ComplexField :
ExtendedField<Complex>,
Norm<Complex, Complex>,
NumbersAddOperations<Complex>,
NumbersAddOps<Complex>,
ScaleOperations<Complex> {
override val zero: Complex = 0.0.toComplex()
@ -216,7 +216,6 @@ public data class Complex(val re: Double, val im: Double) {
public val Complex.Companion.algebra: ComplexField get() = ComplexField
/**
* Creates a complex number with real part equal to this real.
*

View File

@ -6,13 +6,8 @@
package space.kscience.kmath.complex
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.nd.BufferND
import space.kscience.kmath.nd.BufferedFieldND
import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.operations.BufferField
import space.kscience.kmath.operations.ExtendedField
import space.kscience.kmath.operations.NumbersAddOperations
import space.kscience.kmath.operations.bufferAlgebra
import space.kscience.kmath.nd.*
import space.kscience.kmath.operations.*
import space.kscience.kmath.structures.Buffer
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
@ -22,100 +17,61 @@ import kotlin.contracts.contract
* An optimized nd-field for complex numbers
*/
@OptIn(UnstableKMathAPI::class)
public class ComplexFieldND(
shape: IntArray,
) : BufferedFieldND<Complex, ComplexField>(shape, ComplexField, Buffer.Companion::complex),
NumbersAddOperations<StructureND<Complex>>,
ExtendedField<StructureND<Complex>> {
public sealed class ComplexFieldOpsND : BufferedFieldOpsND<Complex, ComplexField>(ComplexField.bufferAlgebra),
ScaleOperations<StructureND<Complex>>, ExtendedFieldOps<StructureND<Complex>> {
override val zero: BufferND<Complex> by lazy { produce { zero } }
override val one: BufferND<Complex> by lazy { produce { one } }
override fun number(value: Number): BufferND<Complex> {
val d = value.toComplex() // minimize conversions
return produce { d }
override fun StructureND<Complex>.toBufferND(): BufferND<Complex> = when (this) {
is BufferND -> this
else -> {
val indexer = indexerBuilder(shape)
BufferND(indexer, Buffer.complex(indexer.linearSize) { offset -> get(indexer.index(offset)) })
}
}
//
// @Suppress("OVERRIDE_BY_INLINE")
// override inline fun map(
// arg: AbstractNDBuffer<Double>,
// transform: DoubleField.(Double) -> Double,
// ): RealNDElement {
// check(arg)
// val array = RealBuffer(arg.strides.linearSize) { offset -> DoubleField.transform(arg.buffer[offset]) }
// return BufferedNDFieldElement(this, array)
// }
//
// @Suppress("OVERRIDE_BY_INLINE")
// override inline fun produce(initializer: DoubleField.(IntArray) -> Double): RealNDElement {
// val array = RealBuffer(strides.linearSize) { offset -> elementContext.initializer(strides.index(offset)) }
// return BufferedNDFieldElement(this, array)
// }
//
// @Suppress("OVERRIDE_BY_INLINE")
// override inline fun mapIndexed(
// arg: AbstractNDBuffer<Double>,
// transform: DoubleField.(index: IntArray, Double) -> Double,
// ): RealNDElement {
// check(arg)
// return BufferedNDFieldElement(
// this,
// RealBuffer(arg.strides.linearSize) { offset ->
// elementContext.transform(
// arg.strides.index(offset),
// arg.buffer[offset]
// )
// })
// }
//
// @Suppress("OVERRIDE_BY_INLINE")
// override inline fun combine(
// a: AbstractNDBuffer<Double>,
// b: AbstractNDBuffer<Double>,
// transform: DoubleField.(Double, Double) -> Double,
// ): RealNDElement {
// check(a, b)
// val buffer = RealBuffer(strides.linearSize) { offset ->
// elementContext.transform(a.buffer[offset], b.buffer[offset])
// }
// return BufferedNDFieldElement(this, buffer)
// }
//TODO do specialization
override fun power(arg: StructureND<Complex>, pow: Number): BufferND<Complex> = arg.map { power(it, pow) }
override fun scale(a: StructureND<Complex>, value: Double): BufferND<Complex> =
mapInline(a.toBufferND()) { it * value }
override fun exp(arg: StructureND<Complex>): BufferND<Complex> = arg.map { exp(it) }
override fun power(arg: StructureND<Complex>, pow: Number): BufferND<Complex> =
mapInline(arg.toBufferND()) { power(it, pow) }
override fun ln(arg: StructureND<Complex>): BufferND<Complex> = arg.map { ln(it) }
override fun exp(arg: StructureND<Complex>): BufferND<Complex> = mapInline(arg.toBufferND()) { exp(it) }
override fun ln(arg: StructureND<Complex>): BufferND<Complex> = mapInline(arg.toBufferND()) { ln(it) }
override fun sin(arg: StructureND<Complex>): BufferND<Complex> = arg.map { sin(it) }
override fun cos(arg: StructureND<Complex>): BufferND<Complex> = arg.map { cos(it) }
override fun tan(arg: StructureND<Complex>): BufferND<Complex> = arg.map { tan(it) }
override fun asin(arg: StructureND<Complex>): BufferND<Complex> = arg.map { asin(it) }
override fun acos(arg: StructureND<Complex>): BufferND<Complex> = arg.map { acos(it) }
override fun atan(arg: StructureND<Complex>): BufferND<Complex> = arg.map { atan(it) }
override fun sin(arg: StructureND<Complex>): BufferND<Complex> = mapInline(arg.toBufferND()) { sin(it) }
override fun cos(arg: StructureND<Complex>): BufferND<Complex> = mapInline(arg.toBufferND()) { cos(it) }
override fun tan(arg: StructureND<Complex>): BufferND<Complex> = mapInline(arg.toBufferND()) { tan(it) }
override fun asin(arg: StructureND<Complex>): BufferND<Complex> = mapInline(arg.toBufferND()) { asin(it) }
override fun acos(arg: StructureND<Complex>): BufferND<Complex> = mapInline(arg.toBufferND()) { acos(it) }
override fun atan(arg: StructureND<Complex>): BufferND<Complex> = mapInline(arg.toBufferND()) { atan(it) }
override fun sinh(arg: StructureND<Complex>): BufferND<Complex> = arg.map { sinh(it) }
override fun cosh(arg: StructureND<Complex>): BufferND<Complex> = arg.map { cosh(it) }
override fun tanh(arg: StructureND<Complex>): BufferND<Complex> = arg.map { tanh(it) }
override fun asinh(arg: StructureND<Complex>): BufferND<Complex> = arg.map { asinh(it) }
override fun acosh(arg: StructureND<Complex>): BufferND<Complex> = arg.map { acosh(it) }
override fun atanh(arg: StructureND<Complex>): BufferND<Complex> = arg.map { atanh(it) }
}
override fun sinh(arg: StructureND<Complex>): BufferND<Complex> = mapInline(arg.toBufferND()) { sinh(it) }
override fun cosh(arg: StructureND<Complex>): BufferND<Complex> = mapInline(arg.toBufferND()) { cosh(it) }
override fun tanh(arg: StructureND<Complex>): BufferND<Complex> = mapInline(arg.toBufferND()) { tanh(it) }
override fun asinh(arg: StructureND<Complex>): BufferND<Complex> = mapInline(arg.toBufferND()) { asinh(it) }
override fun acosh(arg: StructureND<Complex>): BufferND<Complex> = mapInline(arg.toBufferND()) { acosh(it) }
override fun atanh(arg: StructureND<Complex>): BufferND<Complex> = mapInline(arg.toBufferND()) { atanh(it) }
/**
* Fast element production using function inlining
*/
public inline fun BufferedFieldND<Complex, ComplexField>.produceInline(initializer: ComplexField.(Int) -> Complex): BufferND<Complex> {
contract { callsInPlace(initializer, InvocationKind.EXACTLY_ONCE) }
val buffer = Buffer.complex(strides.linearSize) { offset -> ComplexField.initializer(offset) }
return BufferND(strides, buffer)
public companion object : ComplexFieldOpsND()
}
@UnstableKMathAPI
public fun ComplexField.bufferAlgebra(size: Int): BufferField<Complex, ComplexField> =
bufferAlgebra(Buffer.Companion::complex, size)
public val ComplexField.bufferAlgebra: BufferFieldOps<Complex, ComplexField>
get() = bufferAlgebra(Buffer.Companion::complex)
@OptIn(UnstableKMathAPI::class)
public class ComplexFieldND(override val shape: Shape) :
ComplexFieldOpsND(), FieldND<Complex, ComplexField>, NumbersAddOps<StructureND<Complex>> {
override fun number(value: Number): BufferND<Complex> {
val d = value.toDouble() // minimize conversions
return produce(shape) { d.toComplex() }
}
}
public val ComplexField.ndAlgebra: ComplexFieldOpsND get() = ComplexFieldOpsND
public fun ComplexField.ndAlgebra(vararg shape: Int): ComplexFieldND = ComplexFieldND(shape)

View File

@ -44,7 +44,7 @@ public val Quaternion.r: Double
*/
@OptIn(UnstableKMathAPI::class)
public object QuaternionField : Field<Quaternion>, Norm<Quaternion, Quaternion>, PowerOperations<Quaternion>,
ExponentialOperations<Quaternion>, NumbersAddOperations<Quaternion>, ScaleOperations<Quaternion> {
ExponentialOperations<Quaternion>, NumbersAddOps<Quaternion>, ScaleOperations<Quaternion> {
override val zero: Quaternion = 0.toQuaternion()
override val one: Quaternion = 1.toQuaternion()

View File

@ -52,13 +52,13 @@ public open class FunctionalExpressionGroup<T, out A : Group<T>>(
override val zero: Expression<T> get() = const(algebra.zero)
override fun Expression<T>.unaryMinus(): Expression<T> =
unaryOperation(GroupOperations.MINUS_OPERATION, this)
unaryOperation(GroupOps.MINUS_OPERATION, this)
/**
* Builds an Expression of addition of two another expressions.
*/
override fun add(a: Expression<T>, b: Expression<T>): Expression<T> =
binaryOperation(GroupOperations.PLUS_OPERATION, a, b)
binaryOperation(GroupOps.PLUS_OPERATION, a, b)
// /**
// * Builds an Expression of multiplication of expression by number.
@ -89,7 +89,7 @@ public open class FunctionalExpressionRing<T, out A : Ring<T>>(
* Builds an Expression of multiplication of two expressions.
*/
override fun multiply(a: Expression<T>, b: Expression<T>): Expression<T> =
binaryOperationFunction(RingOperations.TIMES_OPERATION)(a, b)
binaryOperationFunction(RingOps.TIMES_OPERATION)(a, b)
public operator fun Expression<T>.times(arg: T): Expression<T> = this * const(arg)
public operator fun T.times(arg: Expression<T>): Expression<T> = arg * this
@ -108,7 +108,7 @@ public open class FunctionalExpressionField<T, out A : Field<T>>(
* Builds an Expression of division an expression by another one.
*/
override fun divide(a: Expression<T>, b: Expression<T>): Expression<T> =
binaryOperationFunction(FieldOperations.DIV_OPERATION)(a, b)
binaryOperationFunction(FieldOps.DIV_OPERATION)(a, b)
public operator fun Expression<T>.div(arg: T): Expression<T> = this / const(arg)
public operator fun T.div(arg: Expression<T>): Expression<T> = arg / this

View File

@ -31,18 +31,18 @@ public object MstGroup : Group<MST>, NumericAlgebra<MST>, ScaleOperations<MST> {
override fun number(value: Number): MST.Numeric = MstNumericAlgebra.number(value)
override fun bindSymbolOrNull(value: String): Symbol = MstNumericAlgebra.bindSymbolOrNull(value)
override fun add(a: MST, b: MST): MST.Binary = binaryOperationFunction(GroupOperations.PLUS_OPERATION)(a, b)
override fun add(a: MST, b: MST): MST.Binary = binaryOperationFunction(GroupOps.PLUS_OPERATION)(a, b)
override operator fun MST.unaryPlus(): MST.Unary =
unaryOperationFunction(GroupOperations.PLUS_OPERATION)(this)
unaryOperationFunction(GroupOps.PLUS_OPERATION)(this)
override operator fun MST.unaryMinus(): MST.Unary =
unaryOperationFunction(GroupOperations.MINUS_OPERATION)(this)
unaryOperationFunction(GroupOps.MINUS_OPERATION)(this)
override operator fun MST.minus(b: MST): MST.Binary =
binaryOperationFunction(GroupOperations.MINUS_OPERATION)(this, b)
binaryOperationFunction(GroupOps.MINUS_OPERATION)(this, b)
override fun scale(a: MST, value: Double): MST.Binary =
binaryOperationFunction(RingOperations.TIMES_OPERATION)(a, number(value))
binaryOperationFunction(RingOps.TIMES_OPERATION)(a, number(value))
override fun binaryOperationFunction(operation: String): (left: MST, right: MST) -> MST.Binary =
MstNumericAlgebra.binaryOperationFunction(operation)
@ -56,7 +56,7 @@ public object MstGroup : Group<MST>, NumericAlgebra<MST>, ScaleOperations<MST> {
*/
@Suppress("OVERRIDE_BY_INLINE")
@OptIn(UnstableKMathAPI::class)
public object MstRing : Ring<MST>, NumbersAddOperations<MST>, ScaleOperations<MST> {
public object MstRing : Ring<MST>, NumbersAddOps<MST>, ScaleOperations<MST> {
override inline val zero: MST.Numeric get() = MstGroup.zero
override val one: MST.Numeric = number(1.0)
@ -65,10 +65,10 @@ public object MstRing : Ring<MST>, NumbersAddOperations<MST>, ScaleOperations<MS
override fun add(a: MST, b: MST): MST.Binary = MstGroup.add(a, b)
override fun scale(a: MST, value: Double): MST.Binary =
MstGroup.binaryOperationFunction(RingOperations.TIMES_OPERATION)(a, MstGroup.number(value))
MstGroup.binaryOperationFunction(RingOps.TIMES_OPERATION)(a, MstGroup.number(value))
override fun multiply(a: MST, b: MST): MST.Binary =
binaryOperationFunction(RingOperations.TIMES_OPERATION)(a, b)
binaryOperationFunction(RingOps.TIMES_OPERATION)(a, b)
override operator fun MST.unaryPlus(): MST.Unary = MstGroup { +this@unaryPlus }
override operator fun MST.unaryMinus(): MST.Unary = MstGroup { -this@unaryMinus }
@ -86,7 +86,7 @@ public object MstRing : Ring<MST>, NumbersAddOperations<MST>, ScaleOperations<MS
*/
@Suppress("OVERRIDE_BY_INLINE")
@OptIn(UnstableKMathAPI::class)
public object MstField : Field<MST>, NumbersAddOperations<MST>, ScaleOperations<MST> {
public object MstField : Field<MST>, NumbersAddOps<MST>, ScaleOperations<MST> {
override inline val zero: MST.Numeric get() = MstRing.zero
override inline val one: MST.Numeric get() = MstRing.one
@ -95,11 +95,11 @@ public object MstField : Field<MST>, NumbersAddOperations<MST>, ScaleOperations<
override fun add(a: MST, b: MST): MST.Binary = MstRing.add(a, b)
override fun scale(a: MST, value: Double): MST.Binary =
MstGroup.binaryOperationFunction(RingOperations.TIMES_OPERATION)(a, MstGroup.number(value))
MstGroup.binaryOperationFunction(RingOps.TIMES_OPERATION)(a, MstGroup.number(value))
override fun multiply(a: MST, b: MST): MST.Binary = MstRing.multiply(a, b)
override fun divide(a: MST, b: MST): MST.Binary =
binaryOperationFunction(FieldOperations.DIV_OPERATION)(a, b)
binaryOperationFunction(FieldOps.DIV_OPERATION)(a, b)
override operator fun MST.unaryPlus(): MST.Unary = MstRing { +this@unaryPlus }
override operator fun MST.unaryMinus(): MST.Unary = MstRing { -this@unaryMinus }
@ -138,7 +138,7 @@ public object MstExtendedField : ExtendedField<MST>, NumericAlgebra<MST> {
override fun sqrt(arg: MST): MST = unaryOperationFunction(PowerOperations.SQRT_OPERATION)(arg)
override fun scale(a: MST, value: Double): MST =
binaryOperation(GroupOperations.PLUS_OPERATION, a, number(value))
binaryOperation(GroupOps.PLUS_OPERATION, a, number(value))
override fun multiply(a: MST, b: MST): MST.Binary = MstField.multiply(a, b)
override fun divide(a: MST, b: MST): MST.Binary = MstField.divide(a, b)

View File

@ -59,7 +59,7 @@ public fun <T : Any> DerivationResult<T>.grad(vararg variables: Symbol): Point<T
public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
public val context: F,
bindings: Map<Symbol, T>,
) : Field<AutoDiffValue<T>>, ExpressionAlgebra<T, AutoDiffValue<T>>, NumbersAddOperations<AutoDiffValue<T>> {
) : Field<AutoDiffValue<T>>, ExpressionAlgebra<T, AutoDiffValue<T>>, NumbersAddOps<AutoDiffValue<T>> {
override val zero: AutoDiffValue<T> get() = const(context.zero)
override val one: AutoDiffValue<T> get() = const(context.one)

View File

@ -6,12 +6,10 @@
package space.kscience.kmath.linear
import space.kscience.kmath.misc.PerformancePitfall
import space.kscience.kmath.nd.BufferedRingND
import space.kscience.kmath.nd.BufferedRingOpsND
import space.kscience.kmath.nd.as2D
import space.kscience.kmath.nd.asND
import space.kscience.kmath.nd.ndAlgebra
import space.kscience.kmath.operations.Ring
import space.kscience.kmath.operations.invoke
import space.kscience.kmath.operations.*
import space.kscience.kmath.structures.Buffer
import space.kscience.kmath.structures.BufferFactory
import space.kscience.kmath.structures.VirtualBuffer
@ -19,31 +17,28 @@ import space.kscience.kmath.structures.indices
public class BufferedLinearSpace<T, out A : Ring<T>>(
override val elementAlgebra: A,
private val bufferFactory: BufferFactory<T>,
private val bufferAlgebra: BufferAlgebra<T, A>
) : LinearSpace<T, A> {
override val elementAlgebra: A get() = bufferAlgebra.elementAlgebra
private fun ndRing(
rows: Int,
cols: Int,
): BufferedRingND<T, A> = elementAlgebra.ndAlgebra(bufferFactory, rows, cols)
private val ndAlgebra = BufferedRingOpsND(bufferAlgebra)
override fun buildMatrix(rows: Int, columns: Int, initializer: A.(i: Int, j: Int) -> T): Matrix<T> =
ndRing(rows, columns).produce { (i, j) -> elementAlgebra.initializer(i, j) }.as2D()
ndAlgebra.produce(intArrayOf(rows, columns)) { (i, j) -> elementAlgebra.initializer(i, j) }.as2D()
override fun buildVector(size: Int, initializer: A.(Int) -> T): Point<T> =
bufferFactory(size) { elementAlgebra.initializer(it) }
bufferAlgebra.buffer(size) { elementAlgebra.initializer(it) }
override fun Matrix<T>.unaryMinus(): Matrix<T> = ndRing(rowNum, colNum).run {
override fun Matrix<T>.unaryMinus(): Matrix<T> = ndAlgebra {
asND().map { -it }.as2D()
}
override fun Matrix<T>.plus(other: Matrix<T>): Matrix<T> = ndRing(rowNum, colNum).run {
override fun Matrix<T>.plus(other: Matrix<T>): Matrix<T> = ndAlgebra {
require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::plus. Expected $shape but found ${other.shape}" }
asND().plus(other.asND()).as2D()
}
override fun Matrix<T>.minus(other: Matrix<T>): Matrix<T> = ndRing(rowNum, colNum).run {
override fun Matrix<T>.minus(other: Matrix<T>): Matrix<T> = ndAlgebra {
require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::minus. Expected $shape but found ${other.shape}" }
asND().minus(other.asND()).as2D()
}
@ -88,11 +83,11 @@ public class BufferedLinearSpace<T, out A : Ring<T>>(
}
}
override fun Matrix<T>.times(value: T): Matrix<T> = ndRing(rowNum, colNum).run {
override fun Matrix<T>.times(value: T): Matrix<T> = ndAlgebra {
asND().map { it * value }.as2D()
}
}
public fun <T, A : Ring<T>> A.linearSpace(bufferFactory: BufferFactory<T>): BufferedLinearSpace<T, A> =
BufferedLinearSpace(this, bufferFactory)
BufferedLinearSpace(BufferRingOps(this, bufferFactory))

View File

@ -6,11 +6,12 @@
package space.kscience.kmath.linear
import space.kscience.kmath.misc.PerformancePitfall
import space.kscience.kmath.nd.DoubleFieldND
import space.kscience.kmath.nd.DoubleFieldOpsND
import space.kscience.kmath.nd.as2D
import space.kscience.kmath.nd.asND
import space.kscience.kmath.operations.DoubleBufferOperations
import space.kscience.kmath.operations.DoubleBufferOps
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.invoke
import space.kscience.kmath.structures.Buffer
import space.kscience.kmath.structures.DoubleBuffer
@ -18,30 +19,27 @@ public object DoubleLinearSpace : LinearSpace<Double, DoubleField> {
override val elementAlgebra: DoubleField get() = DoubleField
private fun ndRing(
rows: Int,
cols: Int,
): DoubleFieldND = DoubleFieldND(intArrayOf(rows, cols))
override fun buildMatrix(
rows: Int,
columns: Int,
initializer: DoubleField.(i: Int, j: Int) -> Double
): Matrix<Double> = ndRing(rows, columns).produce { (i, j) -> DoubleField.initializer(i, j) }.as2D()
): Matrix<Double> = DoubleFieldOpsND.produce(intArrayOf(rows, columns)) { (i, j) ->
DoubleField.initializer(i, j)
}.as2D()
override fun buildVector(size: Int, initializer: DoubleField.(Int) -> Double): DoubleBuffer =
DoubleBuffer(size) { DoubleField.initializer(it) }
override fun Matrix<Double>.unaryMinus(): Matrix<Double> = ndRing(rowNum, colNum).run {
override fun Matrix<Double>.unaryMinus(): Matrix<Double> = DoubleFieldOpsND {
asND().map { -it }.as2D()
}
override fun Matrix<Double>.plus(other: Matrix<Double>): Matrix<Double> = ndRing(rowNum, colNum).run {
override fun Matrix<Double>.plus(other: Matrix<Double>): Matrix<Double> = DoubleFieldOpsND {
require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::plus. Expected $shape but found ${other.shape}" }
asND().plus(other.asND()).as2D()
}
override fun Matrix<Double>.minus(other: Matrix<Double>): Matrix<Double> = ndRing(rowNum, colNum).run {
override fun Matrix<Double>.minus(other: Matrix<Double>): Matrix<Double> = DoubleFieldOpsND {
require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::minus. Expected $shape but found ${other.shape}" }
asND().minus(other.asND()).as2D()
}
@ -84,23 +82,23 @@ public object DoubleLinearSpace : LinearSpace<Double, DoubleField> {
}
override fun Matrix<Double>.times(value: Double): Matrix<Double> = ndRing(rowNum, colNum).run {
override fun Matrix<Double>.times(value: Double): Matrix<Double> = DoubleFieldOpsND {
asND().map { it * value }.as2D()
}
public override fun Point<Double>.plus(other: Point<Double>): DoubleBuffer = DoubleBufferOperations.run {
public override fun Point<Double>.plus(other: Point<Double>): DoubleBuffer = DoubleBufferOps.run {
this@plus + other
}
public override fun Point<Double>.minus(other: Point<Double>): DoubleBuffer = DoubleBufferOperations.run {
public override fun Point<Double>.minus(other: Point<Double>): DoubleBuffer = DoubleBufferOps.run {
this@minus - other
}
public override fun Point<Double>.times(value: Double): DoubleBuffer = DoubleBufferOperations.run {
public override fun Point<Double>.times(value: Double): DoubleBuffer = DoubleBufferOps.run {
scale(this@times, value)
}
public operator fun Point<Double>.div(value: Double): DoubleBuffer = DoubleBufferOperations.run {
public operator fun Point<Double>.div(value: Double): DoubleBuffer = DoubleBufferOps.run {
scale(this@div, 1.0 / value)
}

View File

@ -10,6 +10,7 @@ import space.kscience.kmath.nd.MutableStructure2D
import space.kscience.kmath.nd.Structure2D
import space.kscience.kmath.nd.StructureFeature
import space.kscience.kmath.nd.as1D
import space.kscience.kmath.operations.BufferRingOps
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.Ring
import space.kscience.kmath.operations.invoke
@ -188,7 +189,7 @@ public interface LinearSpace<T, out A : Ring<T>> {
public fun <T : Any, A : Ring<T>> buffered(
algebra: A,
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing,
): LinearSpace<T, A> = BufferedLinearSpace(algebra, bufferFactory)
): LinearSpace<T, A> = BufferedLinearSpace(BufferRingOps(algebra, bufferFactory))
@Deprecated("use DoubleField.linearSpace")
public val double: LinearSpace<Double, DoubleField> = buffered(DoubleField, ::DoubleBuffer)

View File

@ -7,7 +7,6 @@ package space.kscience.kmath.nd
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.*
import space.kscience.kmath.structures.*
import kotlin.reflect.KClass
/**
@ -19,6 +18,14 @@ import kotlin.reflect.KClass
public class ShapeMismatchException(public val expected: IntArray, public val actual: IntArray) :
RuntimeException("Shape ${actual.contentToString()} doesn't fit in expected shape ${expected.contentToString()}.")
public typealias Shape = IntArray
public fun Shape(shapeFirst: Int, vararg shapeRest: Int): Shape = intArrayOf(shapeFirst, *shapeRest)
public interface WithShape {
public val shape: Shape
}
/**
* The base interface for all ND-algebra implementations.
*
@ -26,20 +33,15 @@ public class ShapeMismatchException(public val expected: IntArray, public val ac
* @param C the type of the element context.
*/
public interface AlgebraND<T, out C : Algebra<T>> {
/**
* The shape of ND-structures this algebra operates on.
*/
public val shape: IntArray
/**
* The algebra over elements of ND structure.
*/
public val elementContext: C
public val elementAlgebra: C
/**
* Produces a new NDStructure using given initializer function.
*/
public fun produce(initializer: C.(IntArray) -> T): StructureND<T>
public fun produce(shape: Shape, initializer: C.(IntArray) -> T): StructureND<T>
/**
* Maps elements from one structure to another one by applying [transform] to them.
@ -54,7 +56,7 @@ public interface AlgebraND<T, out C : Algebra<T>> {
/**
* Combines two structures into one.
*/
public fun combine(a: StructureND<T>, b: StructureND<T>, transform: C.(T, T) -> T): StructureND<T>
public fun zip(left: StructureND<T>, right: StructureND<T>, transform: C.(T, T) -> T): StructureND<T>
/**
* Element-wise invocation of function working on [T] on a [StructureND].
@ -77,7 +79,6 @@ public interface AlgebraND<T, out C : Algebra<T>> {
public companion object
}
/**
* Get a feature of the structure in this scope. Structure features take precedence other context features.
*
@ -89,37 +90,13 @@ public interface AlgebraND<T, out C : Algebra<T>> {
public inline fun <T : Any, reified F : StructureFeature> AlgebraND<T, *>.getFeature(structure: StructureND<T>): F? =
getFeature(structure, F::class)
/**
* Checks if given elements are consistent with this context.
*
* @param structures the structures to check.
* @return the array of valid structures.
*/
internal fun <T, C : Algebra<T>> AlgebraND<T, C>.checkShape(vararg structures: StructureND<T>): Array<out StructureND<T>> =
structures
.map(StructureND<T>::shape)
.singleOrNull { !shape.contentEquals(it) }
?.let<IntArray, Array<out StructureND<T>>> { throw ShapeMismatchException(shape, it) }
?: structures
/**
* Checks if given element is consistent with this context.
*
* @param element the structure to check.
* @return the valid structure.
*/
internal fun <T, C : Algebra<T>> AlgebraND<T, C>.checkShape(element: StructureND<T>): StructureND<T> {
if (!element.shape.contentEquals(shape)) throw ShapeMismatchException(shape, element.shape)
return element
}
/**
* Space of [StructureND].
*
* @param T the type of the element contained in ND structure.
* @param S the type of group over structure elements.
* @param A the type of group over structure elements.
*/
public interface GroupND<T, out S : Group<T>> : Group<StructureND<T>>, AlgebraND<T, S> {
public interface GroupOpsND<T, out A : GroupOps<T>> : GroupOps<StructureND<T>>, AlgebraND<T, A> {
/**
* Element-wise addition.
*
@ -128,7 +105,7 @@ public interface GroupND<T, out S : Group<T>> : Group<StructureND<T>>, AlgebraND
* @return the sum.
*/
override fun add(a: StructureND<T>, b: StructureND<T>): StructureND<T> =
combine(a, b) { aValue, bValue -> add(aValue, bValue) }
zip(a, b) { aValue, bValue -> add(aValue, bValue) }
// TODO move to extensions after KEEP-176
@ -171,13 +148,17 @@ public interface GroupND<T, out S : Group<T>> : Group<StructureND<T>>, AlgebraND
public companion object
}
public interface GroupND<T, out A : Group<T>> : Group<StructureND<T>>, GroupOpsND<T, A>, WithShape {
override val zero: StructureND<T> get() = produce(shape) { elementAlgebra.zero }
}
/**
* Ring of [StructureND].
*
* @param T the type of the element contained in ND structure.
* @param R the type of ring over structure elements.
* @param A the type of ring over structure elements.
*/
public interface RingND<T, out R : Ring<T>> : Ring<StructureND<T>>, GroupND<T, R> {
public interface RingOpsND<T, out A : RingOps<T>> : RingOps<StructureND<T>>, GroupOpsND<T, A> {
/**
* Element-wise multiplication.
*
@ -186,7 +167,7 @@ public interface RingND<T, out R : Ring<T>> : Ring<StructureND<T>>, GroupND<T, R
* @return the product.
*/
override fun multiply(a: StructureND<T>, b: StructureND<T>): StructureND<T> =
combine(a, b) { aValue, bValue -> multiply(aValue, bValue) }
zip(a, b) { aValue, bValue -> multiply(aValue, bValue) }
//TODO move to extensions after KEEP-176
@ -211,13 +192,19 @@ public interface RingND<T, out R : Ring<T>> : Ring<StructureND<T>>, GroupND<T, R
public companion object
}
public interface RingND<T, out A : Ring<T>> : Ring<StructureND<T>>, RingOpsND<T, A>, GroupND<T, A>, WithShape {
override val one: StructureND<T> get() = produce(shape) { elementAlgebra.one }
}
/**
* Field of [StructureND].
*
* @param T the type of the element contained in ND structure.
* @param F the type field over structure elements.
* @param A the type field over structure elements.
*/
public interface FieldND<T, out F : Field<T>> : Field<StructureND<T>>, RingND<T, F> {
public interface FieldOpsND<T, out A : Field<T>> : FieldOps<StructureND<T>>, RingOpsND<T, A>,
ScaleOperations<StructureND<T>> {
/**
* Element-wise division.
*
@ -226,9 +213,9 @@ public interface FieldND<T, out F : Field<T>> : Field<StructureND<T>>, RingND<T,
* @return the quotient.
*/
override fun divide(a: StructureND<T>, b: StructureND<T>): StructureND<T> =
combine(a, b) { aValue, bValue -> divide(aValue, bValue) }
zip(a, b) { aValue, bValue -> divide(aValue, bValue) }
//TODO move to extensions after KEEP-176
//TODO move to extensions after https://github.com/Kotlin/KEEP/blob/master/proposals/context-receivers.md
/**
* Divides an ND structure by an element of it.
*
@ -247,42 +234,9 @@ public interface FieldND<T, out F : Field<T>> : Field<StructureND<T>>, RingND<T,
*/
public operator fun T.div(arg: StructureND<T>): StructureND<T> = arg.map { divide(it, this@div) }
/**
* Element-wise scaling.
*
* @param a the multiplicand.
* @param value the multiplier.
* @return the product.
*/
override fun scale(a: StructureND<T>, value: Double): StructureND<T> = a.map { scale(it, value) }
// @ThreadLocal
// public companion object {
// private val realNDFieldCache: MutableMap<IntArray, RealNDField> = hashMapOf()
//
// /**
// * Create a nd-field for [Double] values or pull it from cache if it was created previously.
// */
// public fun real(vararg shape: Int): RealNDField = realNDFieldCache.getOrPut(shape) { RealNDField(shape) }
//
// /**
// * Create an ND field with boxing generic buffer.
// */
// public fun <T : Any, F : Field<T>> boxing(
// field: F,
// vararg shape: Int,
// bufferFactory: BufferFactory<T> = Buffer.Companion::boxing,
// ): BufferedNDField<T, F> = BufferedNDField(shape, field, bufferFactory)
//
// /**
// * Create a most suitable implementation for nd-field using reified class.
// */
// @Suppress("UNCHECKED_CAST")
// public inline fun <reified T : Any, F : Field<T>> auto(field: F, vararg shape: Int): NDField<T, F> =
// when {
// T::class == Double::class -> real(*shape) as NDField<T, F>
// T::class == Complex::class -> complex(*shape) as BufferedNDField<T, F>
// else -> BoxingNDField(shape, field, Buffer.Companion::auto)
// }
// }
}
public interface FieldND<T, out A : Field<T>> : Field<StructureND<T>>, FieldOpsND<T, A>, RingND<T, A>, WithShape {
override val one: StructureND<T> get() = produce(shape) { elementAlgebra.one }
}

View File

@ -3,145 +3,173 @@
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
*/
@file:OptIn(UnstableKMathAPI::class)
package space.kscience.kmath.nd
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.*
import space.kscience.kmath.structures.Buffer
import space.kscience.kmath.structures.BufferFactory
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
import kotlin.jvm.JvmName
public interface BufferAlgebraND<T, out A : Algebra<T>> : AlgebraND<T, A> {
public val strides: Strides
public val bufferFactory: BufferFactory<T>
public val indexerBuilder: (IntArray) -> ShapeIndex
public val bufferAlgebra: BufferAlgebra<T, A>
override val elementAlgebra: A get() = bufferAlgebra.elementAlgebra
override fun produce(initializer: A.(IntArray) -> T): BufferND<T> = BufferND(
strides,
bufferFactory(strides.linearSize) { offset ->
elementContext.initializer(strides.index(offset))
override fun produce(shape: Shape, initializer: A.(IntArray) -> T): BufferND<T> {
val indexer = indexerBuilder(shape)
return BufferND(
indexer,
bufferAlgebra.buffer(indexer.linearSize) { offset ->
elementAlgebra.initializer(indexer.index(offset))
}
)
}
public fun StructureND<T>.toBufferND(): BufferND<T> = when (this) {
is BufferND -> this
else -> {
val indexer = indexerBuilder(shape)
BufferND(indexer, bufferAlgebra.buffer(indexer.linearSize) { offset -> get(indexer.index(offset)) })
}
}
override fun StructureND<T>.map(transform: A.(T) -> T): BufferND<T> = mapInline(toBufferND(), transform)
override fun StructureND<T>.mapIndexed(transform: A.(index: IntArray, T) -> T): BufferND<T> =
mapIndexedInline(toBufferND(), transform)
override fun zip(left: StructureND<T>, right: StructureND<T>, transform: A.(T, T) -> T): BufferND<T> =
zipInline(left.toBufferND(), right.toBufferND(), transform)
public companion object {
public val defaultIndexerBuilder: (IntArray) -> ShapeIndex = DefaultStrides.Companion::invoke
}
}
public inline fun <T, A : Algebra<T>> BufferAlgebraND<T, A>.mapInline(
arg: BufferND<T>,
crossinline transform: A.(T) -> T
): BufferND<T> {
val indexes = arg.indexes
return BufferND(indexes, bufferAlgebra.mapInline(arg.buffer, transform))
}
internal inline fun <T, A : Algebra<T>> BufferAlgebraND<T, A>.mapIndexedInline(
arg: BufferND<T>,
crossinline transform: A.(index: IntArray, arg: T) -> T
): BufferND<T> {
val indexes = arg.indexes
return BufferND(
indexes,
bufferAlgebra.mapIndexedInline(arg.buffer) { offset, value ->
transform(indexes.index(offset), value)
}
)
public val StructureND<T>.buffer: Buffer<T>
get() = when {
!shape.contentEquals(this@BufferAlgebraND.shape) -> throw ShapeMismatchException(
this@BufferAlgebraND.shape,
shape
)
this is BufferND && this.strides == this@BufferAlgebraND.strides -> this.buffer
else -> bufferFactory(strides.linearSize) { offset -> get(strides.index(offset)) }
}
override fun StructureND<T>.map(transform: A.(T) -> T): BufferND<T> {
val buffer = bufferFactory(strides.linearSize) { offset ->
elementContext.transform(buffer[offset])
}
return BufferND(strides, buffer)
}
override fun StructureND<T>.mapIndexed(transform: A.(index: IntArray, T) -> T): BufferND<T> {
val buffer = bufferFactory(strides.linearSize) { offset ->
elementContext.transform(
strides.index(offset),
buffer[offset]
)
}
return BufferND(strides, buffer)
}
override fun combine(a: StructureND<T>, b: StructureND<T>, transform: A.(T, T) -> T): BufferND<T> {
val buffer = bufferFactory(strides.linearSize) { offset ->
elementContext.transform(a.buffer[offset], b.buffer[offset])
}
return BufferND(strides, buffer)
}
}
public open class BufferedGroupND<T, out A : Group<T>>(
final override val shape: IntArray,
final override val elementContext: A,
final override val bufferFactory: BufferFactory<T>,
) : GroupND<T, A>, BufferAlgebraND<T, A> {
override val strides: Strides = DefaultStrides(shape)
override val zero: BufferND<T> by lazy { produce { zero } }
override fun StructureND<T>.unaryMinus(): StructureND<T> = produce { -get(it) }
internal inline fun <T, A : Algebra<T>> BufferAlgebraND<T, A>.zipInline(
l: BufferND<T>,
r: BufferND<T>,
crossinline block: A.(l: T, r: T) -> T
): BufferND<T> {
require(l.indexes == r.indexes)
val indexes = l.indexes
return BufferND(indexes, bufferAlgebra.zipInline(l.buffer, r.buffer, block))
}
public open class BufferedRingND<T, out R : Ring<T>>(
shape: IntArray,
elementContext: R,
bufferFactory: BufferFactory<T>,
) : BufferedGroupND<T, R>(shape, elementContext, bufferFactory), RingND<T, R> {
override val one: BufferND<T> by lazy { produce { one } }
public open class BufferedGroupNDOps<T, out A : Group<T>>(
override val bufferAlgebra: BufferAlgebra<T, A>,
override val indexerBuilder: (IntArray) -> ShapeIndex = BufferAlgebraND.defaultIndexerBuilder
) : GroupOpsND<T, A>, BufferAlgebraND<T, A> {
override fun StructureND<T>.unaryMinus(): StructureND<T> = map { -it }
}
public open class BufferedFieldND<T, out R : Field<T>>(
shape: IntArray,
elementContext: R,
bufferFactory: BufferFactory<T>,
) : BufferedRingND<T, R>(shape, elementContext, bufferFactory), FieldND<T, R> {
public open class BufferedRingOpsND<T, out A : Ring<T>>(
bufferAlgebra: BufferAlgebra<T, A>,
indexerBuilder: (IntArray) -> ShapeIndex = BufferAlgebraND.defaultIndexerBuilder
) : BufferedGroupNDOps<T, A>(bufferAlgebra, indexerBuilder), RingOpsND<T, A>
public open class BufferedFieldOpsND<T, out A : Field<T>>(
bufferAlgebra: BufferAlgebra<T, A>,
indexerBuilder: (IntArray) -> ShapeIndex = BufferAlgebraND.defaultIndexerBuilder
) : BufferedRingOpsND<T, A>(bufferAlgebra, indexerBuilder), FieldOpsND<T, A> {
public constructor(
elementAlgebra: A,
bufferFactory: BufferFactory<T>,
indexerBuilder: (IntArray) -> ShapeIndex = BufferAlgebraND.defaultIndexerBuilder
) : this(BufferFieldOps(elementAlgebra, bufferFactory), indexerBuilder)
override fun scale(a: StructureND<T>, value: Double): StructureND<T> = a.map { it * value }
}
// group factories
public fun <T, A : Group<T>> A.ndAlgebra(
bufferFactory: BufferFactory<T>,
vararg shape: Int,
): BufferedGroupND<T, A> = BufferedGroupND(shape, this, bufferFactory)
public val <T, A : Group<T>> BufferAlgebra<T, A>.nd: BufferedGroupNDOps<T, A> get() = BufferedGroupNDOps(this)
public val <T, A : Ring<T>> BufferAlgebra<T, A>.nd: BufferedRingOpsND<T, A> get() = BufferedRingOpsND(this)
public val <T, A : Field<T>> BufferAlgebra<T, A>.nd: BufferedFieldOpsND<T, A> get() = BufferedFieldOpsND(this)
@JvmName("withNdGroup")
public inline fun <T, A : Group<T>, R> A.withNdAlgebra(
noinline bufferFactory: BufferFactory<T>,
vararg shape: Int,
action: BufferedGroupND<T, A>.() -> R,
): R {
contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) }
return ndAlgebra(bufferFactory, *shape).run(action)
}
//ring factories
public fun <T, A : Ring<T>> A.ndAlgebra(
bufferFactory: BufferFactory<T>,
public fun <T, A : Algebra<T>> BufferAlgebraND<T, A>.produce(
vararg shape: Int,
): BufferedRingND<T, A> = BufferedRingND(shape, this, bufferFactory)
initializer: A.(IntArray) -> T
): BufferND<T> = produce(shape, initializer)
@JvmName("withNdRing")
public inline fun <T, A : Ring<T>, R> A.withNdAlgebra(
noinline bufferFactory: BufferFactory<T>,
vararg shape: Int,
action: BufferedRingND<T, A>.() -> R,
): R {
contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) }
return ndAlgebra(bufferFactory, *shape).run(action)
}
//// group factories
//public fun <T, A : Group<T>> A.ndAlgebra(
// bufferAlgebra: BufferAlgebra<T, A>,
// vararg shape: Int,
//): BufferedGroupNDOps<T, A> = BufferedGroupNDOps(bufferAlgebra)
//
//@JvmName("withNdGroup")
//public inline fun <T, A : Group<T>, R> A.withNdAlgebra(
// noinline bufferFactory: BufferFactory<T>,
// vararg shape: Int,
// action: BufferedGroupNDOps<T, A>.() -> R,
//): R {
// contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) }
// return ndAlgebra(bufferFactory, *shape).run(action)
//}
//field factories
public fun <T, A : Field<T>> A.ndAlgebra(
bufferFactory: BufferFactory<T>,
vararg shape: Int,
): BufferedFieldND<T, A> = BufferedFieldND(shape, this, bufferFactory)
/**
* Create a [FieldND] for this [Field] inferring proper buffer factory from the type
*/
@UnstableKMathAPI
@Suppress("UNCHECKED_CAST")
public inline fun <reified T : Any, A : Field<T>> A.autoNdAlgebra(
vararg shape: Int,
): FieldND<T, A> = when (this) {
DoubleField -> DoubleFieldND(shape) as FieldND<T, A>
else -> BufferedFieldND(shape, this, Buffer.Companion::auto)
}
@JvmName("withNdField")
public inline fun <T, A : Field<T>, R> A.withNdAlgebra(
noinline bufferFactory: BufferFactory<T>,
vararg shape: Int,
action: BufferedFieldND<T, A>.() -> R,
): R {
contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) }
return ndAlgebra(bufferFactory, *shape).run(action)
}
////ring factories
//public fun <T, A : Ring<T>> A.ndAlgebra(
// bufferFactory: BufferFactory<T>,
// vararg shape: Int,
//): BufferedRingNDOps<T, A> = BufferedRingNDOps(shape, this, bufferFactory)
//
//@JvmName("withNdRing")
//public inline fun <T, A : Ring<T>, R> A.withNdAlgebra(
// noinline bufferFactory: BufferFactory<T>,
// vararg shape: Int,
// action: BufferedRingNDOps<T, A>.() -> R,
//): R {
// contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) }
// return ndAlgebra(bufferFactory, *shape).run(action)
//}
//
////field factories
//public fun <T, A : Field<T>> A.ndAlgebra(
// bufferFactory: BufferFactory<T>,
// vararg shape: Int,
//): BufferedFieldNDOps<T, A> = BufferedFieldNDOps(shape, this, bufferFactory)
//
///**
// * Create a [FieldND] for this [Field] inferring proper buffer factory from the type
// */
//@UnstableKMathAPI
//@Suppress("UNCHECKED_CAST")
//public inline fun <reified T : Any, A : Field<T>> A.autoNdAlgebra(
// vararg shape: Int,
//): FieldND<T, A> = when (this) {
// DoubleField -> DoubleFieldND(shape) as FieldND<T, A>
// else -> BufferedFieldNDOps(shape, this, Buffer.Companion::auto)
//}
//
//@JvmName("withNdField")
//public inline fun <T, A : Field<T>, R> A.withNdAlgebra(
// noinline bufferFactory: BufferFactory<T>,
// vararg shape: Int,
// action: BufferedFieldNDOps<T, A>.() -> R,
//): R {
// contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) }
// return ndAlgebra(bufferFactory, *shape).run(action)
//}

View File

@ -15,26 +15,20 @@ import space.kscience.kmath.structures.MutableBufferFactory
* Represents [StructureND] over [Buffer].
*
* @param T the type of items.
* @param strides The strides to access elements of [Buffer] by linear indices.
* @param indexes The strides to access elements of [Buffer] by linear indices.
* @param buffer The underlying buffer.
*/
public open class BufferND<out T>(
public val strides: Strides,
public val indexes: ShapeIndex,
public val buffer: Buffer<T>,
) : StructureND<T> {
init {
if (strides.linearSize != buffer.size) {
error("Expected buffer side of ${strides.linearSize}, but found ${buffer.size}")
}
}
override operator fun get(index: IntArray): T = buffer[indexes.offset(index)]
override operator fun get(index: IntArray): T = buffer[strides.offset(index)]
override val shape: IntArray get() = strides.shape
override val shape: IntArray get() = indexes.shape
@PerformancePitfall
override fun elements(): Sequence<Pair<IntArray, T>> = strides.indices().map {
override fun elements(): Sequence<Pair<IntArray, T>> = indexes.indices().map {
it to this[it]
}
@ -49,7 +43,7 @@ public inline fun <T, reified R : Any> StructureND<T>.mapToBuffer(
crossinline transform: (T) -> R,
): BufferND<R> {
return if (this is BufferND<T>)
BufferND(this.strides, factory.invoke(strides.linearSize) { transform(buffer[it]) })
BufferND(this.indexes, factory.invoke(indexes.linearSize) { transform(buffer[it]) })
else {
val strides = DefaultStrides(shape)
BufferND(strides, factory.invoke(strides.linearSize) { transform(get(strides.index(it))) })
@ -64,11 +58,11 @@ public inline fun <T, reified R : Any> StructureND<T>.mapToBuffer(
* @param mutableBuffer The underlying buffer.
*/
public class MutableBufferND<T>(
strides: Strides,
strides: ShapeIndex,
public val mutableBuffer: MutableBuffer<T>,
) : MutableStructureND<T>, BufferND<T>(strides, mutableBuffer) {
override fun set(index: IntArray, value: T) {
mutableBuffer[strides.offset(index)] = value
mutableBuffer[indexes.offset(index)] = value
}
}
@ -80,7 +74,7 @@ public inline fun <T, reified R : Any> MutableStructureND<T>.mapToMutableBuffer(
crossinline transform: (T) -> R,
): MutableBufferND<R> {
return if (this is MutableBufferND<T>)
MutableBufferND(this.strides, factory.invoke(strides.linearSize) { transform(mutableBuffer[it]) })
MutableBufferND(this.indexes, factory.invoke(indexes.linearSize) { transform(mutableBuffer[it]) })
else {
val strides = DefaultStrides(shape)
MutableBufferND(strides, factory.invoke(strides.linearSize) { transform(get(strides.index(it))) })

View File

@ -6,108 +6,68 @@
package space.kscience.kmath.nd
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.ExtendedField
import space.kscience.kmath.operations.NumbersAddOperations
import space.kscience.kmath.operations.ScaleOperations
import space.kscience.kmath.operations.*
import space.kscience.kmath.structures.DoubleBuffer
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
@OptIn(UnstableKMathAPI::class)
public class DoubleFieldND(
shape: IntArray,
) : BufferedFieldND<Double, DoubleField>(shape, DoubleField, ::DoubleBuffer),
NumbersAddOperations<StructureND<Double>>,
ScaleOperations<StructureND<Double>>,
ExtendedField<StructureND<Double>> {
public sealed class DoubleFieldOpsND : BufferedFieldOpsND<Double, DoubleField>(DoubleField.bufferAlgebra),
ScaleOperations<StructureND<Double>>, ExtendedFieldOps<StructureND<Double>> {
override val zero: BufferND<Double> by lazy { produce { zero } }
override val one: BufferND<Double> by lazy { produce { one } }
override fun StructureND<Double>.toBufferND(): BufferND<Double> = when (this) {
is BufferND -> this
else -> {
val indexer = indexerBuilder(shape)
BufferND(indexer, DoubleBuffer(indexer.linearSize) { offset -> get(indexer.index(offset)) })
}
}
//TODO do specialization
override fun scale(a: StructureND<Double>, value: Double): BufferND<Double> =
mapInline(a.toBufferND()) { it * value }
override fun power(arg: StructureND<Double>, pow: Number): BufferND<Double> =
mapInline(arg.toBufferND()) { power(it, pow) }
override fun exp(arg: StructureND<Double>): BufferND<Double> = mapInline(arg.toBufferND()) { exp(it) }
override fun ln(arg: StructureND<Double>): BufferND<Double> = mapInline(arg.toBufferND()) { ln(it) }
override fun sin(arg: StructureND<Double>): BufferND<Double> = mapInline(arg.toBufferND()) { sin(it) }
override fun cos(arg: StructureND<Double>): BufferND<Double> = mapInline(arg.toBufferND()) { cos(it) }
override fun tan(arg: StructureND<Double>): BufferND<Double> = mapInline(arg.toBufferND()) { tan(it) }
override fun asin(arg: StructureND<Double>): BufferND<Double> = mapInline(arg.toBufferND()) { asin(it) }
override fun acos(arg: StructureND<Double>): BufferND<Double> = mapInline(arg.toBufferND()) { acos(it) }
override fun atan(arg: StructureND<Double>): BufferND<Double> = mapInline(arg.toBufferND()) { atan(it) }
override fun sinh(arg: StructureND<Double>): BufferND<Double> = mapInline(arg.toBufferND()) { sinh(it) }
override fun cosh(arg: StructureND<Double>): BufferND<Double> = mapInline(arg.toBufferND()) { cosh(it) }
override fun tanh(arg: StructureND<Double>): BufferND<Double> = mapInline(arg.toBufferND()) { tanh(it) }
override fun asinh(arg: StructureND<Double>): BufferND<Double> = mapInline(arg.toBufferND()) { asinh(it) }
override fun acosh(arg: StructureND<Double>): BufferND<Double> = mapInline(arg.toBufferND()) { acosh(it) }
override fun atanh(arg: StructureND<Double>): BufferND<Double> = mapInline(arg.toBufferND()) { atanh(it) }
public companion object : DoubleFieldOpsND()
}
@OptIn(UnstableKMathAPI::class)
public class DoubleFieldND(override val shape: Shape) :
DoubleFieldOpsND(), FieldND<Double, DoubleField>, NumbersAddOps<StructureND<Double>> {
override fun number(value: Number): BufferND<Double> {
val d = value.toDouble() // minimize conversions
return produce { d }
return produce(shape) { d }
}
override val StructureND<Double>.buffer: DoubleBuffer
get() = when {
!shape.contentEquals(this@DoubleFieldND.shape) -> throw ShapeMismatchException(
this@DoubleFieldND.shape,
shape
)
this is BufferND && this.strides == this@DoubleFieldND.strides -> this.buffer as DoubleBuffer
else -> DoubleBuffer(strides.linearSize) { offset -> get(strides.index(offset)) }
}
@Suppress("OVERRIDE_BY_INLINE")
override inline fun StructureND<Double>.map(
transform: DoubleField.(Double) -> Double,
): BufferND<Double> {
val buffer = DoubleBuffer(strides.linearSize) { offset -> DoubleField.transform(buffer.array[offset]) }
return BufferND(strides, buffer)
}
@Suppress("OVERRIDE_BY_INLINE")
override inline fun produce(initializer: DoubleField.(IntArray) -> Double): BufferND<Double> {
val array = DoubleArray(strides.linearSize) { offset ->
val index = strides.index(offset)
DoubleField.initializer(index)
}
return BufferND(strides, DoubleBuffer(array))
}
@Suppress("OVERRIDE_BY_INLINE")
override inline fun StructureND<Double>.mapIndexed(
transform: DoubleField.(index: IntArray, Double) -> Double,
): BufferND<Double> = BufferND(
strides,
buffer = DoubleBuffer(strides.linearSize) { offset ->
DoubleField.transform(
strides.index(offset),
buffer.array[offset]
)
})
@Suppress("OVERRIDE_BY_INLINE")
override inline fun combine(
a: StructureND<Double>,
b: StructureND<Double>,
transform: DoubleField.(Double, Double) -> Double,
): BufferND<Double> {
val buffer = DoubleBuffer(strides.linearSize) { offset ->
DoubleField.transform(a.buffer.array[offset], b.buffer.array[offset])
}
return BufferND(strides, buffer)
}
override fun scale(a: StructureND<Double>, value: Double): StructureND<Double> = a.map { it * value }
override fun power(arg: StructureND<Double>, pow: Number): BufferND<Double> = arg.map { power(it, pow) }
override fun exp(arg: StructureND<Double>): BufferND<Double> = arg.map { exp(it) }
override fun ln(arg: StructureND<Double>): BufferND<Double> = arg.map { ln(it) }
override fun sin(arg: StructureND<Double>): BufferND<Double> = arg.map { sin(it) }
override fun cos(arg: StructureND<Double>): BufferND<Double> = arg.map { cos(it) }
override fun tan(arg: StructureND<Double>): BufferND<Double> = arg.map { tan(it) }
override fun asin(arg: StructureND<Double>): BufferND<Double> = arg.map { asin(it) }
override fun acos(arg: StructureND<Double>): BufferND<Double> = arg.map { acos(it) }
override fun atan(arg: StructureND<Double>): BufferND<Double> = arg.map { atan(it) }
override fun sinh(arg: StructureND<Double>): BufferND<Double> = arg.map { sinh(it) }
override fun cosh(arg: StructureND<Double>): BufferND<Double> = arg.map { cosh(it) }
override fun tanh(arg: StructureND<Double>): BufferND<Double> = arg.map { tanh(it) }
override fun asinh(arg: StructureND<Double>): BufferND<Double> = arg.map { asinh(it) }
override fun acosh(arg: StructureND<Double>): BufferND<Double> = arg.map { acosh(it) }
override fun atanh(arg: StructureND<Double>): BufferND<Double> = arg.map { atanh(it) }
}
public val DoubleField.ndAlgebra: DoubleFieldOpsND get() = DoubleFieldOpsND
public fun DoubleField.ndAlgebra(vararg shape: Int): DoubleFieldND = DoubleFieldND(shape)
/**
* Produce a context for n-dimensional operations inside this real field
*/
@UnstableKMathAPI
public inline fun <R> DoubleField.withNdAlgebra(vararg shape: Int, action: DoubleFieldND.() -> R): R {
contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) }
return DoubleFieldND(shape).run(action)

View File

@ -0,0 +1,120 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
*/
package space.kscience.kmath.nd
import kotlin.native.concurrent.ThreadLocal
/**
* A converter from linear index to multivariate index
*/
public interface ShapeIndex{
public val shape: Shape
/**
* Get linear index from multidimensional index
*/
public fun offset(index: IntArray): Int
/**
* Get multidimensional from linear
*/
public fun index(offset: Int): IntArray
/**
* The size of linear buffer to accommodate all elements of ND-structure corresponding to strides
*/
public val linearSize: Int
// TODO introduce a fast way to calculate index of the next element?
/**
* Iterate over ND indices in a natural order
*/
public fun indices(): Sequence<IntArray>
override fun equals(other: Any?): Boolean
override fun hashCode(): Int
}
/**
* Linear transformation of indexes
*/
public abstract class Strides: ShapeIndex {
/**
* Array strides
*/
public abstract val strides: IntArray
public override fun offset(index: IntArray): Int = index.mapIndexed { i, value ->
if (value < 0 || value >= shape[i]) throw IndexOutOfBoundsException("Index $value out of shape bounds: (0,${this.shape[i]})")
value * strides[i]
}.sum()
// TODO introduce a fast way to calculate index of the next element?
/**
* Iterate over ND indices in a natural order
*/
public override fun indices(): Sequence<IntArray> = (0 until linearSize).asSequence().map(::index)
}
/**
* Simple implementation of [Strides].
*/
public class DefaultStrides private constructor(override val shape: IntArray) : Strides() {
override val linearSize: Int get() = strides[shape.size]
/**
* Strides for memory access
*/
override val strides: IntArray by lazy {
sequence {
var current = 1
yield(1)
shape.forEach {
current *= it
yield(current)
}
}.toList().toIntArray()
}
override fun index(offset: Int): IntArray {
val res = IntArray(shape.size)
var current = offset
var strideIndex = strides.size - 2
while (strideIndex >= 0) {
res[strideIndex] = (current / strides[strideIndex])
current %= strides[strideIndex]
strideIndex--
}
return res
}
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (other !is DefaultStrides) return false
if (!shape.contentEquals(other.shape)) return false
return true
}
override fun hashCode(): Int = shape.contentHashCode()
@ThreadLocal
public companion object {
//private val defaultStridesCache = HashMap<IntArray, Strides>()
/**
* Cached builder for default strides
*/
public operator fun invoke(shape: IntArray): Strides = DefaultStrides(shape)
//defaultStridesCache.getOrPut(shape) { DefaultStrides(shape) }
//TODO fix cache
}
}

View File

@ -6,34 +6,27 @@
package space.kscience.kmath.nd
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.NumbersAddOperations
import space.kscience.kmath.operations.NumbersAddOps
import space.kscience.kmath.operations.ShortRing
import space.kscience.kmath.structures.Buffer
import space.kscience.kmath.structures.ShortBuffer
import space.kscience.kmath.operations.bufferAlgebra
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
public sealed class ShortRingOpsND : BufferedRingOpsND<Short, ShortRing>(ShortRing.bufferAlgebra) {
public companion object : ShortRingOpsND()
}
@OptIn(UnstableKMathAPI::class)
public class ShortRingND(
shape: IntArray,
) : BufferedRingND<Short, ShortRing>(shape, ShortRing, Buffer.Companion::auto),
NumbersAddOperations<StructureND<Short>> {
override val zero: BufferND<Short> by lazy { produce { zero } }
override val one: BufferND<Short> by lazy { produce { one } }
override val shape: Shape
) : ShortRingOpsND(), RingND<Short, ShortRing>, NumbersAddOps<StructureND<Short>> {
override fun number(value: Number): BufferND<Short> {
val d = value.toShort() // minimize conversions
return produce { d }
return produce(shape) { d }
}
}
/**
* Fast element production using function inlining.
*/
public inline fun BufferedRingND<Short, ShortRing>.produceInline(crossinline initializer: ShortRing.(Int) -> Short): BufferND<Short> =
BufferND(strides, ShortBuffer(ShortArray(strides.linearSize) { offset -> ShortRing.initializer(offset) }))
public inline fun <R> ShortRing.withNdAlgebra(vararg shape: Int, action: ShortRingND.() -> R): R {
contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) }
return ShortRingND(shape).run(action)

View File

@ -15,7 +15,6 @@ import space.kscience.kmath.structures.Buffer
import space.kscience.kmath.structures.BufferFactory
import kotlin.jvm.JvmName
import kotlin.math.abs
import kotlin.native.concurrent.ThreadLocal
import kotlin.reflect.KClass
public interface StructureFeature : Feature<StructureFeature>
@ -72,7 +71,7 @@ public interface StructureND<out T> : Featured<StructureFeature> {
if (st1 === st2) return true
// fast comparison of buffers if possible
if (st1 is BufferND && st2 is BufferND && st1.strides == st2.strides)
if (st1 is BufferND && st2 is BufferND && st1.indexes == st2.indexes)
return Buffer.contentEquals(st1.buffer, st2.buffer)
//element by element comparison if it could not be avoided
@ -88,7 +87,7 @@ public interface StructureND<out T> : Featured<StructureFeature> {
if (st1 === st2) return true
// fast comparison of buffers if possible
if (st1 is BufferND && st2 is BufferND && st1.strides == st2.strides)
if (st1 is BufferND && st2 is BufferND && st1.indexes == st2.indexes)
return Buffer.contentEquals(st1.buffer, st2.buffer)
//element by element comparison if it could not be avoided
@ -187,11 +186,11 @@ public fun <T : Comparable<T>> LinearSpace<T, Ring<T>>.contentEquals(
* Indicates whether some [StructureND] is equal to another one with [absoluteTolerance].
*/
@PerformancePitfall
public fun <T : Comparable<T>> GroupND<T, Ring<T>>.contentEquals(
public fun <T : Comparable<T>> GroupOpsND<T, Ring<T>>.contentEquals(
st1: StructureND<T>,
st2: StructureND<T>,
absoluteTolerance: T,
): Boolean = st1.elements().all { (index, value) -> elementContext { (value - st2[index]) } < absoluteTolerance }
): Boolean = st1.elements().all { (index, value) -> elementAlgebra { (value - st2[index]) } < absoluteTolerance }
/**
* Indicates whether some [StructureND] is equal to another one with [absoluteTolerance].
@ -231,107 +230,10 @@ public interface MutableStructureND<T> : StructureND<T> {
* Transform a structure element-by element in place.
*/
@OptIn(PerformancePitfall::class)
public inline fun <T> MutableStructureND<T>.mapInPlace(action: (IntArray, T) -> T): Unit =
public inline fun <T> MutableStructureND<T>.mapInPlace(action: (index: IntArray, t: T) -> T): Unit =
elements().forEach { (index, oldValue) -> this[index] = action(index, oldValue) }
/**
* A way to convert ND indices to linear one and back.
*/
public interface Strides {
/**
* Shape of NDStructure
*/
public val shape: IntArray
/**
* Array strides
*/
public val strides: IntArray
/**
* Get linear index from multidimensional index
*/
public fun offset(index: IntArray): Int = index.mapIndexed { i, value ->
if (value < 0 || value >= shape[i]) throw IndexOutOfBoundsException("Index $value out of shape bounds: (0,${this.shape[i]})")
value * strides[i]
}.sum()
/**
* Get multidimensional from linear
*/
public fun index(offset: Int): IntArray
/**
* The size of linear buffer to accommodate all elements of ND-structure corresponding to strides
*/
public val linearSize: Int
// TODO introduce a fast way to calculate index of the next element?
/**
* Iterate over ND indices in a natural order
*/
public fun indices(): Sequence<IntArray> = (0 until linearSize).asSequence().map(::index)
}
/**
* Simple implementation of [Strides].
*/
public class DefaultStrides private constructor(override val shape: IntArray) : Strides {
override val linearSize: Int
get() = strides[shape.size]
/**
* Strides for memory access
*/
override val strides: IntArray by lazy {
sequence {
var current = 1
yield(1)
shape.forEach {
current *= it
yield(current)
}
}.toList().toIntArray()
}
override fun index(offset: Int): IntArray {
val res = IntArray(shape.size)
var current = offset
var strideIndex = strides.size - 2
while (strideIndex >= 0) {
res[strideIndex] = (current / strides[strideIndex])
current %= strides[strideIndex]
strideIndex--
}
return res
}
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (other !is DefaultStrides) return false
if (!shape.contentEquals(other.shape)) return false
return true
}
override fun hashCode(): Int = shape.contentHashCode()
@ThreadLocal
public companion object {
private val defaultStridesCache = HashMap<IntArray, Strides>()
/**
* Cached builder for default strides
*/
public operator fun invoke(shape: IntArray): Strides =
defaultStridesCache.getOrPut(shape) { DefaultStrides(shape) }
}
}
public inline fun <reified T : Any> StructureND<T>.combine(
public inline fun <reified T : Any> StructureND<T>.zip(
struct: StructureND<T>,
crossinline block: (T, T) -> T,
): StructureND<T> {

View File

@ -0,0 +1,34 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
*/
package space.kscience.kmath.nd
import space.kscience.kmath.operations.Algebra
import space.kscience.kmath.operations.Group
import space.kscience.kmath.operations.Ring
import kotlin.jvm.JvmName
public fun <T, A : Algebra<T>> AlgebraND<T, A>.produce(
shapeFirst: Int,
vararg shapeRest: Int,
initializer: A.(IntArray) -> T
): StructureND<T> = produce(Shape(shapeFirst, *shapeRest), initializer)
public fun <T, A : Group<T>> AlgebraND<T, A>.zero(shape: Shape): StructureND<T> = produce(shape) { zero }
@JvmName("zeroVarArg")
public fun <T, A : Group<T>> AlgebraND<T, A>.zero(
shapeFirst: Int,
vararg shapeRest: Int,
): StructureND<T> = produce(shapeFirst, *shapeRest) { zero }
public fun <T, A : Ring<T>> AlgebraND<T, A>.one(shape: Shape): StructureND<T> = produce(shape) { one }
@JvmName("oneVarArg")
public fun <T, A : Ring<T>> AlgebraND<T, A>.one(
shapeFirst: Int,
vararg shapeRest: Int,
): StructureND<T> = produce(shapeFirst, *shapeRest) { one }

View File

@ -117,7 +117,7 @@ public inline operator fun <A : Algebra<*>, R> A.invoke(block: A.() -> R): R = r
*
* @param T the type of element of this semispace.
*/
public interface GroupOperations<T> : Algebra<T> {
public interface GroupOps<T> : Algebra<T> {
/**
* Addition of two elements.
*
@ -162,7 +162,7 @@ public interface GroupOperations<T> : Algebra<T> {
* @return the difference.
*/
public operator fun T.minus(b: T): T = add(this, -b)
// Dynamic dispatch of operations
override fun unaryOperationFunction(operation: String): (arg: T) -> T = when (operation) {
PLUS_OPERATION -> { arg -> +arg }
MINUS_OPERATION -> { arg -> -arg }
@ -193,7 +193,7 @@ public interface GroupOperations<T> : Algebra<T> {
*
* @param T the type of element of this semispace.
*/
public interface Group<T> : GroupOperations<T> {
public interface Group<T> : GroupOps<T> {
/**
* The neutral element of addition.
*/
@ -206,7 +206,7 @@ public interface Group<T> : GroupOperations<T> {
*
* @param T the type of element of this semiring.
*/
public interface RingOperations<T> : GroupOperations<T> {
public interface RingOps<T> : GroupOps<T> {
/**
* Multiplies two elements.
*
@ -242,7 +242,7 @@ public interface RingOperations<T> : GroupOperations<T> {
*
* @param T the type of element of this ring.
*/
public interface Ring<T> : Group<T>, RingOperations<T> {
public interface Ring<T> : Group<T>, RingOps<T> {
/**
* The neutral element of multiplication
*/
@ -256,7 +256,7 @@ public interface Ring<T> : Group<T>, RingOperations<T> {
*
* @param T the type of element of this semifield.
*/
public interface FieldOperations<T> : RingOperations<T> {
public interface FieldOps<T> : RingOps<T> {
/**
* Division of two elements.
*
@ -295,6 +295,6 @@ public interface FieldOperations<T> : RingOperations<T> {
*
* @param T the type of element of this field.
*/
public interface Field<T> : Ring<T>, FieldOperations<T>, ScaleOperations<T>, NumericAlgebra<T> {
public interface Field<T> : Ring<T>, FieldOps<T>, ScaleOperations<T>, NumericAlgebra<T> {
override fun number(value: Number): T = scale(one, value.toDouble())
}

View File

@ -6,7 +6,7 @@
package space.kscience.kmath.operations
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.nd.BufferedRingND
import space.kscience.kmath.nd.BufferedRingOpsND
import space.kscience.kmath.operations.BigInt.Companion.BASE
import space.kscience.kmath.operations.BigInt.Companion.BASE_SIZE
import space.kscience.kmath.structures.Buffer
@ -26,7 +26,7 @@ private typealias TBase = ULong
* @author Peter Klimai
*/
@OptIn(UnstableKMathAPI::class)
public object BigIntField : Field<BigInt>, NumbersAddOperations<BigInt>, ScaleOperations<BigInt> {
public object BigIntField : Field<BigInt>, NumbersAddOps<BigInt>, ScaleOperations<BigInt> {
override val zero: BigInt = BigInt.ZERO
override val one: BigInt = BigInt.ONE
@ -542,5 +542,5 @@ public inline fun MutableBuffer.Companion.bigInt(size: Int, initializer: (Int) -
public inline fun BigInt.mutableBuffer(size: Int, initializer: (Int) -> BigInt): Buffer<BigInt> =
Buffer.boxing(size, initializer)
public fun BigIntField.nd(vararg shape: Int): BufferedRingND<BigInt, BigIntField> =
BufferedRingND(shape, BigIntField, BigInt::buffer)
public val BigIntField.nd: BufferedRingOpsND<BigInt, BigIntField>
get() = BufferedRingOpsND(BufferRingOps(BigIntField, BigInt::buffer))

View File

@ -5,32 +5,34 @@
package space.kscience.kmath.operations
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.structures.Buffer
import space.kscience.kmath.structures.BufferFactory
import space.kscience.kmath.structures.DoubleBuffer
import space.kscience.kmath.structures.ShortBuffer
public interface WithSize {
public val size: Int
}
/**
* An algebra over [Buffer]
*/
@UnstableKMathAPI
public interface BufferAlgebra<T, A : Algebra<T>> : Algebra<Buffer<T>> {
public val bufferFactory: BufferFactory<T>
public interface BufferAlgebra<T, out A : Algebra<T>> : Algebra<Buffer<T>> {
public val elementAlgebra: A
public val size: Int
public val bufferFactory: BufferFactory<T>
public fun buffer(vararg elements: T): Buffer<T> {
public fun buffer(size: Int, vararg elements: T): Buffer<T> {
require(elements.size == size) { "Expected $size elements but found ${elements.size}" }
return bufferFactory(size) { elements[it] }
}
//TODO move to multi-receiver inline extension
public fun Buffer<T>.map(block: (T) -> T): Buffer<T> = bufferFactory(size) { block(get(it)) }
public fun Buffer<T>.map(block: A.(T) -> T): Buffer<T> = mapInline(this, block)
public fun Buffer<T>.zip(other: Buffer<T>, block: (left: T, right: T) -> T): Buffer<T> {
require(size == other.size) { "Incompatible buffer sizes. left: $size, right: ${other.size}" }
return bufferFactory(size) { block(this[it], other[it]) }
}
public fun Buffer<T>.mapIndexed(block: A.(index: Int, arg: T) -> T): Buffer<T> = mapIndexedInline(this, block)
public fun Buffer<T>.zip(other: Buffer<T>, block: A.(left: T, right: T) -> T): Buffer<T> =
zipInline(this, other, block)
override fun unaryOperationFunction(operation: String): (arg: Buffer<T>) -> Buffer<T> {
val operationFunction = elementAlgebra.unaryOperationFunction(operation)
@ -45,112 +47,149 @@ public interface BufferAlgebra<T, A : Algebra<T>> : Algebra<Buffer<T>> {
}
}
@UnstableKMathAPI
public fun <T> BufferField<T, *>.buffer(initializer: (Int) -> T): Buffer<T> {
/**
* Inline map
*/
public inline fun <T, A : Algebra<T>> BufferAlgebra<T, A>.mapInline(
buffer: Buffer<T>,
crossinline block: A.(T) -> T
): Buffer<T> = bufferFactory(buffer.size) { elementAlgebra.block(buffer[it]) }
/**
* Inline map
*/
public inline fun <T, A : Algebra<T>> BufferAlgebra<T, A>.mapIndexedInline(
buffer: Buffer<T>,
crossinline block: A.(index: Int, arg: T) -> T
): Buffer<T> = bufferFactory(buffer.size) { elementAlgebra.block(it, buffer[it]) }
/**
* Inline zip
*/
public inline fun <T, A : Algebra<T>> BufferAlgebra<T, A>.zipInline(
l: Buffer<T>,
r: Buffer<T>,
crossinline block: A.(l: T, r: T) -> T
): Buffer<T> {
require(l.size == r.size) { "Incompatible buffer sizes. left: ${l.size}, right: ${r.size}" }
return bufferFactory(l.size) { elementAlgebra.block(l[it], r[it]) }
}
public fun <T> BufferAlgebra<T, *>.buffer(size: Int, initializer: (Int) -> T): Buffer<T> {
return bufferFactory(size, initializer)
}
public fun <T, A> A.buffer(initializer: (Int) -> T): Buffer<T> where A : BufferAlgebra<T, *>, A : WithSize {
return bufferFactory(size, initializer)
}
@UnstableKMathAPI
public fun <T, A : TrigonometricOperations<T>> BufferAlgebra<T, A>.sin(arg: Buffer<T>): Buffer<T> =
arg.map(elementAlgebra::sin)
mapInline(arg) { sin(it) }
@UnstableKMathAPI
public fun <T, A : TrigonometricOperations<T>> BufferAlgebra<T, A>.cos(arg: Buffer<T>): Buffer<T> =
arg.map(elementAlgebra::cos)
mapInline(arg) { cos(it) }
@UnstableKMathAPI
public fun <T, A : TrigonometricOperations<T>> BufferAlgebra<T, A>.tan(arg: Buffer<T>): Buffer<T> =
arg.map(elementAlgebra::tan)
mapInline(arg) { tan(it) }
@UnstableKMathAPI
public fun <T, A : TrigonometricOperations<T>> BufferAlgebra<T, A>.asin(arg: Buffer<T>): Buffer<T> =
arg.map(elementAlgebra::asin)
mapInline(arg) { asin(it) }
@UnstableKMathAPI
public fun <T, A : TrigonometricOperations<T>> BufferAlgebra<T, A>.acos(arg: Buffer<T>): Buffer<T> =
arg.map(elementAlgebra::acos)
mapInline(arg) { acos(it) }
@UnstableKMathAPI
public fun <T, A : TrigonometricOperations<T>> BufferAlgebra<T, A>.atan(arg: Buffer<T>): Buffer<T> =
arg.map(elementAlgebra::atan)
mapInline(arg) { atan(it) }
@UnstableKMathAPI
public fun <T, A : ExponentialOperations<T>> BufferAlgebra<T, A>.exp(arg: Buffer<T>): Buffer<T> =
arg.map(elementAlgebra::exp)
mapInline(arg) { exp(it) }
@UnstableKMathAPI
public fun <T, A : ExponentialOperations<T>> BufferAlgebra<T, A>.ln(arg: Buffer<T>): Buffer<T> =
arg.map(elementAlgebra::ln)
mapInline(arg) { ln(it) }
@UnstableKMathAPI
public fun <T, A : ExponentialOperations<T>> BufferAlgebra<T, A>.sinh(arg: Buffer<T>): Buffer<T> =
arg.map(elementAlgebra::sinh)
mapInline(arg) { sinh(it) }
@UnstableKMathAPI
public fun <T, A : ExponentialOperations<T>> BufferAlgebra<T, A>.cosh(arg: Buffer<T>): Buffer<T> =
arg.map(elementAlgebra::cosh)
mapInline(arg) { cosh(it) }
@UnstableKMathAPI
public fun <T, A : ExponentialOperations<T>> BufferAlgebra<T, A>.tanh(arg: Buffer<T>): Buffer<T> =
arg.map(elementAlgebra::tanh)
mapInline(arg) { tanh(it) }
@UnstableKMathAPI
public fun <T, A : ExponentialOperations<T>> BufferAlgebra<T, A>.asinh(arg: Buffer<T>): Buffer<T> =
arg.map(elementAlgebra::asinh)
mapInline(arg) { asinh(it) }
@UnstableKMathAPI
public fun <T, A : ExponentialOperations<T>> BufferAlgebra<T, A>.acosh(arg: Buffer<T>): Buffer<T> =
arg.map(elementAlgebra::acosh)
mapInline(arg) { acosh(it) }
@UnstableKMathAPI
public fun <T, A : ExponentialOperations<T>> BufferAlgebra<T, A>.atanh(arg: Buffer<T>): Buffer<T> =
arg.map(elementAlgebra::atanh)
mapInline(arg) { atanh(it) }
@UnstableKMathAPI
public fun <T, A : PowerOperations<T>> BufferAlgebra<T, A>.pow(arg: Buffer<T>, pow: Number): Buffer<T> =
with(elementAlgebra) { arg.map { power(it, pow) } }
mapInline(arg) { power(it, pow) }
@UnstableKMathAPI
public class BufferField<T, A : Field<T>>(
override val bufferFactory: BufferFactory<T>,
public open class BufferRingOps<T, A: Ring<T>>(
override val elementAlgebra: A,
override val bufferFactory: BufferFactory<T>,
) : BufferAlgebra<T, A>, RingOps<Buffer<T>>{
override fun add(a: Buffer<T>, b: Buffer<T>): Buffer<T> = zipInline(a, b) { l, r -> l + r }
override fun multiply(a: Buffer<T>, b: Buffer<T>): Buffer<T> = zipInline(a, b) { l, r -> l * r }
override fun Buffer<T>.unaryMinus(): Buffer<T> = map { -it }
override fun unaryOperationFunction(operation: String): (arg: Buffer<T>) -> Buffer<T> =
super<BufferAlgebra>.unaryOperationFunction(operation)
override fun binaryOperationFunction(operation: String): (left: Buffer<T>, right: Buffer<T>) -> Buffer<T> =
super<BufferAlgebra>.binaryOperationFunction(operation)
}
public val ShortRing.bufferAlgebra: BufferRingOps<Short, ShortRing>
get() = BufferRingOps(ShortRing, ::ShortBuffer)
public open class BufferFieldOps<T, A : Field<T>>(
elementAlgebra: A,
bufferFactory: BufferFactory<T>,
) : BufferRingOps<T, A>(elementAlgebra, bufferFactory), BufferAlgebra<T, A>, FieldOps<Buffer<T>>, ScaleOperations<Buffer<T>> {
override fun add(a: Buffer<T>, b: Buffer<T>): Buffer<T> = zipInline(a, b) { l, r -> l + r }
override fun multiply(a: Buffer<T>, b: Buffer<T>): Buffer<T> = zipInline(a, b) { l, r -> l * r }
override fun divide(a: Buffer<T>, b: Buffer<T>): Buffer<T> = zipInline(a, b) { l, r -> l / r }
override fun scale(a: Buffer<T>, value: Double): Buffer<T> = a.map { scale(it, value) }
override fun Buffer<T>.unaryMinus(): Buffer<T> = map { -it }
override fun binaryOperationFunction(operation: String): (left: Buffer<T>, right: Buffer<T>) -> Buffer<T> =
super<BufferRingOps>.binaryOperationFunction(operation)
}
public class BufferField<T, A : Field<T>>(
elementAlgebra: A,
bufferFactory: BufferFactory<T>,
override val size: Int
) : BufferAlgebra<T, A>, Field<Buffer<T>> {
) : BufferFieldOps<T, A>(elementAlgebra, bufferFactory), Field<Buffer<T>>, WithSize {
override val zero: Buffer<T> = bufferFactory(size) { elementAlgebra.zero }
override val one: Buffer<T> = bufferFactory(size) { elementAlgebra.one }
override fun add(a: Buffer<T>, b: Buffer<T>): Buffer<T> = a.zip(b, elementAlgebra::add)
override fun multiply(a: Buffer<T>, b: Buffer<T>): Buffer<T> = a.zip(b, elementAlgebra::multiply)
override fun divide(a: Buffer<T>, b: Buffer<T>): Buffer<T> = a.zip(b, elementAlgebra::divide)
override fun scale(a: Buffer<T>, value: Double): Buffer<T> = with(elementAlgebra) { a.map { scale(it, value) } }
override fun Buffer<T>.unaryMinus(): Buffer<T> = with(elementAlgebra) { map { -it } }
override fun unaryOperationFunction(operation: String): (arg: Buffer<T>) -> Buffer<T> {
return super<BufferAlgebra>.unaryOperationFunction(operation)
}
override fun binaryOperationFunction(operation: String): (left: Buffer<T>, right: Buffer<T>) -> Buffer<T> {
return super<BufferAlgebra>.binaryOperationFunction(operation)
}
}
/**
* Generate full buffer field from given buffer operations
*/
public fun <T, A : Field<T>> BufferFieldOps<T, A>.withSize(size: Int): BufferField<T, A> =
BufferField(elementAlgebra, bufferFactory, size)
//Double buffer specialization
@UnstableKMathAPI
public fun BufferField<Double, *>.buffer(vararg elements: Number): Buffer<Double> {
require(elements.size == size) { "Expected $size elements but found ${elements.size}" }
return bufferFactory(size) { elements[it].toDouble() }
}
@UnstableKMathAPI
public fun <T, A : Field<T>> A.bufferAlgebra(bufferFactory: BufferFactory<T>, size: Int): BufferField<T, A> =
BufferField(bufferFactory, this, size)
public fun <T, A : Field<T>> A.bufferAlgebra(bufferFactory: BufferFactory<T>): BufferFieldOps<T, A> =
BufferFieldOps(this, bufferFactory)
@UnstableKMathAPI
public fun DoubleField.bufferAlgebra(size: Int): BufferField<Double, DoubleField> =
BufferField(::DoubleBuffer, DoubleField, size)
public val DoubleField.bufferAlgebra: BufferFieldOps<Double, DoubleField>
get() = BufferFieldOps(DoubleField, ::DoubleBuffer)

View File

@ -13,21 +13,21 @@ import space.kscience.kmath.structures.DoubleBuffer
*
* @property size the size of buffers to operate on.
*/
public class DoubleBufferField(public val size: Int) : ExtendedField<Buffer<Double>>, DoubleBufferOperations() {
public class DoubleBufferField(public val size: Int) : ExtendedField<Buffer<Double>>, DoubleBufferOps() {
override val zero: Buffer<Double> by lazy { DoubleBuffer(size) { 0.0 } }
override val one: Buffer<Double> by lazy { DoubleBuffer(size) { 1.0 } }
override fun sinh(arg: Buffer<Double>): DoubleBuffer = super<DoubleBufferOperations>.sinh(arg)
override fun sinh(arg: Buffer<Double>): DoubleBuffer = super<DoubleBufferOps>.sinh(arg)
override fun cosh(arg: Buffer<Double>): DoubleBuffer = super<DoubleBufferOperations>.cosh(arg)
override fun cosh(arg: Buffer<Double>): DoubleBuffer = super<DoubleBufferOps>.cosh(arg)
override fun tanh(arg: Buffer<Double>): DoubleBuffer = super<DoubleBufferOperations>.tanh(arg)
override fun tanh(arg: Buffer<Double>): DoubleBuffer = super<DoubleBufferOps>.tanh(arg)
override fun asinh(arg: Buffer<Double>): DoubleBuffer = super<DoubleBufferOperations>.asinh(arg)
override fun asinh(arg: Buffer<Double>): DoubleBuffer = super<DoubleBufferOps>.asinh(arg)
override fun acosh(arg: Buffer<Double>): DoubleBuffer = super<DoubleBufferOperations>.acosh(arg)
override fun acosh(arg: Buffer<Double>): DoubleBuffer = super<DoubleBufferOps>.acosh(arg)
override fun atanh(arg: Buffer<Double>): DoubleBuffer= super<DoubleBufferOperations>.atanh(arg)
override fun atanh(arg: Buffer<Double>): DoubleBuffer= super<DoubleBufferOps>.atanh(arg)
// override fun number(value: Number): Buffer<Double> = DoubleBuffer(size) { value.toDouble() }
//

View File

@ -12,9 +12,9 @@ import space.kscience.kmath.structures.DoubleBuffer
import kotlin.math.*
/**
* [ExtendedFieldOperations] over [DoubleBuffer].
* [ExtendedFieldOps] over [DoubleBuffer].
*/
public abstract class DoubleBufferOperations : ExtendedFieldOperations<Buffer<Double>>, Norm<Buffer<Double>, Double> {
public abstract class DoubleBufferOps : ExtendedFieldOps<Buffer<Double>>, Norm<Buffer<Double>, Double> {
override fun Buffer<Double>.unaryMinus(): DoubleBuffer = if (this is DoubleBuffer) {
DoubleBuffer(size) { -array[it] }
} else {
@ -185,7 +185,7 @@ public abstract class DoubleBufferOperations : ExtendedFieldOperations<Buffer<Do
DoubleBuffer(DoubleArray(a.size) { aArray[it] * value })
} else DoubleBuffer(DoubleArray(a.size) { a[it] * value })
public companion object : DoubleBufferOperations()
public companion object : DoubleBufferOps()
}
public object DoubleL2Norm : Norm<Point<Double>, Double> {

View File

@ -150,7 +150,7 @@ public interface ScaleOperations<T> : Algebra<T> {
* TODO to be removed and replaced by extensions after multiple receivers are there
*/
@UnstableKMathAPI
public interface NumbersAddOperations<T> : Ring<T>, NumericAlgebra<T> {
public interface NumbersAddOps<T> : Ring<T>, NumericAlgebra<T> {
/**
* Addition of element and scalar.
*

View File

@ -10,8 +10,8 @@ import kotlin.math.pow as kpow
/**
* Advanced Number-like semifield that implements basic operations.
*/
public interface ExtendedFieldOperations<T> :
FieldOperations<T>,
public interface ExtendedFieldOps<T> :
FieldOps<T>,
TrigonometricOperations<T>,
PowerOperations<T>,
ExponentialOperations<T>,
@ -35,14 +35,14 @@ public interface ExtendedFieldOperations<T> :
ExponentialOperations.ACOSH_OPERATION -> ::acosh
ExponentialOperations.ASINH_OPERATION -> ::asinh
ExponentialOperations.ATANH_OPERATION -> ::atanh
else -> super<FieldOperations>.unaryOperationFunction(operation)
else -> super<FieldOps>.unaryOperationFunction(operation)
}
}
/**
* Advanced Number-like field that implements basic operations.
*/
public interface ExtendedField<T> : ExtendedFieldOperations<T>, Field<T>, NumericAlgebra<T>{
public interface ExtendedField<T> : ExtendedFieldOps<T>, Field<T>, NumericAlgebra<T>{
override fun sinh(arg: T): T = (exp(arg) - exp(-arg)) / 2.0
override fun cosh(arg: T): T = (exp(arg) + exp(-arg)) / 2.0
override fun tanh(arg: T): T = (exp(arg) - exp(-arg)) / (exp(-arg) + exp(arg))

View File

@ -7,6 +7,7 @@ package space.kscience.kmath.structures
import space.kscience.kmath.nd.get
import space.kscience.kmath.nd.ndAlgebra
import space.kscience.kmath.nd.produce
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.invoke
import space.kscience.kmath.testutils.FieldVerifier
@ -21,7 +22,7 @@ internal class NDFieldTest {
@Test
fun testStrides() {
val ndArray = DoubleField.ndAlgebra(10, 10).produce { (it[0] + it[1]).toDouble() }
val ndArray = DoubleField.ndAlgebra.produce(10, 10) { (it[0] + it[1]).toDouble() }
assertEquals(ndArray[5, 5], 10.0)
}
}

View File

@ -7,10 +7,7 @@ package space.kscience.kmath.structures
import space.kscience.kmath.linear.linearSpace
import space.kscience.kmath.misc.PerformancePitfall
import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.nd.combine
import space.kscience.kmath.nd.get
import space.kscience.kmath.nd.ndAlgebra
import space.kscience.kmath.nd.*
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.Norm
import space.kscience.kmath.operations.algebra
@ -22,9 +19,9 @@ import kotlin.test.assertEquals
@Suppress("UNUSED_VARIABLE")
class NumberNDFieldTest {
val algebra = DoubleField.ndAlgebra(3, 3)
val array1 = algebra.produce { (i, j) -> (i + j).toDouble() }
val array2 = algebra.produce { (i, j) -> (i - j).toDouble() }
val algebra = DoubleField.ndAlgebra
val array1 = algebra.produce(3, 3) { (i, j) -> (i + j).toDouble() }
val array2 = algebra.produce(3, 3) { (i, j) -> (i - j).toDouble() }
@Test
fun testSum() {
@ -77,7 +74,7 @@ class NumberNDFieldTest {
@Test
fun combineTest() {
val division = array1.combine(array2, Double::div)
val division = array1.zip(array2, Double::div)
}
object L2Norm : Norm<StructureND<Number>, Double> {

View File

@ -10,12 +10,12 @@ import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.flow.runningReduce
import kotlinx.coroutines.flow.scan
import space.kscience.kmath.operations.GroupOperations
import space.kscience.kmath.operations.GroupOps
import space.kscience.kmath.operations.Ring
import space.kscience.kmath.operations.ScaleOperations
import space.kscience.kmath.operations.invoke
public fun <T> Flow<T>.cumulativeSum(group: GroupOperations<T>): Flow<T> =
public fun <T> Flow<T>.cumulativeSum(group: GroupOps<T>): Flow<T> =
group { runningReduce { sum, element -> sum + element } }
@ExperimentalCoroutinesApi

View File

@ -13,8 +13,8 @@ import space.kscience.kmath.structures.DoubleBuffer
* Map one [BufferND] using function without indices.
*/
public inline fun BufferND<Double>.mapInline(crossinline transform: DoubleField.(Double) -> Double): BufferND<Double> {
val array = DoubleArray(strides.linearSize) { offset -> DoubleField.transform(buffer[offset]) }
return BufferND(strides, DoubleBuffer(array))
val array = DoubleArray(indexes.linearSize) { offset -> DoubleField.transform(buffer[offset]) }
return BufferND(indexes, DoubleBuffer(array))
}
/**

View File

@ -28,10 +28,9 @@ public class DoubleHistogramSpace(
public val dimension: Int get() = lower.size
private val shape = IntArray(binNums.size) { binNums[it] + 2 }
override val shape: IntArray = IntArray(binNums.size) { binNums[it] + 2 }
override val histogramValueSpace: DoubleFieldND = DoubleField.ndAlgebra(*shape)
override val strides: Strides get() = histogramValueSpace.strides
private val binSize = DoubleBuffer(dimension) { (upper[it] - lower[it]) / binNums[it] }
/**
@ -52,7 +51,7 @@ public class DoubleHistogramSpace(
val lowerBoundary = index.mapIndexed { axis, i ->
when (i) {
0 -> Double.NEGATIVE_INFINITY
strides.shape[axis] - 1 -> upper[axis]
shape[axis] - 1 -> upper[axis]
else -> lower[axis] + (i.toDouble()) * binSize[axis]
}
}.asBuffer()
@ -60,7 +59,7 @@ public class DoubleHistogramSpace(
val upperBoundary = index.mapIndexed { axis, i ->
when (i) {
0 -> lower[axis]
strides.shape[axis] - 1 -> Double.POSITIVE_INFINITY
shape[axis] - 1 -> Double.POSITIVE_INFINITY
else -> lower[axis] + (i.toDouble() + 1) * binSize[axis]
}
}.asBuffer()
@ -75,7 +74,7 @@ public class DoubleHistogramSpace(
}
override fun produce(builder: HistogramBuilder<Double>.() -> Unit): IndexedHistogram<Double, Double> {
val ndCounter = StructureND.auto(strides) { Counter.real() }
val ndCounter = StructureND.auto(shape) { Counter.real() }
val hBuilder = HistogramBuilder<Double> { point, value ->
val index = getIndex(point)
ndCounter[index].add(value.toDouble())

View File

@ -8,8 +8,9 @@ package space.kscience.kmath.histogram
import space.kscience.kmath.domains.Domain
import space.kscience.kmath.linear.Point
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.nd.DefaultStrides
import space.kscience.kmath.nd.FieldND
import space.kscience.kmath.nd.Strides
import space.kscience.kmath.nd.Shape
import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.operations.Group
import space.kscience.kmath.operations.ScaleOperations
@ -34,10 +35,10 @@ public class IndexedHistogram<T : Comparable<T>, V : Any>(
return context.produceBin(index, values[index])
}
override val dimension: Int get() = context.strides.shape.size
override val dimension: Int get() = context.shape.size
override val bins: Iterable<Bin<T>>
get() = context.strides.indices().map {
get() = DefaultStrides(context.shape).indices().map {
context.produceBin(it, values[it])
}.asIterable()
@ -49,7 +50,7 @@ public class IndexedHistogram<T : Comparable<T>, V : Any>(
public interface IndexedHistogramSpace<T : Comparable<T>, V : Any>
: Group<IndexedHistogram<T, V>>, ScaleOperations<IndexedHistogram<T, V>> {
//public val valueSpace: Space<V>
public val strides: Strides
public val shape: Shape
public val histogramValueSpace: FieldND<V, *> //= NDAlgebra.space(valueSpace, Buffer.Companion::boxing, *shape),
/**

View File

@ -5,6 +5,7 @@
package space.kscience.kmath.histogram
import space.kscience.kmath.nd.DefaultStrides
import space.kscience.kmath.operations.invoke
import space.kscience.kmath.real.DoubleVector
import kotlin.random.Random
@ -69,7 +70,7 @@ internal class MultivariateHistogramTest {
}
val res = histogram1 - histogram2
assertTrue {
strides.indices().all { index ->
DefaultStrides(shape).indices().all { index ->
res.values[index] <= histogram1.values[index]
}
}

View File

@ -106,8 +106,8 @@ public fun <X : SFun<X>> MST.toSFun(): SFun<X> = when (this) {
is Symbol -> toSVar()
is MST.Unary -> when (operation) {
GroupOperations.PLUS_OPERATION -> +value.toSFun<X>()
GroupOperations.MINUS_OPERATION -> -value.toSFun<X>()
GroupOps.PLUS_OPERATION -> +value.toSFun<X>()
GroupOps.MINUS_OPERATION -> -value.toSFun<X>()
TrigonometricOperations.SIN_OPERATION -> sin(value.toSFun())
TrigonometricOperations.COS_OPERATION -> cos(value.toSFun())
TrigonometricOperations.TAN_OPERATION -> tan(value.toSFun())
@ -124,10 +124,10 @@ public fun <X : SFun<X>> MST.toSFun(): SFun<X> = when (this) {
}
is MST.Binary -> when (operation) {
GroupOperations.PLUS_OPERATION -> left.toSFun<X>() + right.toSFun()
GroupOperations.MINUS_OPERATION -> left.toSFun<X>() - right.toSFun()
RingOperations.TIMES_OPERATION -> left.toSFun<X>() * right.toSFun()
FieldOperations.DIV_OPERATION -> left.toSFun<X>() / right.toSFun()
GroupOps.PLUS_OPERATION -> left.toSFun<X>() + right.toSFun()
GroupOps.MINUS_OPERATION -> left.toSFun<X>() - right.toSFun()
RingOps.TIMES_OPERATION -> left.toSFun<X>() * right.toSFun()
FieldOps.DIV_OPERATION -> left.toSFun<X>() / right.toSFun()
PowerOperations.POW_OPERATION -> left.toSFun<X>() pow (right as MST.Numeric).toSConst()
else -> error("Binary operation $operation not defined in $this")
}

View File

@ -15,13 +15,6 @@ import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.nd.*
import space.kscience.kmath.operations.*
internal fun AlgebraND<*, *>.checkShape(array: INDArray): INDArray {
val arrayShape = array.shape().toIntArray()
if (!shape.contentEquals(arrayShape)) throw ShapeMismatchException(shape, arrayShape)
return array
}
/**
* Represents [AlgebraND] over [Nd4jArrayAlgebra].
*
@ -39,33 +32,34 @@ public sealed interface Nd4jArrayAlgebra<T, out C : Algebra<T>> : AlgebraND<T, C
*/
public val StructureND<T>.ndArray: INDArray
override fun produce(initializer: C.(IntArray) -> T): Nd4jArrayStructure<T> {
override fun produce(shape: Shape, initializer: C.(IntArray) -> T): Nd4jArrayStructure<T> {
val struct = Nd4j.create(*shape)!!.wrap()
struct.indicesIterator().forEach { struct[it] = elementContext.initializer(it) }
struct.indicesIterator().forEach { struct[it] = elementAlgebra.initializer(it) }
return struct
}
override fun StructureND<T>.map(transform: C.(T) -> T): Nd4jArrayStructure<T> {
val newStruct = ndArray.dup().wrap()
newStruct.elements().forEach { (idx, value) -> newStruct[idx] = elementContext.transform(value) }
newStruct.elements().forEach { (idx, value) -> newStruct[idx] = elementAlgebra.transform(value) }
return newStruct
}
override fun StructureND<T>.mapIndexed(
transform: C.(index: IntArray, T) -> T,
): Nd4jArrayStructure<T> {
val new = Nd4j.create(*this@Nd4jArrayAlgebra.shape).wrap()
new.indicesIterator().forEach { idx -> new[idx] = elementContext.transform(idx, this[idx]) }
val new = Nd4j.create(*shape).wrap()
new.indicesIterator().forEach { idx -> new[idx] = elementAlgebra.transform(idx, this[idx]) }
return new
}
override fun combine(
a: StructureND<T>,
b: StructureND<T>,
override fun zip(
left: StructureND<T>,
right: StructureND<T>,
transform: C.(T, T) -> T,
): Nd4jArrayStructure<T> {
val new = Nd4j.create(*shape).wrap()
new.indicesIterator().forEach { idx -> new[idx] = elementContext.transform(a[idx], b[idx]) }
require(left.shape.contentEquals(right.shape)) { "Can't zip tow structures of shape ${left.shape} and ${right.shape}" }
val new = Nd4j.create(*left.shape).wrap()
new.indicesIterator().forEach { idx -> new[idx] = elementAlgebra.transform(left[idx], right[idx]) }
return new
}
}
@ -76,10 +70,7 @@ public sealed interface Nd4jArrayAlgebra<T, out C : Algebra<T>> : AlgebraND<T, C
* @param T the type of the element contained in ND structure.
* @param S the type of space of structure elements.
*/
public sealed interface Nd4jArrayGroup<T, out S : Ring<T>> : GroupND<T, S>, Nd4jArrayAlgebra<T, S> {
override val zero: Nd4jArrayStructure<T>
get() = Nd4j.zeros(*shape).wrap()
public sealed interface Nd4jArrayGroupOps<T, out S : Ring<T>> : GroupOpsND<T, S>, Nd4jArrayAlgebra<T, S> {
override fun add(a: StructureND<T>, b: StructureND<T>): Nd4jArrayStructure<T> =
a.ndArray.add(b.ndArray).wrap()
@ -101,10 +92,7 @@ public sealed interface Nd4jArrayGroup<T, out S : Ring<T>> : GroupND<T, S>, Nd4j
* @param R the type of ring of structure elements.
*/
@OptIn(UnstableKMathAPI::class)
public sealed interface Nd4jArrayRing<T, out R : Ring<T>> : RingND<T, R>, Nd4jArrayGroup<T, R> {
override val one: Nd4jArrayStructure<T>
get() = Nd4j.ones(*shape).wrap()
public sealed interface Nd4jArrayRingOps<T, out R : Ring<T>> : RingOpsND<T, R>, Nd4jArrayGroupOps<T, R> {
override fun multiply(a: StructureND<T>, b: StructureND<T>): Nd4jArrayStructure<T> =
a.ndArray.mul(b.ndArray).wrap()
@ -125,21 +113,12 @@ public sealed interface Nd4jArrayRing<T, out R : Ring<T>> : RingND<T, R>, Nd4jAr
// }
public companion object {
private val intNd4jArrayRingCache: ThreadLocal<MutableMap<IntArray, IntNd4jArrayRing>> =
ThreadLocal.withInitial(::HashMap)
/**
* Creates an [RingND] for [Int] values or pull it from cache if it was created previously.
*/
public fun int(vararg shape: Int): Nd4jArrayRing<Int, IntRing> =
intNd4jArrayRingCache.get().getOrPut(shape) { IntNd4jArrayRing(shape) }
/**
* Creates a most suitable implementation of [RingND] using reified class.
*/
@Suppress("UNCHECKED_CAST")
public inline fun <reified T : Number> auto(vararg shape: Int): Nd4jArrayRing<T, Ring<T>> = when {
T::class == Int::class -> int(*shape) as Nd4jArrayRing<T, Ring<T>>
public inline fun <reified T : Number> auto(vararg shape: Int): Nd4jArrayRingOps<T, Ring<T>> = when {
T::class == Int::class -> IntRing.nd4j as Nd4jArrayRingOps<T, Ring<T>>
else -> throw UnsupportedOperationException("This factory method only supports Long type.")
}
}
@ -151,38 +130,21 @@ public sealed interface Nd4jArrayRing<T, out R : Ring<T>> : RingND<T, R>, Nd4jAr
* @param T the type of the element contained in ND structure.
* @param F the type field of structure elements.
*/
public sealed interface Nd4jArrayField<T, out F : Field<T>> : FieldND<T, F>, Nd4jArrayRing<T, F> {
public sealed interface Nd4jArrayField<T, out F : Field<T>> : FieldOpsND<T, F>, Nd4jArrayRingOps<T, F> {
override fun divide(a: StructureND<T>, b: StructureND<T>): Nd4jArrayStructure<T> =
a.ndArray.div(b.ndArray).wrap()
public operator fun Number.div(b: StructureND<T>): Nd4jArrayStructure<T> = b.ndArray.rdiv(this).wrap()
public companion object {
private val floatNd4jArrayFieldCache: ThreadLocal<MutableMap<IntArray, FloatNd4jArrayField>> =
ThreadLocal.withInitial(::HashMap)
private val doubleNd4JArrayFieldCache: ThreadLocal<MutableMap<IntArray, DoubleNd4jArrayField>> =
ThreadLocal.withInitial(::HashMap)
/**
* Creates an [FieldND] for [Float] values or pull it from cache if it was created previously.
*/
public fun float(vararg shape: Int): Nd4jArrayRing<Float, FloatField> =
floatNd4jArrayFieldCache.get().getOrPut(shape) { FloatNd4jArrayField(shape) }
/**
* Creates an [FieldND] for [Double] values or pull it from cache if it was created previously.
*/
public fun real(vararg shape: Int): Nd4jArrayRing<Double, DoubleField> =
doubleNd4JArrayFieldCache.get().getOrPut(shape) { DoubleNd4jArrayField(shape) }
/**
* Creates a most suitable implementation of [FieldND] using reified class.
*/
@Suppress("UNCHECKED_CAST")
public inline fun <reified T : Any> auto(vararg shape: Int): Nd4jArrayField<T, Field<T>> = when {
T::class == Float::class -> float(*shape) as Nd4jArrayField<T, Field<T>>
T::class == Double::class -> real(*shape) as Nd4jArrayField<T, Field<T>>
T::class == Float::class -> FloatField.nd4j as Nd4jArrayField<T, Field<T>>
T::class == Double::class -> DoubleField.nd4j as Nd4jArrayField<T, Field<T>>
else -> throw UnsupportedOperationException("This factory method only supports Float and Double types.")
}
}
@ -191,8 +153,9 @@ public sealed interface Nd4jArrayField<T, out F : Field<T>> : FieldND<T, F>, Nd4
/**
* Represents intersection of [ExtendedField] and [Field] over [Nd4jArrayStructure].
*/
public sealed interface Nd4jArrayExtendedField<T, out F : ExtendedField<T>> : ExtendedField<StructureND<T>>,
Nd4jArrayField<T, F> {
public sealed interface Nd4jArrayExtendedFieldOps<T, out F : ExtendedField<T>> :
ExtendedFieldOps<StructureND<T>>, Nd4jArrayField<T, F> {
override fun sin(arg: StructureND<T>): StructureND<T> = Transforms.sin(arg.ndArray).wrap()
override fun cos(arg: StructureND<T>): StructureND<T> = Transforms.cos(arg.ndArray).wrap()
override fun asin(arg: StructureND<T>): StructureND<T> = Transforms.asin(arg.ndArray).wrap()
@ -221,63 +184,59 @@ public sealed interface Nd4jArrayExtendedField<T, out F : ExtendedField<T>> : Ex
/**
* Represents [FieldND] over [Nd4jArrayDoubleStructure].
*/
public class DoubleNd4jArrayField(override val shape: IntArray) : Nd4jArrayExtendedField<Double, DoubleField> {
override val elementContext: DoubleField get() = DoubleField
public open class DoubleNd4jArrayFieldOps : Nd4jArrayExtendedFieldOps<Double, DoubleField> {
override val elementAlgebra: DoubleField get() = DoubleField
override fun INDArray.wrap(): Nd4jArrayStructure<Double> = checkShape(this).asDoubleStructure()
override fun INDArray.wrap(): Nd4jArrayStructure<Double> = asDoubleStructure()
@OptIn(PerformancePitfall::class)
override val StructureND<Double>.ndArray: INDArray
get() = when (this) {
is Nd4jArrayStructure<Double> -> checkShape(ndArray)
is Nd4jArrayStructure<Double> -> ndArray
else -> Nd4j.zeros(*shape).also {
elements().forEach { (idx, value) -> it.putScalar(idx, value) }
}
}
override fun scale(a: StructureND<Double>, value: Double): Nd4jArrayStructure<Double> {
return a.ndArray.mul(value).wrap()
}
override fun scale(a: StructureND<Double>, value: Double): Nd4jArrayStructure<Double> = a.ndArray.mul(value).wrap()
override operator fun StructureND<Double>.div(arg: Double): Nd4jArrayStructure<Double> {
return ndArray.div(arg).wrap()
}
override operator fun StructureND<Double>.div(arg: Double): Nd4jArrayStructure<Double> = ndArray.div(arg).wrap()
override operator fun StructureND<Double>.plus(arg: Double): Nd4jArrayStructure<Double> {
return ndArray.add(arg).wrap()
}
override operator fun StructureND<Double>.plus(arg: Double): Nd4jArrayStructure<Double> = ndArray.add(arg).wrap()
override operator fun StructureND<Double>.minus(arg: Double): Nd4jArrayStructure<Double> {
return ndArray.sub(arg).wrap()
}
override operator fun StructureND<Double>.minus(arg: Double): Nd4jArrayStructure<Double> = ndArray.sub(arg).wrap()
override operator fun StructureND<Double>.times(arg: Double): Nd4jArrayStructure<Double> {
return ndArray.mul(arg).wrap()
}
override operator fun StructureND<Double>.times(arg: Double): Nd4jArrayStructure<Double> = ndArray.mul(arg).wrap()
override operator fun Double.div(arg: StructureND<Double>): Nd4jArrayStructure<Double> {
return arg.ndArray.rdiv(this).wrap()
}
override operator fun Double.div(arg: StructureND<Double>): Nd4jArrayStructure<Double> =
arg.ndArray.rdiv(this).wrap()
override operator fun Double.minus(arg: StructureND<Double>): Nd4jArrayStructure<Double> {
return arg.ndArray.rsub(this).wrap()
}
override operator fun Double.minus(arg: StructureND<Double>): Nd4jArrayStructure<Double> =
arg.ndArray.rsub(this).wrap()
public companion object : DoubleNd4jArrayFieldOps()
}
public fun DoubleField.nd4j(vararg shape: Int): DoubleNd4jArrayField = DoubleNd4jArrayField(intArrayOf(*shape))
public val DoubleField.nd4j: DoubleNd4jArrayFieldOps get() = DoubleNd4jArrayFieldOps
public class DoubleNd4jArrayField(override val shape: Shape) : DoubleNd4jArrayFieldOps(), FieldND<Double, DoubleField>
public fun DoubleField.nd4j(shapeFirst: Int, vararg shapeRest: Int): DoubleNd4jArrayField =
DoubleNd4jArrayField(intArrayOf(shapeFirst, * shapeRest))
/**
* Represents [FieldND] over [Nd4jArrayStructure] of [Float].
*/
public class FloatNd4jArrayField(override val shape: IntArray) : Nd4jArrayExtendedField<Float, FloatField> {
override val elementContext: FloatField get() = FloatField
public open class FloatNd4jArrayFieldOps : Nd4jArrayExtendedFieldOps<Float, FloatField> {
override val elementAlgebra: FloatField get() = FloatField
override fun INDArray.wrap(): Nd4jArrayStructure<Float> = checkShape(this).asFloatStructure()
override fun INDArray.wrap(): Nd4jArrayStructure<Float> = asFloatStructure()
@OptIn(PerformancePitfall::class)
override val StructureND<Float>.ndArray: INDArray
get() = when (this) {
is Nd4jArrayStructure<Float> -> checkShape(ndArray)
is Nd4jArrayStructure<Float> -> ndArray
else -> Nd4j.zeros(*shape).also {
elements().forEach { (idx, value) -> it.putScalar(idx, value) }
}
@ -303,21 +262,29 @@ public class FloatNd4jArrayField(override val shape: IntArray) : Nd4jArrayExtend
override operator fun Float.minus(arg: StructureND<Float>): Nd4jArrayStructure<Float> =
arg.ndArray.rsub(this).wrap()
public companion object : FloatNd4jArrayFieldOps()
}
public class FloatNd4jArrayField(override val shape: Shape) : FloatNd4jArrayFieldOps(), RingND<Float, FloatField>
public val FloatField.nd4j: FloatNd4jArrayFieldOps get() = FloatNd4jArrayFieldOps
public fun FloatField.nd4j(shapeFirst: Int, vararg shapeRest: Int): FloatNd4jArrayField =
FloatNd4jArrayField(intArrayOf(shapeFirst, * shapeRest))
/**
* Represents [RingND] over [Nd4jArrayIntStructure].
*/
public class IntNd4jArrayRing(override val shape: IntArray) : Nd4jArrayRing<Int, IntRing> {
override val elementContext: IntRing
get() = IntRing
public open class IntNd4jArrayRingOps : Nd4jArrayRingOps<Int, IntRing> {
override val elementAlgebra: IntRing get() = IntRing
override fun INDArray.wrap(): Nd4jArrayStructure<Int> = checkShape(this).asIntStructure()
override fun INDArray.wrap(): Nd4jArrayStructure<Int> = asIntStructure()
@OptIn(PerformancePitfall::class)
override val StructureND<Int>.ndArray: INDArray
get() = when (this) {
is Nd4jArrayStructure<Int> -> checkShape(ndArray)
is Nd4jArrayStructure<Int> -> ndArray
else -> Nd4j.zeros(*shape).also {
elements().forEach { (idx, value) -> it.putScalar(idx, value) }
}
@ -334,4 +301,13 @@ public class IntNd4jArrayRing(override val shape: IntArray) : Nd4jArrayRing<Int,
override operator fun Int.minus(arg: StructureND<Int>): Nd4jArrayStructure<Int> =
arg.ndArray.rsub(this).wrap()
public companion object : IntNd4jArrayRingOps()
}
public val IntRing.nd4j: IntNd4jArrayRingOps get() = IntNd4jArrayRingOps
public class IntNd4jArrayRing(override val shape: Shape) : IntNd4jArrayRingOps(), RingND<Int, IntRing>
public fun IntRing.nd4j(shapeFirst: Int, vararg shapeRest: Int): IntNd4jArrayRing =
IntNd4jArrayRing(intArrayOf(shapeFirst, * shapeRest))

View File

@ -8,6 +8,10 @@ package space.kscience.kmath.nd4j
import org.nd4j.linalg.factory.Nd4j
import space.kscience.kmath.misc.PerformancePitfall
import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.nd.one
import space.kscience.kmath.nd.produce
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.IntRing
import space.kscience.kmath.operations.invoke
import kotlin.math.PI
import kotlin.test.Test
@ -19,7 +23,7 @@ import kotlin.test.fail
internal class Nd4jArrayAlgebraTest {
@Test
fun testProduce() {
val res = with(DoubleNd4jArrayField(intArrayOf(2, 2))) { produce { it.sum().toDouble() } }
val res = DoubleField.nd4j.produce(2, 2) { it.sum().toDouble() }
val expected = (Nd4j.create(2, 2) ?: fail()).asDoubleStructure()
expected[intArrayOf(0, 0)] = 0.0
expected[intArrayOf(0, 1)] = 1.0
@ -30,7 +34,9 @@ internal class Nd4jArrayAlgebraTest {
@Test
fun testMap() {
val res = with(IntNd4jArrayRing(intArrayOf(2, 2))) { one.map { it + it * 2 } }
val res = IntRing.nd4j {
one(2, 2).map { it + it * 2 }
}
val expected = (Nd4j.create(2, 2) ?: fail()).asIntStructure()
expected[intArrayOf(0, 0)] = 3
expected[intArrayOf(0, 1)] = 3
@ -41,7 +47,7 @@ internal class Nd4jArrayAlgebraTest {
@Test
fun testAdd() {
val res = with(IntNd4jArrayRing(intArrayOf(2, 2))) { one + 25 }
val res = IntRing.nd4j { one(2, 2) + 25 }
val expected = (Nd4j.create(2, 2) ?: fail()).asIntStructure()
expected[intArrayOf(0, 0)] = 26
expected[intArrayOf(0, 1)] = 26
@ -51,10 +57,10 @@ internal class Nd4jArrayAlgebraTest {
}
@Test
fun testSin() = DoubleNd4jArrayField(intArrayOf(2, 2)).invoke {
val initial = produce { (i, j) -> if (i == j) PI / 2 else 0.0 }
fun testSin() = DoubleField.nd4j{
val initial = produce(2, 2) { (i, j) -> if (i == j) PI / 2 else 0.0 }
val transformed = sin(initial)
val expected = produce { (i, j) -> if (i == j) 1.0 else 0.0 }
val expected = produce(2, 2) { (i, j) -> if (i == j) 1.0 else 0.0 }
println(transformed)
assertTrue { StructureND.contentEquals(transformed, expected) }

View File

@ -64,8 +64,8 @@ public fun MST.toIExpr(): IExpr = when (this) {
}
is MST.Unary -> when (operation) {
GroupOperations.PLUS_OPERATION -> value.toIExpr()
GroupOperations.MINUS_OPERATION -> F.Negate(value.toIExpr())
GroupOps.PLUS_OPERATION -> value.toIExpr()
GroupOps.MINUS_OPERATION -> F.Negate(value.toIExpr())
TrigonometricOperations.SIN_OPERATION -> F.Sin(value.toIExpr())
TrigonometricOperations.COS_OPERATION -> F.Cos(value.toIExpr())
TrigonometricOperations.TAN_OPERATION -> F.Tan(value.toIExpr())
@ -85,10 +85,10 @@ public fun MST.toIExpr(): IExpr = when (this) {
}
is MST.Binary -> when (operation) {
GroupOperations.PLUS_OPERATION -> left.toIExpr() + right.toIExpr()
GroupOperations.MINUS_OPERATION -> left.toIExpr() - right.toIExpr()
RingOperations.TIMES_OPERATION -> left.toIExpr() * right.toIExpr()
FieldOperations.DIV_OPERATION -> F.Divide(left.toIExpr(), right.toIExpr())
GroupOps.PLUS_OPERATION -> left.toIExpr() + right.toIExpr()
GroupOps.MINUS_OPERATION -> left.toIExpr() - right.toIExpr()
RingOps.TIMES_OPERATION -> left.toIExpr() * right.toIExpr()
FieldOps.DIV_OPERATION -> F.Divide(left.toIExpr(), right.toIExpr())
PowerOperations.POW_OPERATION -> F.Power(left.toIExpr(), F.symjify((right as MST.Numeric).value))
else -> error("Binary operation $operation not defined in $this")
}

View File

@ -373,8 +373,12 @@ public open class DoubleTensorAlgebra :
return resTensor
}
override fun diagonalEmbedding(diagonalEntries: Tensor<Double>, offset: Int, dim1: Int, dim2: Int):
DoubleTensor {
override fun diagonalEmbedding(
diagonalEntries: Tensor<Double>,
offset: Int,
dim1: Int,
dim2: Int
): DoubleTensor {
val n = diagonalEntries.shape.size
val d1 = minusIndexFrom(n + 1, dim1)
val d2 = minusIndexFrom(n + 1, dim2)

View File

@ -44,7 +44,7 @@ internal fun indexFromOffset(offset: Int, strides: IntArray, nDim: Int): IntArra
*
* @param shape the shape of the tensor.
*/
internal class TensorLinearStructure(override val shape: IntArray) : Strides {
internal class TensorLinearStructure(override val shape: IntArray) : Strides() {
override val strides: IntArray
get() = stridesFromShape(shape)
@ -54,4 +54,18 @@ internal class TensorLinearStructure(override val shape: IntArray) : Strides {
override val linearSize: Int
get() = shape.reduce(Int::times)
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (other == null || this::class != other::class) return false
other as TensorLinearStructure
if (!shape.contentEquals(other.shape)) return false
return true
}
override fun hashCode(): Int {
return shape.contentHashCode()
}
}

View File

@ -26,8 +26,11 @@ internal fun <T> Tensor<T>.copyToBufferedTensor(): BufferedTensor<T> =
internal fun <T> Tensor<T>.toBufferedTensor(): BufferedTensor<T> = when (this) {
is BufferedTensor<T> -> this
is MutableBufferND<T> -> if (this.strides.strides contentEquals TensorLinearStructure(this.shape).strides)
BufferedTensor(this.shape, this.mutableBuffer, 0) else this.copyToBufferedTensor()
is MutableBufferND<T> -> if (this.indexes == TensorLinearStructure(this.shape)) {
BufferedTensor(this.shape, this.mutableBuffer, 0)
} else {
this.copyToBufferedTensor()
}
else -> this.copyToBufferedTensor()
}

View File

@ -11,7 +11,7 @@ import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.nd.*
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.ExtendedField
import space.kscience.kmath.operations.NumbersAddOperations
import space.kscience.kmath.operations.NumbersAddOps
import space.kscience.kmath.operations.ScaleOperations
@Suppress("OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
@ -34,7 +34,7 @@ public fun F64Array.asStructure(): ViktorStructureND = ViktorStructureND(this)
@OptIn(UnstableKMathAPI::class)
@Suppress("OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
public class ViktorFieldND(override val shape: IntArray) : FieldND<Double, DoubleField>,
NumbersAddOperations<StructureND<Double>>, ExtendedField<StructureND<Double>>,
NumbersAddOps<StructureND<Double>>, ExtendedField<StructureND<Double>>,
ScaleOperations<StructureND<Double>> {
public val StructureND<Double>.f64Buffer: F64Array
@ -44,7 +44,7 @@ public class ViktorFieldND(override val shape: IntArray) : FieldND<Double, Doubl
shape
)
this is ViktorStructureND && this.f64Buffer.shape.contentEquals(this@ViktorFieldND.shape) -> this.f64Buffer
else -> produce { this@f64Buffer[it] }.f64Buffer
else -> produce(shape) { this@f64Buffer[it] }.f64Buffer
}
override val zero: ViktorStructureND by lazy { F64Array.full(init = 0.0, shape = shape).asStructure() }
@ -52,9 +52,9 @@ public class ViktorFieldND(override val shape: IntArray) : FieldND<Double, Doubl
private val strides: Strides = DefaultStrides(shape)
override val elementContext: DoubleField get() = DoubleField
override val elementAlgebra: DoubleField get() = DoubleField
override fun produce(initializer: DoubleField.(IntArray) -> Double): ViktorStructureND =
override fun produce(shape: IntArray, initializer: DoubleField.(IntArray) -> Double): ViktorStructureND =
F64Array(*shape).apply {
this@ViktorFieldND.strides.indices().forEach { index ->
set(value = DoubleField.initializer(index), indices = index)
@ -78,13 +78,13 @@ public class ViktorFieldND(override val shape: IntArray) : FieldND<Double, Doubl
}
}.asStructure()
override fun combine(
a: StructureND<Double>,
b: StructureND<Double>,
override fun zip(
left: StructureND<Double>,
right: StructureND<Double>,
transform: DoubleField.(Double, Double) -> Double,
): ViktorStructureND = F64Array(*shape).apply {
this@ViktorFieldND.strides.indices().forEach { index ->
set(value = DoubleField.transform(a[index], b[index]), indices = index)
set(value = DoubleField.transform(left[index], right[index]), indices = index)
}
}.asStructure()
@ -123,4 +123,4 @@ public class ViktorFieldND(override val shape: IntArray) : FieldND<Double, Doubl
override fun ln(arg: StructureND<Double>): ViktorStructureND = arg.f64Buffer.log().asStructure()
}
public fun ViktorNDField(vararg shape: Int): ViktorFieldND = ViktorFieldND(shape)
public fun ViktorFieldND(vararg shape: Int): ViktorFieldND = ViktorFieldND(shape)

View File

@ -1,16 +1,18 @@
pluginManagement {
repositories {
mavenLocal()
maven("https://repo.kotlin.link")
mavenCentral()
gradlePluginPortal()
}
val kotlinVersion = "1.6.0-M1"
val kotlinVersion = "1.6.0-RC"
val toolsVersion = "0.10.5"
plugins {
id("org.jetbrains.kotlinx.benchmark") version "0.3.1"
id("ru.mipt.npm.gradle.project") version "0.10.5"
id("ru.mipt.npm.gradle.project") version toolsVersion
id("ru.mipt.npm.gradle.jvm") version toolsVersion
id("ru.mipt.npm.gradle.mpp") version toolsVersion
kotlin("multiplatform") version kotlinVersion
kotlin("plugin.allopen") version kotlinVersion
}