Shapeless ND and Buffer algebras

This commit is contained in:
Alexander Nozik 2021-10-17 12:42:35 +03:00
parent d0354da80a
commit 688382eed6
32 changed files with 382 additions and 281 deletions

View File

@ -57,12 +57,12 @@ internal class NDFieldBenchmark {
blackhole.consume(res) blackhole.consume(res)
} }
@Benchmark // @Benchmark
fun nd4jAdd(blackhole: Blackhole) = with(nd4jField) { // fun nd4jAdd(blackhole: Blackhole) = with(nd4jField) {
var res: StructureND<Double> = one(dim, dim) // var res: StructureND<Double> = one(dim, dim)
repeat(n) { res += 1.0 } // repeat(n) { res += 1.0 }
blackhole.consume(res) // blackhole.consume(res)
} // }
private companion object { private companion object {
private const val dim = 1000 private const val dim = 1000

View File

@ -9,6 +9,7 @@ import space.kscience.kmath.integration.gaussIntegrator
import space.kscience.kmath.integration.integrate import space.kscience.kmath.integration.integrate
import space.kscience.kmath.integration.value import space.kscience.kmath.integration.value
import space.kscience.kmath.nd.StructureND import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.nd.produce
import space.kscience.kmath.nd.withNdAlgebra import space.kscience.kmath.nd.withNdAlgebra
import space.kscience.kmath.operations.algebra import space.kscience.kmath.operations.algebra
import space.kscience.kmath.operations.invoke import space.kscience.kmath.operations.invoke

View File

@ -12,6 +12,7 @@ import space.kscience.kmath.linear.transpose
import space.kscience.kmath.nd.StructureND import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.nd.as2D import space.kscience.kmath.nd.as2D
import space.kscience.kmath.nd.ndAlgebra import space.kscience.kmath.nd.ndAlgebra
import space.kscience.kmath.nd.produce
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.invoke import space.kscience.kmath.operations.invoke
import kotlin.system.measureTimeMillis import kotlin.system.measureTimeMillis

View File

@ -8,10 +8,8 @@ package space.kscience.kmath.structures
import kotlinx.coroutines.DelicateCoroutinesApi import kotlinx.coroutines.DelicateCoroutinesApi
import kotlinx.coroutines.GlobalScope import kotlinx.coroutines.GlobalScope
import org.nd4j.linalg.factory.Nd4j import org.nd4j.linalg.factory.Nd4j
import space.kscience.kmath.nd.StructureND import space.kscience.kmath.nd.*
import space.kscience.kmath.nd.autoNdAlgebra import space.kscience.kmath.nd4j.nd4j
import space.kscience.kmath.nd.ndAlgebra
import space.kscience.kmath.nd4j.Nd4jArrayField
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.invoke import space.kscience.kmath.operations.invoke
import space.kscience.kmath.viktor.ViktorFieldND import space.kscience.kmath.viktor.ViktorFieldND
@ -31,15 +29,17 @@ fun main() {
Nd4j.zeros(0) Nd4j.zeros(0)
val dim = 1000 val dim = 1000
val n = 1000 val n = 1000
val shape = Shape(dim, dim)
// automatically build context most suited for given type. // automatically build context most suited for given type.
val autoField = DoubleField.autoNdAlgebra(dim, dim) val autoField = BufferedFieldOpsND(DoubleField, Buffer.Companion::auto)
// specialized nd-field for Double. It works as generic Double field as well. // specialized nd-field for Double. It works as generic Double field as well.
val realField = DoubleField.ndAlgebra(dim, dim) val realField = DoubleField.ndAlgebra
//A generic boxing field. It should be used for objects, not primitives. //A generic boxing field. It should be used for objects, not primitives.
val boxingField = DoubleField.ndAlgebra(Buffer.Companion::boxing, dim, dim) val boxingField = BufferedFieldOpsND(DoubleField, Buffer.Companion::boxing)
// Nd4j specialized field. // Nd4j specialized field.
val nd4jField = Nd4jArrayField.real(dim, dim) val nd4jField = DoubleField.nd4j
//viktor field //viktor field
val viktorField = ViktorFieldND(dim, dim) val viktorField = ViktorFieldND(dim, dim)
//parallel processing based on Java Streams //parallel processing based on Java Streams
@ -47,21 +47,21 @@ fun main() {
measureAndPrint("Boxing addition") { measureAndPrint("Boxing addition") {
boxingField { boxingField {
var res: StructureND<Double> = one var res: StructureND<Double> = one(shape)
repeat(n) { res += 1.0 } repeat(n) { res += 1.0 }
} }
} }
measureAndPrint("Specialized addition") { measureAndPrint("Specialized addition") {
realField { realField {
var res: StructureND<Double> = one var res: StructureND<Double> = one(shape)
repeat(n) { res += 1.0 } repeat(n) { res += 1.0 }
} }
} }
measureAndPrint("Nd4j specialized addition") { measureAndPrint("Nd4j specialized addition") {
nd4jField { nd4jField {
var res: StructureND<Double> = one var res: StructureND<Double> = one(shape)
repeat(n) { res += 1.0 } repeat(n) { res += 1.0 }
} }
} }
@ -82,13 +82,13 @@ fun main() {
measureAndPrint("Automatic field addition") { measureAndPrint("Automatic field addition") {
autoField { autoField {
var res: StructureND<Double> = one var res: StructureND<Double> = one(shape)
repeat(n) { res += 1.0 } repeat(n) { res += 1.0 }
} }
} }
measureAndPrint("Lazy addition") { measureAndPrint("Lazy addition") {
val res = realField.one.mapAsync(GlobalScope) { val res = realField.one(shape).mapAsync(GlobalScope) {
var c = 0.0 var c = 0.0
repeat(n) { repeat(n) {
c += 1.0 c += 1.0

View File

@ -22,12 +22,12 @@ class StreamDoubleFieldND(override val shape: IntArray) : FieldND<Double, Double
private val strides = DefaultStrides(shape) private val strides = DefaultStrides(shape)
override val elementAlgebra: DoubleField get() = DoubleField override val elementAlgebra: DoubleField get() = DoubleField
override val zero: BufferND<Double> by lazy { produce { zero } } override val zero: BufferND<Double> by lazy { produce(shape) { zero } }
override val one: BufferND<Double> by lazy { produce { one } } override val one: BufferND<Double> by lazy { produce(shape) { one } }
override fun number(value: Number): BufferND<Double> { override fun number(value: Number): BufferND<Double> {
val d = value.toDouble() // minimize conversions val d = value.toDouble() // minimize conversions
return produce { d } return produce(shape) { d }
} }
private val StructureND<Double>.buffer: DoubleBuffer private val StructureND<Double>.buffer: DoubleBuffer
@ -40,7 +40,7 @@ class StreamDoubleFieldND(override val shape: IntArray) : FieldND<Double, Double
else -> DoubleBuffer(strides.linearSize) { offset -> get(strides.index(offset)) } else -> DoubleBuffer(strides.linearSize) { offset -> get(strides.index(offset)) }
} }
override fun produce(initializer: DoubleField.(IntArray) -> Double): BufferND<Double> { override fun produce(shape: Shape, initializer: DoubleField.(IntArray) -> Double): BufferND<Double> {
val array = IntStream.range(0, strides.linearSize).parallel().mapToDouble { offset -> val array = IntStream.range(0, strides.linearSize).parallel().mapToDouble { offset ->
val index = strides.index(offset) val index = strides.index(offset)
DoubleField.initializer(index) DoubleField.initializer(index)
@ -70,12 +70,12 @@ class StreamDoubleFieldND(override val shape: IntArray) : FieldND<Double, Double
} }
override fun zip( override fun zip(
a: StructureND<Double>, left: StructureND<Double>,
b: StructureND<Double>, right: StructureND<Double>,
transform: DoubleField.(Double, Double) -> Double, transform: DoubleField.(Double, Double) -> Double,
): BufferND<Double> { ): BufferND<Double> {
val array = IntStream.range(0, strides.linearSize).parallel().mapToDouble { offset -> val array = IntStream.range(0, strides.linearSize).parallel().mapToDouble { offset ->
DoubleField.transform(a.buffer.array[offset], b.buffer.array[offset]) DoubleField.transform(left.buffer.array[offset], right.buffer.array[offset])
}.toArray() }.toArray()
return BufferND(strides, array.asBuffer()) return BufferND(strides, array.asBuffer())
} }

View File

@ -70,12 +70,12 @@ public class DerivativeStructureField(
override fun DerivativeStructure.unaryMinus(): DerivativeStructure = negate() override fun DerivativeStructure.unaryMinus(): DerivativeStructure = negate()
override fun add(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.add(b) override fun add(left: DerivativeStructure, right: DerivativeStructure): DerivativeStructure = left.add(right)
override fun scale(a: DerivativeStructure, value: Double): DerivativeStructure = a.multiply(value) override fun scale(a: DerivativeStructure, value: Double): DerivativeStructure = a.multiply(value)
override fun multiply(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.multiply(b) override fun multiply(left: DerivativeStructure, right: DerivativeStructure): DerivativeStructure = left.multiply(right)
override fun divide(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.divide(b) override fun divide(left: DerivativeStructure, right: DerivativeStructure): DerivativeStructure = left.divide(right)
override fun sin(arg: DerivativeStructure): DerivativeStructure = arg.sin() override fun sin(arg: DerivativeStructure): DerivativeStructure = arg.sin()
override fun cos(arg: DerivativeStructure): DerivativeStructure = arg.cos() override fun cos(arg: DerivativeStructure): DerivativeStructure = arg.cos()
override fun tan(arg: DerivativeStructure): DerivativeStructure = arg.tan() override fun tan(arg: DerivativeStructure): DerivativeStructure = arg.tan()

View File

@ -77,33 +77,33 @@ public object ComplexField :
override fun scale(a: Complex, value: Double): Complex = Complex(a.re * value, a.im * value) override fun scale(a: Complex, value: Double): Complex = Complex(a.re * value, a.im * value)
override fun add(a: Complex, b: Complex): Complex = Complex(a.re + b.re, a.im + b.im) override fun add(left: Complex, right: Complex): Complex = Complex(left.re + right.re, left.im + right.im)
// override fun multiply(a: Complex, k: Number): Complex = Complex(a.re * k.toDouble(), a.im * k.toDouble()) // override fun multiply(a: Complex, k: Number): Complex = Complex(a.re * k.toDouble(), a.im * k.toDouble())
override fun multiply(a: Complex, b: Complex): Complex = override fun multiply(left: Complex, right: Complex): Complex =
Complex(a.re * b.re - a.im * b.im, a.re * b.im + a.im * b.re) Complex(left.re * right.re - left.im * right.im, left.re * right.im + left.im * right.re)
override fun divide(a: Complex, b: Complex): Complex = when { override fun divide(left: Complex, right: Complex): Complex = when {
abs(b.im) < abs(b.re) -> { abs(right.im) < abs(right.re) -> {
val wr = b.im / b.re val wr = right.im / right.re
val wd = b.re + wr * b.im val wd = right.re + wr * right.im
if (wd.isNaN() || wd == 0.0) if (wd.isNaN() || wd == 0.0)
throw ArithmeticException("Division by zero or infinity") throw ArithmeticException("Division by zero or infinity")
else else
Complex((a.re + a.im * wr) / wd, (a.im - a.re * wr) / wd) Complex((left.re + left.im * wr) / wd, (left.im - left.re * wr) / wd)
} }
b.im == 0.0 -> throw ArithmeticException("Division by zero") right.im == 0.0 -> throw ArithmeticException("Division by zero")
else -> { else -> {
val wr = b.re / b.im val wr = right.re / right.im
val wd = b.im + wr * b.re val wd = right.im + wr * right.re
if (wd.isNaN() || wd == 0.0) if (wd.isNaN() || wd == 0.0)
throw ArithmeticException("Division by zero or infinity") throw ArithmeticException("Division by zero or infinity")
else else
Complex((a.re * wr + a.im) / wd, (a.im * wr - a.re) / wd) Complex((left.re * wr + left.im) / wd, (left.im * wr - left.re) / wd)
} }
} }

View File

@ -63,27 +63,27 @@ public object QuaternionField : Field<Quaternion>, Norm<Quaternion, Quaternion>,
*/ */
public val k: Quaternion = Quaternion(0, 0, 0, 1) public val k: Quaternion = Quaternion(0, 0, 0, 1)
override fun add(a: Quaternion, b: Quaternion): Quaternion = override fun add(left: Quaternion, right: Quaternion): Quaternion =
Quaternion(a.w + b.w, a.x + b.x, a.y + b.y, a.z + b.z) Quaternion(left.w + right.w, left.x + right.x, left.y + right.y, left.z + right.z)
override fun scale(a: Quaternion, value: Double): Quaternion = override fun scale(a: Quaternion, value: Double): Quaternion =
Quaternion(a.w * value, a.x * value, a.y * value, a.z * value) Quaternion(a.w * value, a.x * value, a.y * value, a.z * value)
override fun multiply(a: Quaternion, b: Quaternion): Quaternion = Quaternion( override fun multiply(left: Quaternion, right: Quaternion): Quaternion = Quaternion(
a.w * b.w - a.x * b.x - a.y * b.y - a.z * b.z, left.w * right.w - left.x * right.x - left.y * right.y - left.z * right.z,
a.w * b.x + a.x * b.w + a.y * b.z - a.z * b.y, left.w * right.x + left.x * right.w + left.y * right.z - left.z * right.y,
a.w * b.y - a.x * b.z + a.y * b.w + a.z * b.x, left.w * right.y - left.x * right.z + left.y * right.w + left.z * right.x,
a.w * b.z + a.x * b.y - a.y * b.x + a.z * b.w, left.w * right.z + left.x * right.y - left.y * right.x + left.z * right.w,
) )
override fun divide(a: Quaternion, b: Quaternion): Quaternion { override fun divide(left: Quaternion, right: Quaternion): Quaternion {
val s = b.w * b.w + b.x * b.x + b.y * b.y + b.z * b.z val s = right.w * right.w + right.x * right.x + right.y * right.y + right.z * right.z
return Quaternion( return Quaternion(
(b.w * a.w + b.x * a.x + b.y * a.y + b.z * a.z) / s, (right.w * left.w + right.x * left.x + right.y * left.y + right.z * left.z) / s,
(b.w * a.x - b.x * a.w - b.y * a.z + b.z * a.y) / s, (right.w * left.x - right.x * left.w - right.y * left.z + right.z * left.y) / s,
(b.w * a.y + b.x * a.z - b.y * a.w - b.z * a.x) / s, (right.w * left.y + right.x * left.z - right.y * left.w - right.z * left.x) / s,
(b.w * a.z - b.x * a.y + b.y * a.x - b.z * a.w) / s, (right.w * left.z - right.x * left.y + right.y * left.x - right.z * left.w) / s,
) )
} }

View File

@ -57,8 +57,8 @@ public open class FunctionalExpressionGroup<T, out A : Group<T>>(
/** /**
* Builds an Expression of addition of two another expressions. * Builds an Expression of addition of two another expressions.
*/ */
override fun add(a: Expression<T>, b: Expression<T>): Expression<T> = override fun add(left: Expression<T>, right: Expression<T>): Expression<T> =
binaryOperation(GroupOps.PLUS_OPERATION, a, b) binaryOperation(GroupOps.PLUS_OPERATION, left, right)
// /** // /**
// * Builds an Expression of multiplication of expression by number. // * Builds an Expression of multiplication of expression by number.
@ -88,8 +88,8 @@ public open class FunctionalExpressionRing<T, out A : Ring<T>>(
/** /**
* Builds an Expression of multiplication of two expressions. * Builds an Expression of multiplication of two expressions.
*/ */
override fun multiply(a: Expression<T>, b: Expression<T>): Expression<T> = override fun multiply(left: Expression<T>, right: Expression<T>): Expression<T> =
binaryOperationFunction(RingOps.TIMES_OPERATION)(a, b) binaryOperationFunction(RingOps.TIMES_OPERATION)(left, right)
public operator fun Expression<T>.times(arg: T): Expression<T> = this * const(arg) public operator fun Expression<T>.times(arg: T): Expression<T> = this * const(arg)
public operator fun T.times(arg: Expression<T>): Expression<T> = arg * this public operator fun T.times(arg: Expression<T>): Expression<T> = arg * this
@ -107,8 +107,8 @@ public open class FunctionalExpressionField<T, out A : Field<T>>(
/** /**
* Builds an Expression of division an expression by another one. * Builds an Expression of division an expression by another one.
*/ */
override fun divide(a: Expression<T>, b: Expression<T>): Expression<T> = override fun divide(left: Expression<T>, right: Expression<T>): Expression<T> =
binaryOperationFunction(FieldOps.DIV_OPERATION)(a, b) binaryOperationFunction(FieldOps.DIV_OPERATION)(left, right)
public operator fun Expression<T>.div(arg: T): Expression<T> = this / const(arg) public operator fun Expression<T>.div(arg: T): Expression<T> = this / const(arg)
public operator fun T.div(arg: Expression<T>): Expression<T> = arg / this public operator fun T.div(arg: Expression<T>): Expression<T> = arg / this

View File

@ -31,15 +31,15 @@ public object MstGroup : Group<MST>, NumericAlgebra<MST>, ScaleOperations<MST> {
override fun number(value: Number): MST.Numeric = MstNumericAlgebra.number(value) override fun number(value: Number): MST.Numeric = MstNumericAlgebra.number(value)
override fun bindSymbolOrNull(value: String): Symbol = MstNumericAlgebra.bindSymbolOrNull(value) override fun bindSymbolOrNull(value: String): Symbol = MstNumericAlgebra.bindSymbolOrNull(value)
override fun add(a: MST, b: MST): MST.Binary = binaryOperationFunction(GroupOps.PLUS_OPERATION)(a, b) override fun add(left: MST, right: MST): MST.Binary = binaryOperationFunction(GroupOps.PLUS_OPERATION)(left, right)
override operator fun MST.unaryPlus(): MST.Unary = override operator fun MST.unaryPlus(): MST.Unary =
unaryOperationFunction(GroupOps.PLUS_OPERATION)(this) unaryOperationFunction(GroupOps.PLUS_OPERATION)(this)
override operator fun MST.unaryMinus(): MST.Unary = override operator fun MST.unaryMinus(): MST.Unary =
unaryOperationFunction(GroupOps.MINUS_OPERATION)(this) unaryOperationFunction(GroupOps.MINUS_OPERATION)(this)
override operator fun MST.minus(b: MST): MST.Binary = override operator fun MST.minus(other: MST): MST.Binary =
binaryOperationFunction(GroupOps.MINUS_OPERATION)(this, b) binaryOperationFunction(GroupOps.MINUS_OPERATION)(this, other)
override fun scale(a: MST, value: Double): MST.Binary = override fun scale(a: MST, value: Double): MST.Binary =
binaryOperationFunction(RingOps.TIMES_OPERATION)(a, number(value)) binaryOperationFunction(RingOps.TIMES_OPERATION)(a, number(value))
@ -62,17 +62,17 @@ public object MstRing : Ring<MST>, NumbersAddOps<MST>, ScaleOperations<MST> {
override fun number(value: Number): MST.Numeric = MstGroup.number(value) override fun number(value: Number): MST.Numeric = MstGroup.number(value)
override fun bindSymbolOrNull(value: String): Symbol = MstNumericAlgebra.bindSymbolOrNull(value) override fun bindSymbolOrNull(value: String): Symbol = MstNumericAlgebra.bindSymbolOrNull(value)
override fun add(a: MST, b: MST): MST.Binary = MstGroup.add(a, b) override fun add(left: MST, right: MST): MST.Binary = MstGroup.add(left, right)
override fun scale(a: MST, value: Double): MST.Binary = override fun scale(a: MST, value: Double): MST.Binary =
MstGroup.binaryOperationFunction(RingOps.TIMES_OPERATION)(a, MstGroup.number(value)) MstGroup.binaryOperationFunction(RingOps.TIMES_OPERATION)(a, MstGroup.number(value))
override fun multiply(a: MST, b: MST): MST.Binary = override fun multiply(left: MST, right: MST): MST.Binary =
binaryOperationFunction(RingOps.TIMES_OPERATION)(a, b) binaryOperationFunction(RingOps.TIMES_OPERATION)(left, right)
override operator fun MST.unaryPlus(): MST.Unary = MstGroup { +this@unaryPlus } override operator fun MST.unaryPlus(): MST.Unary = MstGroup { +this@unaryPlus }
override operator fun MST.unaryMinus(): MST.Unary = MstGroup { -this@unaryMinus } override operator fun MST.unaryMinus(): MST.Unary = MstGroup { -this@unaryMinus }
override operator fun MST.minus(b: MST): MST.Binary = MstGroup { this@minus - b } override operator fun MST.minus(other: MST): MST.Binary = MstGroup { this@minus - other }
override fun binaryOperationFunction(operation: String): (left: MST, right: MST) -> MST.Binary = override fun binaryOperationFunction(operation: String): (left: MST, right: MST) -> MST.Binary =
MstGroup.binaryOperationFunction(operation) MstGroup.binaryOperationFunction(operation)
@ -92,18 +92,18 @@ public object MstField : Field<MST>, NumbersAddOps<MST>, ScaleOperations<MST> {
override fun bindSymbolOrNull(value: String): Symbol = MstNumericAlgebra.bindSymbolOrNull(value) override fun bindSymbolOrNull(value: String): Symbol = MstNumericAlgebra.bindSymbolOrNull(value)
override fun number(value: Number): MST.Numeric = MstRing.number(value) override fun number(value: Number): MST.Numeric = MstRing.number(value)
override fun add(a: MST, b: MST): MST.Binary = MstRing.add(a, b) override fun add(left: MST, right: MST): MST.Binary = MstRing.add(left, right)
override fun scale(a: MST, value: Double): MST.Binary = override fun scale(a: MST, value: Double): MST.Binary =
MstGroup.binaryOperationFunction(RingOps.TIMES_OPERATION)(a, MstGroup.number(value)) MstGroup.binaryOperationFunction(RingOps.TIMES_OPERATION)(a, MstGroup.number(value))
override fun multiply(a: MST, b: MST): MST.Binary = MstRing.multiply(a, b) override fun multiply(left: MST, right: MST): MST.Binary = MstRing.multiply(left, right)
override fun divide(a: MST, b: MST): MST.Binary = override fun divide(left: MST, right: MST): MST.Binary =
binaryOperationFunction(FieldOps.DIV_OPERATION)(a, b) binaryOperationFunction(FieldOps.DIV_OPERATION)(left, right)
override operator fun MST.unaryPlus(): MST.Unary = MstRing { +this@unaryPlus } override operator fun MST.unaryPlus(): MST.Unary = MstRing { +this@unaryPlus }
override operator fun MST.unaryMinus(): MST.Unary = MstRing { -this@unaryMinus } override operator fun MST.unaryMinus(): MST.Unary = MstRing { -this@unaryMinus }
override operator fun MST.minus(b: MST): MST.Binary = MstRing { this@minus - b } override operator fun MST.minus(other: MST): MST.Binary = MstRing { this@minus - other }
override fun binaryOperationFunction(operation: String): (left: MST, right: MST) -> MST.Binary = override fun binaryOperationFunction(operation: String): (left: MST, right: MST) -> MST.Binary =
MstRing.binaryOperationFunction(operation) MstRing.binaryOperationFunction(operation)
@ -134,17 +134,17 @@ public object MstExtendedField : ExtendedField<MST>, NumericAlgebra<MST> {
override fun asinh(arg: MST): MST.Unary = unaryOperationFunction(ExponentialOperations.ASINH_OPERATION)(arg) override fun asinh(arg: MST): MST.Unary = unaryOperationFunction(ExponentialOperations.ASINH_OPERATION)(arg)
override fun acosh(arg: MST): MST.Unary = unaryOperationFunction(ExponentialOperations.ACOSH_OPERATION)(arg) override fun acosh(arg: MST): MST.Unary = unaryOperationFunction(ExponentialOperations.ACOSH_OPERATION)(arg)
override fun atanh(arg: MST): MST.Unary = unaryOperationFunction(ExponentialOperations.ATANH_OPERATION)(arg) override fun atanh(arg: MST): MST.Unary = unaryOperationFunction(ExponentialOperations.ATANH_OPERATION)(arg)
override fun add(a: MST, b: MST): MST.Binary = MstField.add(a, b) override fun add(left: MST, right: MST): MST.Binary = MstField.add(left, right)
override fun sqrt(arg: MST): MST = unaryOperationFunction(PowerOperations.SQRT_OPERATION)(arg) override fun sqrt(arg: MST): MST = unaryOperationFunction(PowerOperations.SQRT_OPERATION)(arg)
override fun scale(a: MST, value: Double): MST = override fun scale(a: MST, value: Double): MST =
binaryOperation(GroupOps.PLUS_OPERATION, a, number(value)) binaryOperation(GroupOps.PLUS_OPERATION, a, number(value))
override fun multiply(a: MST, b: MST): MST.Binary = MstField.multiply(a, b) override fun multiply(left: MST, right: MST): MST.Binary = MstField.multiply(left, right)
override fun divide(a: MST, b: MST): MST.Binary = MstField.divide(a, b) override fun divide(left: MST, right: MST): MST.Binary = MstField.divide(left, right)
override operator fun MST.unaryPlus(): MST.Unary = MstField { +this@unaryPlus } override operator fun MST.unaryPlus(): MST.Unary = MstField { +this@unaryPlus }
override operator fun MST.unaryMinus(): MST.Unary = MstField { -this@unaryMinus } override operator fun MST.unaryMinus(): MST.Unary = MstField { -this@unaryMinus }
override operator fun MST.minus(b: MST): MST.Binary = MstField { this@minus - b } override operator fun MST.minus(other: MST): MST.Binary = MstField { this@minus - other }
override fun power(arg: MST, pow: Number): MST.Binary = override fun power(arg: MST, pow: Number): MST.Binary =
binaryOperationFunction(PowerOperations.POW_OPERATION)(arg, number(pow)) binaryOperationFunction(PowerOperations.POW_OPERATION)(arg, number(pow))

View File

@ -168,22 +168,22 @@ public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
// Basic math (+, -, *, /) // Basic math (+, -, *, /)
override fun add(a: AutoDiffValue<T>, b: AutoDiffValue<T>): AutoDiffValue<T> = override fun add(left: AutoDiffValue<T>, right: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { a.value + b.value }) { z -> derive(const { left.value + right.value }) { z ->
a.d += z.d left.d += z.d
b.d += z.d right.d += z.d
} }
override fun multiply(a: AutoDiffValue<T>, b: AutoDiffValue<T>): AutoDiffValue<T> = override fun multiply(left: AutoDiffValue<T>, right: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { a.value * b.value }) { z -> derive(const { left.value * right.value }) { z ->
a.d += z.d * b.value left.d += z.d * right.value
b.d += z.d * a.value right.d += z.d * left.value
} }
override fun divide(a: AutoDiffValue<T>, b: AutoDiffValue<T>): AutoDiffValue<T> = override fun divide(left: AutoDiffValue<T>, right: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { a.value / b.value }) { z -> derive(const { left.value / right.value }) { z ->
a.d += z.d / b.value left.d += z.d / right.value
b.d -= z.d * a.value / (b.value * b.value) right.d -= z.d * left.value / (right.value * right.value)
} }
override fun scale(a: AutoDiffValue<T>, value: Double): AutoDiffValue<T> = override fun scale(a: AutoDiffValue<T>, value: Double): AutoDiffValue<T> =

View File

@ -100,12 +100,12 @@ public interface GroupOpsND<T, out A : GroupOps<T>> : GroupOps<StructureND<T>>,
/** /**
* Element-wise addition. * Element-wise addition.
* *
* @param a the augend. * @param left the augend.
* @param b the addend. * @param right the addend.
* @return the sum. * @return the sum.
*/ */
override fun add(a: StructureND<T>, b: StructureND<T>): StructureND<T> = override fun add(left: StructureND<T>, right: StructureND<T>): StructureND<T> =
zip(a, b) { aValue, bValue -> add(aValue, bValue) } zip(left, right) { aValue, bValue -> add(aValue, bValue) }
// TODO move to extensions after KEEP-176 // TODO move to extensions after KEEP-176
@ -134,7 +134,7 @@ public interface GroupOpsND<T, out A : GroupOps<T>> : GroupOps<StructureND<T>>,
* @param arg the addend. * @param arg the addend.
* @return the sum. * @return the sum.
*/ */
public operator fun T.plus(arg: StructureND<T>): StructureND<T> = arg.map { value -> add(this@plus, value) } public operator fun T.plus(arg: StructureND<T>): StructureND<T> = arg + this
/** /**
* Subtracts an ND structure from an element of it. * Subtracts an ND structure from an element of it.
@ -162,12 +162,12 @@ public interface RingOpsND<T, out A : RingOps<T>> : RingOps<StructureND<T>>, Gro
/** /**
* Element-wise multiplication. * Element-wise multiplication.
* *
* @param a the multiplicand. * @param left the multiplicand.
* @param b the multiplier. * @param right the multiplier.
* @return the product. * @return the product.
*/ */
override fun multiply(a: StructureND<T>, b: StructureND<T>): StructureND<T> = override fun multiply(left: StructureND<T>, right: StructureND<T>): StructureND<T> =
zip(a, b) { aValue, bValue -> multiply(aValue, bValue) } zip(left, right) { aValue, bValue -> multiply(aValue, bValue) }
//TODO move to extensions after KEEP-176 //TODO move to extensions after KEEP-176
@ -208,12 +208,12 @@ public interface FieldOpsND<T, out A : Field<T>> : FieldOps<StructureND<T>>, Rin
/** /**
* Element-wise division. * Element-wise division.
* *
* @param a the dividend. * @param left the dividend.
* @param b the divisor. * @param right the divisor.
* @return the quotient. * @return the quotient.
*/ */
override fun divide(a: StructureND<T>, b: StructureND<T>): StructureND<T> = override fun divide(left: StructureND<T>, right: StructureND<T>): StructureND<T> =
zip(a, b) { aValue, bValue -> divide(aValue, bValue) } zip(left, right) { aValue, bValue -> divide(aValue, bValue) }
//TODO move to extensions after https://github.com/Kotlin/KEEP/blob/master/proposals/context-receivers.md //TODO move to extensions after https://github.com/Kotlin/KEEP/blob/master/proposals/context-receivers.md
/** /**

View File

@ -73,7 +73,7 @@ internal inline fun <T, A : Algebra<T>> BufferAlgebraND<T, A>.zipInline(
r: BufferND<T>, r: BufferND<T>,
crossinline block: A.(l: T, r: T) -> T crossinline block: A.(l: T, r: T) -> T
): BufferND<T> { ): BufferND<T> {
require(l.indexes == r.indexes) require(l.indexes == r.indexes) { "Zip requires the same shapes, but found ${l.shape} on the left and ${r.shape} on the right" }
val indexes = l.indexes val indexes = l.indexes
return BufferND(indexes, bufferAlgebra.zipInline(l.buffer, r.buffer, block)) return BufferND(indexes, bufferAlgebra.zipInline(l.buffer, r.buffer, block))
} }
@ -114,6 +114,10 @@ public fun <T, A : Algebra<T>> BufferAlgebraND<T, A>.produce(
initializer: A.(IntArray) -> T initializer: A.(IntArray) -> T
): BufferND<T> = produce(shape, initializer) ): BufferND<T> = produce(shape, initializer)
public fun <T, EA : Algebra<T>, A> A.produce(
initializer: EA.(IntArray) -> T
): BufferND<T> where A : BufferAlgebraND<T, EA>, A : WithShape = produce(shape, initializer)
//// group factories //// group factories
//public fun <T, A : Group<T>> A.ndAlgebra( //public fun <T, A : Group<T>> A.ndAlgebra(
// bufferAlgebra: BufferAlgebra<T, A>, // bufferAlgebra: BufferAlgebra<T, A>,

View File

@ -20,7 +20,7 @@ import space.kscience.kmath.structures.MutableBufferFactory
*/ */
public open class BufferND<out T>( public open class BufferND<out T>(
public val indexes: ShapeIndex, public val indexes: ShapeIndex,
public val buffer: Buffer<T>, public open val buffer: Buffer<T>,
) : StructureND<T> { ) : StructureND<T> {
override operator fun get(index: IntArray): T = buffer[indexes.offset(index)] override operator fun get(index: IntArray): T = buffer[indexes.offset(index)]
@ -55,14 +55,14 @@ public inline fun <T, reified R : Any> StructureND<T>.mapToBuffer(
* *
* @param T the type of items. * @param T the type of items.
* @param strides The strides to access elements of [MutableBuffer] by linear indices. * @param strides The strides to access elements of [MutableBuffer] by linear indices.
* @param mutableBuffer The underlying buffer. * @param buffer The underlying buffer.
*/ */
public class MutableBufferND<T>( public class MutableBufferND<T>(
strides: ShapeIndex, strides: ShapeIndex,
public val mutableBuffer: MutableBuffer<T>, override val buffer: MutableBuffer<T>,
) : MutableStructureND<T>, BufferND<T>(strides, mutableBuffer) { ) : MutableStructureND<T>, BufferND<T>(strides, buffer) {
override fun set(index: IntArray, value: T) { override fun set(index: IntArray, value: T) {
mutableBuffer[indexes.offset(index)] = value buffer[indexes.offset(index)] = value
} }
} }
@ -74,7 +74,7 @@ public inline fun <T, reified R : Any> MutableStructureND<T>.mapToMutableBuffer(
crossinline transform: (T) -> R, crossinline transform: (T) -> R,
): MutableBufferND<R> { ): MutableBufferND<R> {
return if (this is MutableBufferND<T>) return if (this is MutableBufferND<T>)
MutableBufferND(this.indexes, factory.invoke(indexes.linearSize) { transform(mutableBuffer[it]) }) MutableBufferND(this.indexes, factory.invoke(indexes.linearSize) { transform(buffer[it]) })
else { else {
val strides = DefaultStrides(shape) val strides = DefaultStrides(shape)
MutableBufferND(strides, factory.invoke(strides.linearSize) { transform(get(strides.index(it))) }) MutableBufferND(strides, factory.invoke(strides.linearSize) { transform(get(strides.index(it))) })

View File

@ -10,42 +10,132 @@ import space.kscience.kmath.operations.*
import space.kscience.kmath.structures.DoubleBuffer import space.kscience.kmath.structures.DoubleBuffer
import kotlin.contracts.InvocationKind import kotlin.contracts.InvocationKind
import kotlin.contracts.contract import kotlin.contracts.contract
import kotlin.math.pow
public class DoubleBufferND(
indexes: ShapeIndex,
override val buffer: DoubleBuffer,
) : BufferND<Double>(indexes, buffer)
public sealed class DoubleFieldOpsND : BufferedFieldOpsND<Double, DoubleField>(DoubleField.bufferAlgebra), public sealed class DoubleFieldOpsND : BufferedFieldOpsND<Double, DoubleField>(DoubleField.bufferAlgebra),
ScaleOperations<StructureND<Double>>, ExtendedFieldOps<StructureND<Double>> { ScaleOperations<StructureND<Double>>, ExtendedFieldOps<StructureND<Double>> {
override fun StructureND<Double>.toBufferND(): BufferND<Double> = when (this) { override fun StructureND<Double>.toBufferND(): DoubleBufferND = when (this) {
is BufferND -> this is DoubleBufferND -> this
else -> { else -> {
val indexer = indexerBuilder(shape) val indexer = indexerBuilder(shape)
BufferND(indexer, DoubleBuffer(indexer.linearSize) { offset -> get(indexer.index(offset)) }) DoubleBufferND(indexer, DoubleBuffer(indexer.linearSize) { offset -> get(indexer.index(offset)) })
} }
} }
//TODO do specialization private inline fun mapInline(
arg: DoubleBufferND,
transform: (Double) -> Double
): DoubleBufferND {
val indexes = arg.indexes
val array = arg.buffer.array
return DoubleBufferND(indexes, DoubleBuffer(indexes.linearSize) { transform(array[it]) })
}
override fun scale(a: StructureND<Double>, value: Double): BufferND<Double> = private inline fun zipInline(
l: DoubleBufferND,
r: DoubleBufferND,
block: (l: Double, r: Double) -> Double
): DoubleBufferND {
require(l.indexes == r.indexes) { "Zip requires the same shapes, but found ${l.shape} on the left and ${r.shape} on the right" }
val indexes = l.indexes
val lArray = l.buffer.array
val rArray = r.buffer.array
return DoubleBufferND(indexes, DoubleBuffer(indexes.linearSize) { block(lArray[it], rArray[it]) })
}
override fun StructureND<Double>.map(transform: DoubleField.(Double) -> Double): BufferND<Double> =
mapInline(toBufferND()) { DoubleField.transform(it) }
override fun zip(
left: StructureND<Double>,
right: StructureND<Double>,
transform: DoubleField.(Double, Double) -> Double
): BufferND<Double> = zipInline(left.toBufferND(), right.toBufferND()) { l, r -> DoubleField.transform(l, r) }
override fun produce(shape: Shape, initializer: DoubleField.(IntArray) -> Double): DoubleBufferND {
val indexer = indexerBuilder(shape)
return DoubleBufferND(
indexer,
DoubleBuffer(indexer.linearSize) { offset ->
elementAlgebra.initializer(indexer.index(offset))
}
)
}
override fun add(left: StructureND<Double>, right: StructureND<Double>): DoubleBufferND =
zipInline(left.toBufferND(), right.toBufferND()) { l, r -> l + r }
override fun multiply(left: StructureND<Double>, right: StructureND<Double>): DoubleBufferND =
zipInline(left.toBufferND(), right.toBufferND()) { l, r -> l * r }
override fun StructureND<Double>.unaryMinus(): DoubleBufferND = mapInline(toBufferND()) { -it }
override fun StructureND<Double>.div(other: StructureND<Double>): DoubleBufferND =
zipInline(toBufferND(), other.toBufferND()) { l, r -> l / r }
override fun StructureND<Double>.plus(arg: Double): DoubleBufferND = mapInline(toBufferND()) { it + arg }
override fun StructureND<Double>.minus(arg: Double): StructureND<Double> = mapInline(toBufferND()) { it - arg }
override fun Double.plus(arg: StructureND<Double>): StructureND<Double> = arg + this
override fun Double.minus(arg: StructureND<Double>): StructureND<Double> = mapInline(arg.toBufferND()) { this - it }
override fun scale(a: StructureND<Double>, value: Double): DoubleBufferND =
mapInline(a.toBufferND()) { it * value } mapInline(a.toBufferND()) { it * value }
override fun power(arg: StructureND<Double>, pow: Number): BufferND<Double> = override fun power(arg: StructureND<Double>, pow: Number): DoubleBufferND =
mapInline(arg.toBufferND()) { power(it, pow) } mapInline(arg.toBufferND()) { it.pow(pow.toDouble()) }
override fun exp(arg: StructureND<Double>): BufferND<Double> = mapInline(arg.toBufferND()) { exp(it) } override fun exp(arg: StructureND<Double>): DoubleBufferND =
override fun ln(arg: StructureND<Double>): BufferND<Double> = mapInline(arg.toBufferND()) { ln(it) } mapInline(arg.toBufferND()) { kotlin.math.exp(it) }
override fun sin(arg: StructureND<Double>): BufferND<Double> = mapInline(arg.toBufferND()) { sin(it) } override fun ln(arg: StructureND<Double>): DoubleBufferND =
override fun cos(arg: StructureND<Double>): BufferND<Double> = mapInline(arg.toBufferND()) { cos(it) } mapInline(arg.toBufferND()) { kotlin.math.ln(it) }
override fun tan(arg: StructureND<Double>): BufferND<Double> = mapInline(arg.toBufferND()) { tan(it) }
override fun asin(arg: StructureND<Double>): BufferND<Double> = mapInline(arg.toBufferND()) { asin(it) }
override fun acos(arg: StructureND<Double>): BufferND<Double> = mapInline(arg.toBufferND()) { acos(it) }
override fun atan(arg: StructureND<Double>): BufferND<Double> = mapInline(arg.toBufferND()) { atan(it) }
override fun sinh(arg: StructureND<Double>): BufferND<Double> = mapInline(arg.toBufferND()) { sinh(it) } override fun sin(arg: StructureND<Double>): DoubleBufferND =
override fun cosh(arg: StructureND<Double>): BufferND<Double> = mapInline(arg.toBufferND()) { cosh(it) } mapInline(arg.toBufferND()) { kotlin.math.sin(it) }
override fun tanh(arg: StructureND<Double>): BufferND<Double> = mapInline(arg.toBufferND()) { tanh(it) }
override fun asinh(arg: StructureND<Double>): BufferND<Double> = mapInline(arg.toBufferND()) { asinh(it) } override fun cos(arg: StructureND<Double>): DoubleBufferND =
override fun acosh(arg: StructureND<Double>): BufferND<Double> = mapInline(arg.toBufferND()) { acosh(it) } mapInline(arg.toBufferND()) { kotlin.math.cos(it) }
override fun atanh(arg: StructureND<Double>): BufferND<Double> = mapInline(arg.toBufferND()) { atanh(it) }
override fun tan(arg: StructureND<Double>): DoubleBufferND =
mapInline(arg.toBufferND()) { kotlin.math.tan(it) }
override fun asin(arg: StructureND<Double>): DoubleBufferND =
mapInline(arg.toBufferND()) { kotlin.math.asin(it) }
override fun acos(arg: StructureND<Double>): DoubleBufferND =
mapInline(arg.toBufferND()) { kotlin.math.acos(it) }
override fun atan(arg: StructureND<Double>): DoubleBufferND =
mapInline(arg.toBufferND()) { kotlin.math.atan(it) }
override fun sinh(arg: StructureND<Double>): DoubleBufferND =
mapInline(arg.toBufferND()) { kotlin.math.sinh(it) }
override fun cosh(arg: StructureND<Double>): DoubleBufferND =
mapInline(arg.toBufferND()) { kotlin.math.cosh(it) }
override fun tanh(arg: StructureND<Double>): DoubleBufferND =
mapInline(arg.toBufferND()) { kotlin.math.tanh(it) }
override fun asinh(arg: StructureND<Double>): DoubleBufferND =
mapInline(arg.toBufferND()) { kotlin.math.asinh(it) }
override fun acosh(arg: StructureND<Double>): DoubleBufferND =
mapInline(arg.toBufferND()) { kotlin.math.acosh(it) }
override fun atanh(arg: StructureND<Double>): DoubleBufferND =
mapInline(arg.toBufferND()) { kotlin.math.atanh(it) }
public companion object : DoubleFieldOpsND() public companion object : DoubleFieldOpsND()
} }
@ -54,7 +144,7 @@ public sealed class DoubleFieldOpsND : BufferedFieldOpsND<Double, DoubleField>(D
public class DoubleFieldND(override val shape: Shape) : public class DoubleFieldND(override val shape: Shape) :
DoubleFieldOpsND(), FieldND<Double, DoubleField>, NumbersAddOps<StructureND<Double>> { DoubleFieldOpsND(), FieldND<Double, DoubleField>, NumbersAddOps<StructureND<Double>> {
override fun number(value: Number): BufferND<Double> { override fun number(value: Number): DoubleBufferND {
val d = value.toDouble() // minimize conversions val d = value.toDouble() // minimize conversions
return produce(shape) { d } return produce(shape) { d }
} }

View File

@ -121,11 +121,11 @@ public interface GroupOps<T> : Algebra<T> {
/** /**
* Addition of two elements. * Addition of two elements.
* *
* @param a the augend. * @param left the augend.
* @param b the addend. * @param right the addend.
* @return the sum. * @return the sum.
*/ */
public fun add(a: T, b: T): T public fun add(left: T, right: T): T
// Operations to be performed in this context. Could be moved to extensions in case of KEEP-176. // Operations to be performed in this context. Could be moved to extensions in case of KEEP-176.
@ -149,19 +149,19 @@ public interface GroupOps<T> : Algebra<T> {
* Addition of two elements. * Addition of two elements.
* *
* @receiver the augend. * @receiver the augend.
* @param b the addend. * @param other the addend.
* @return the sum. * @return the sum.
*/ */
public operator fun T.plus(b: T): T = add(this, b) public operator fun T.plus(other: T): T = add(this, other)
/** /**
* Subtraction of two elements. * Subtraction of two elements.
* *
* @receiver the minuend. * @receiver the minuend.
* @param b the subtrahend. * @param other the subtrahend.
* @return the difference. * @return the difference.
*/ */
public operator fun T.minus(b: T): T = add(this, -b) public operator fun T.minus(other: T): T = add(this, -other)
// Dynamic dispatch of operations // Dynamic dispatch of operations
override fun unaryOperationFunction(operation: String): (arg: T) -> T = when (operation) { override fun unaryOperationFunction(operation: String): (arg: T) -> T = when (operation) {
PLUS_OPERATION -> { arg -> +arg } PLUS_OPERATION -> { arg -> +arg }
@ -210,18 +210,18 @@ public interface RingOps<T> : GroupOps<T> {
/** /**
* Multiplies two elements. * Multiplies two elements.
* *
* @param a the multiplier. * @param left the multiplier.
* @param b the multiplicand. * @param right the multiplicand.
*/ */
public fun multiply(a: T, b: T): T public fun multiply(left: T, right: T): T
/** /**
* Multiplies this element by scalar. * Multiplies this element by scalar.
* *
* @receiver the multiplier. * @receiver the multiplier.
* @param b the multiplicand. * @param other the multiplicand.
*/ */
public operator fun T.times(b: T): T = multiply(this, b) public operator fun T.times(other: T): T = multiply(this, other)
override fun binaryOperationFunction(operation: String): (left: T, right: T) -> T = when (operation) { override fun binaryOperationFunction(operation: String): (left: T, right: T) -> T = when (operation) {
TIMES_OPERATION -> ::multiply TIMES_OPERATION -> ::multiply
@ -260,20 +260,20 @@ public interface FieldOps<T> : RingOps<T> {
/** /**
* Division of two elements. * Division of two elements.
* *
* @param a the dividend. * @param left the dividend.
* @param b the divisor. * @param right the divisor.
* @return the quotient. * @return the quotient.
*/ */
public fun divide(a: T, b: T): T public fun divide(left: T, right: T): T
/** /**
* Division of two elements. * Division of two elements.
* *
* @receiver the dividend. * @receiver the dividend.
* @param b the divisor. * @param other the divisor.
* @return the quotient. * @return the quotient.
*/ */
public operator fun T.div(b: T): T = divide(this, b) public operator fun T.div(other: T): T = divide(this, other)
override fun binaryOperationFunction(operation: String): (left: T, right: T) -> T = when (operation) { override fun binaryOperationFunction(operation: String): (left: T, right: T) -> T = when (operation) {
DIV_OPERATION -> ::divide DIV_OPERATION -> ::divide

View File

@ -34,10 +34,10 @@ public object BigIntField : Field<BigInt>, NumbersAddOps<BigInt>, ScaleOperation
@Suppress("EXTENSION_SHADOWED_BY_MEMBER") @Suppress("EXTENSION_SHADOWED_BY_MEMBER")
override fun BigInt.unaryMinus(): BigInt = -this override fun BigInt.unaryMinus(): BigInt = -this
override fun add(a: BigInt, b: BigInt): BigInt = a.plus(b) override fun add(left: BigInt, right: BigInt): BigInt = left.plus(right)
override fun scale(a: BigInt, value: Double): BigInt = a.times(number(value)) override fun scale(a: BigInt, value: Double): BigInt = a.times(number(value))
override fun multiply(a: BigInt, b: BigInt): BigInt = a.times(b) override fun multiply(left: BigInt, right: BigInt): BigInt = left.times(right)
override fun divide(a: BigInt, b: BigInt): BigInt = a.div(b) override fun divide(left: BigInt, right: BigInt): BigInt = left.div(right)
public operator fun String.unaryPlus(): BigInt = this.parseBigInteger() ?: error("Can't parse $this as big integer") public operator fun String.unaryPlus(): BigInt = this.parseBigInteger() ?: error("Can't parse $this as big integer")
public operator fun String.unaryMinus(): BigInt = public operator fun String.unaryMinus(): BigInt =

View File

@ -134,8 +134,8 @@ public open class BufferRingOps<T, A: Ring<T>>(
override val bufferFactory: BufferFactory<T>, override val bufferFactory: BufferFactory<T>,
) : BufferAlgebra<T, A>, RingOps<Buffer<T>>{ ) : BufferAlgebra<T, A>, RingOps<Buffer<T>>{
override fun add(a: Buffer<T>, b: Buffer<T>): Buffer<T> = zipInline(a, b) { l, r -> l + r } override fun add(left: Buffer<T>, right: Buffer<T>): Buffer<T> = zipInline(left, right) { l, r -> l + r }
override fun multiply(a: Buffer<T>, b: Buffer<T>): Buffer<T> = zipInline(a, b) { l, r -> l * r } override fun multiply(left: Buffer<T>, right: Buffer<T>): Buffer<T> = zipInline(left, right) { l, r -> l * r }
override fun Buffer<T>.unaryMinus(): Buffer<T> = map { -it } override fun Buffer<T>.unaryMinus(): Buffer<T> = map { -it }
override fun unaryOperationFunction(operation: String): (arg: Buffer<T>) -> Buffer<T> = override fun unaryOperationFunction(operation: String): (arg: Buffer<T>) -> Buffer<T> =
@ -153,9 +153,9 @@ public open class BufferFieldOps<T, A : Field<T>>(
bufferFactory: BufferFactory<T>, bufferFactory: BufferFactory<T>,
) : BufferRingOps<T, A>(elementAlgebra, bufferFactory), BufferAlgebra<T, A>, FieldOps<Buffer<T>>, ScaleOperations<Buffer<T>> { ) : BufferRingOps<T, A>(elementAlgebra, bufferFactory), BufferAlgebra<T, A>, FieldOps<Buffer<T>>, ScaleOperations<Buffer<T>> {
override fun add(a: Buffer<T>, b: Buffer<T>): Buffer<T> = zipInline(a, b) { l, r -> l + r } override fun add(left: Buffer<T>, right: Buffer<T>): Buffer<T> = zipInline(left, right) { l, r -> l + r }
override fun multiply(a: Buffer<T>, b: Buffer<T>): Buffer<T> = zipInline(a, b) { l, r -> l * r } override fun multiply(left: Buffer<T>, right: Buffer<T>): Buffer<T> = zipInline(left, right) { l, r -> l * r }
override fun divide(a: Buffer<T>, b: Buffer<T>): Buffer<T> = zipInline(a, b) { l, r -> l / r } override fun divide(left: Buffer<T>, right: Buffer<T>): Buffer<T> = zipInline(left, right) { l, r -> l / r }
override fun scale(a: Buffer<T>, value: Double): Buffer<T> = a.map { scale(it, value) } override fun scale(a: Buffer<T>, value: Double): Buffer<T> = a.map { scale(it, value) }
override fun Buffer<T>.unaryMinus(): Buffer<T> = map { -it } override fun Buffer<T>.unaryMinus(): Buffer<T> = map { -it }

View File

@ -15,36 +15,37 @@ import kotlin.math.*
* [ExtendedFieldOps] over [DoubleBuffer]. * [ExtendedFieldOps] over [DoubleBuffer].
*/ */
public abstract class DoubleBufferOps : ExtendedFieldOps<Buffer<Double>>, Norm<Buffer<Double>, Double> { public abstract class DoubleBufferOps : ExtendedFieldOps<Buffer<Double>>, Norm<Buffer<Double>, Double> {
override fun Buffer<Double>.unaryMinus(): DoubleBuffer = if (this is DoubleBuffer) { override fun Buffer<Double>.unaryMinus(): DoubleBuffer = if (this is DoubleBuffer) {
DoubleBuffer(size) { -array[it] } DoubleBuffer(size) { -array[it] }
} else { } else {
DoubleBuffer(size) { -get(it) } DoubleBuffer(size) { -get(it) }
} }
override fun add(a: Buffer<Double>, b: Buffer<Double>): DoubleBuffer { override fun add(left: Buffer<Double>, right: Buffer<Double>): DoubleBuffer {
require(b.size == a.size) { require(right.size == left.size) {
"The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " "The size of the first buffer ${left.size} should be the same as for second one: ${right.size} "
} }
return if (a is DoubleBuffer && b is DoubleBuffer) { return if (left is DoubleBuffer && right is DoubleBuffer) {
val aArray = a.array val aArray = left.array
val bArray = b.array val bArray = right.array
DoubleBuffer(DoubleArray(a.size) { aArray[it] + bArray[it] }) DoubleBuffer(DoubleArray(left.size) { aArray[it] + bArray[it] })
} else DoubleBuffer(DoubleArray(a.size) { a[it] + b[it] }) } else DoubleBuffer(DoubleArray(left.size) { left[it] + right[it] })
} }
override fun Buffer<Double>.plus(b: Buffer<Double>): DoubleBuffer = add(this, b) override fun Buffer<Double>.plus(other: Buffer<Double>): DoubleBuffer = add(this, other)
override fun Buffer<Double>.minus(b: Buffer<Double>): DoubleBuffer { override fun Buffer<Double>.minus(other: Buffer<Double>): DoubleBuffer {
require(b.size == this.size) { require(other.size == this.size) {
"The size of the first buffer ${this.size} should be the same as for second one: ${b.size} " "The size of the first buffer ${this.size} should be the same as for second one: ${other.size} "
} }
return if (this is DoubleBuffer && b is DoubleBuffer) { return if (this is DoubleBuffer && other is DoubleBuffer) {
val aArray = this.array val aArray = this.array
val bArray = b.array val bArray = other.array
DoubleBuffer(DoubleArray(this.size) { aArray[it] - bArray[it] }) DoubleBuffer(DoubleArray(this.size) { aArray[it] - bArray[it] })
} else DoubleBuffer(DoubleArray(this.size) { this[it] - b[it] }) } else DoubleBuffer(DoubleArray(this.size) { this[it] - other[it] })
} }
// //
@ -66,29 +67,29 @@ public abstract class DoubleBufferOps : ExtendedFieldOps<Buffer<Double>>, Norm<B
// } else RealBuffer(DoubleArray(a.size) { a[it] / kValue }) // } else RealBuffer(DoubleArray(a.size) { a[it] / kValue })
// } // }
override fun multiply(a: Buffer<Double>, b: Buffer<Double>): DoubleBuffer { override fun multiply(left: Buffer<Double>, right: Buffer<Double>): DoubleBuffer {
require(b.size == a.size) { require(right.size == left.size) {
"The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " "The size of the first buffer ${left.size} should be the same as for second one: ${right.size} "
} }
return if (a is DoubleBuffer && b is DoubleBuffer) { return if (left is DoubleBuffer && right is DoubleBuffer) {
val aArray = a.array val aArray = left.array
val bArray = b.array val bArray = right.array
DoubleBuffer(DoubleArray(a.size) { aArray[it] * bArray[it] }) DoubleBuffer(DoubleArray(left.size) { aArray[it] * bArray[it] })
} else } else
DoubleBuffer(DoubleArray(a.size) { a[it] * b[it] }) DoubleBuffer(DoubleArray(left.size) { left[it] * right[it] })
} }
override fun divide(a: Buffer<Double>, b: Buffer<Double>): DoubleBuffer { override fun divide(left: Buffer<Double>, right: Buffer<Double>): DoubleBuffer {
require(b.size == a.size) { require(right.size == left.size) {
"The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " "The size of the first buffer ${left.size} should be the same as for second one: ${right.size} "
} }
return if (a is DoubleBuffer && b is DoubleBuffer) { return if (left is DoubleBuffer && right is DoubleBuffer) {
val aArray = a.array val aArray = left.array
val bArray = b.array val bArray = right.array
DoubleBuffer(DoubleArray(a.size) { aArray[it] / bArray[it] }) DoubleBuffer(DoubleArray(left.size) { aArray[it] / bArray[it] })
} else DoubleBuffer(DoubleArray(a.size) { a[it] / b[it] }) } else DoubleBuffer(DoubleArray(left.size) { left[it] / right[it] })
} }
override fun sin(arg: Buffer<Double>): DoubleBuffer = if (arg is DoubleBuffer) { override fun sin(arg: Buffer<Double>): DoubleBuffer = if (arg is DoubleBuffer) {

View File

@ -73,10 +73,10 @@ public object DoubleField : ExtendedField<Double>, Norm<Double, Double>, ScaleOp
else -> super<ExtendedField>.binaryOperationFunction(operation) else -> super<ExtendedField>.binaryOperationFunction(operation)
} }
override inline fun add(a: Double, b: Double): Double = a + b override inline fun add(left: Double, right: Double): Double = left + right
override inline fun multiply(a: Double, b: Double): Double = a * b override inline fun multiply(left: Double, right: Double): Double = left * right
override inline fun divide(a: Double, b: Double): Double = a / b override inline fun divide(left: Double, right: Double): Double = left / right
override inline fun scale(a: Double, value: Double): Double = a * value override inline fun scale(a: Double, value: Double): Double = a * value
@ -102,10 +102,10 @@ public object DoubleField : ExtendedField<Double>, Norm<Double, Double>, ScaleOp
override inline fun norm(arg: Double): Double = abs(arg) override inline fun norm(arg: Double): Double = abs(arg)
override inline fun Double.unaryMinus(): Double = -this override inline fun Double.unaryMinus(): Double = -this
override inline fun Double.plus(b: Double): Double = this + b override inline fun Double.plus(other: Double): Double = this + other
override inline fun Double.minus(b: Double): Double = this - b override inline fun Double.minus(other: Double): Double = this - other
override inline fun Double.times(b: Double): Double = this * b override inline fun Double.times(other: Double): Double = this * other
override inline fun Double.div(b: Double): Double = this / b override inline fun Double.div(other: Double): Double = this / other
} }
public val Double.Companion.algebra: DoubleField get() = DoubleField public val Double.Companion.algebra: DoubleField get() = DoubleField
@ -126,12 +126,12 @@ public object FloatField : ExtendedField<Float>, Norm<Float, Float> {
else -> super.binaryOperationFunction(operation) else -> super.binaryOperationFunction(operation)
} }
override inline fun add(a: Float, b: Float): Float = a + b override inline fun add(left: Float, right: Float): Float = left + right
override fun scale(a: Float, value: Double): Float = a * value.toFloat() override fun scale(a: Float, value: Double): Float = a * value.toFloat()
override inline fun multiply(a: Float, b: Float): Float = a * b override inline fun multiply(left: Float, right: Float): Float = left * right
override inline fun divide(a: Float, b: Float): Float = a / b override inline fun divide(left: Float, right: Float): Float = left / right
override inline fun sin(arg: Float): Float = kotlin.math.sin(arg) override inline fun sin(arg: Float): Float = kotlin.math.sin(arg)
override inline fun cos(arg: Float): Float = kotlin.math.cos(arg) override inline fun cos(arg: Float): Float = kotlin.math.cos(arg)
@ -155,10 +155,10 @@ public object FloatField : ExtendedField<Float>, Norm<Float, Float> {
override inline fun norm(arg: Float): Float = abs(arg) override inline fun norm(arg: Float): Float = abs(arg)
override inline fun Float.unaryMinus(): Float = -this override inline fun Float.unaryMinus(): Float = -this
override inline fun Float.plus(b: Float): Float = this + b override inline fun Float.plus(other: Float): Float = this + other
override inline fun Float.minus(b: Float): Float = this - b override inline fun Float.minus(other: Float): Float = this - other
override inline fun Float.times(b: Float): Float = this * b override inline fun Float.times(other: Float): Float = this * other
override inline fun Float.div(b: Float): Float = this / b override inline fun Float.div(other: Float): Float = this / other
} }
public val Float.Companion.algebra: FloatField get() = FloatField public val Float.Companion.algebra: FloatField get() = FloatField
@ -175,14 +175,14 @@ public object IntRing : Ring<Int>, Norm<Int, Int>, NumericAlgebra<Int> {
get() = 1 get() = 1
override fun number(value: Number): Int = value.toInt() override fun number(value: Number): Int = value.toInt()
override inline fun add(a: Int, b: Int): Int = a + b override inline fun add(left: Int, right: Int): Int = left + right
override inline fun multiply(a: Int, b: Int): Int = a * b override inline fun multiply(left: Int, right: Int): Int = left * right
override inline fun norm(arg: Int): Int = abs(arg) override inline fun norm(arg: Int): Int = abs(arg)
override inline fun Int.unaryMinus(): Int = -this override inline fun Int.unaryMinus(): Int = -this
override inline fun Int.plus(b: Int): Int = this + b override inline fun Int.plus(other: Int): Int = this + other
override inline fun Int.minus(b: Int): Int = this - b override inline fun Int.minus(other: Int): Int = this - other
override inline fun Int.times(b: Int): Int = this * b override inline fun Int.times(other: Int): Int = this * other
} }
public val Int.Companion.algebra: IntRing get() = IntRing public val Int.Companion.algebra: IntRing get() = IntRing
@ -199,14 +199,14 @@ public object ShortRing : Ring<Short>, Norm<Short, Short>, NumericAlgebra<Short>
get() = 1 get() = 1
override fun number(value: Number): Short = value.toShort() override fun number(value: Number): Short = value.toShort()
override inline fun add(a: Short, b: Short): Short = (a + b).toShort() override inline fun add(left: Short, right: Short): Short = (left + right).toShort()
override inline fun multiply(a: Short, b: Short): Short = (a * b).toShort() override inline fun multiply(left: Short, right: Short): Short = (left * right).toShort()
override fun norm(arg: Short): Short = if (arg > 0) arg else (-arg).toShort() override fun norm(arg: Short): Short = if (arg > 0) arg else (-arg).toShort()
override inline fun Short.unaryMinus(): Short = (-this).toShort() override inline fun Short.unaryMinus(): Short = (-this).toShort()
override inline fun Short.plus(b: Short): Short = (this + b).toShort() override inline fun Short.plus(other: Short): Short = (this + other).toShort()
override inline fun Short.minus(b: Short): Short = (this - b).toShort() override inline fun Short.minus(other: Short): Short = (this - other).toShort()
override inline fun Short.times(b: Short): Short = (this * b).toShort() override inline fun Short.times(other: Short): Short = (this * other).toShort()
} }
public val Short.Companion.algebra: ShortRing get() = ShortRing public val Short.Companion.algebra: ShortRing get() = ShortRing
@ -223,14 +223,14 @@ public object ByteRing : Ring<Byte>, Norm<Byte, Byte>, NumericAlgebra<Byte> {
get() = 1 get() = 1
override fun number(value: Number): Byte = value.toByte() override fun number(value: Number): Byte = value.toByte()
override inline fun add(a: Byte, b: Byte): Byte = (a + b).toByte() override inline fun add(left: Byte, right: Byte): Byte = (left + right).toByte()
override inline fun multiply(a: Byte, b: Byte): Byte = (a * b).toByte() override inline fun multiply(left: Byte, right: Byte): Byte = (left * right).toByte()
override fun norm(arg: Byte): Byte = if (arg > 0) arg else (-arg).toByte() override fun norm(arg: Byte): Byte = if (arg > 0) arg else (-arg).toByte()
override inline fun Byte.unaryMinus(): Byte = (-this).toByte() override inline fun Byte.unaryMinus(): Byte = (-this).toByte()
override inline fun Byte.plus(b: Byte): Byte = (this + b).toByte() override inline fun Byte.plus(other: Byte): Byte = (this + other).toByte()
override inline fun Byte.minus(b: Byte): Byte = (this - b).toByte() override inline fun Byte.minus(other: Byte): Byte = (this - other).toByte()
override inline fun Byte.times(b: Byte): Byte = (this * b).toByte() override inline fun Byte.times(other: Byte): Byte = (this * other).toByte()
} }
public val Byte.Companion.algebra: ByteRing get() = ByteRing public val Byte.Companion.algebra: ByteRing get() = ByteRing
@ -247,14 +247,14 @@ public object LongRing : Ring<Long>, Norm<Long, Long>, NumericAlgebra<Long> {
get() = 1L get() = 1L
override fun number(value: Number): Long = value.toLong() override fun number(value: Number): Long = value.toLong()
override inline fun add(a: Long, b: Long): Long = a + b override inline fun add(left: Long, right: Long): Long = left + right
override inline fun multiply(a: Long, b: Long): Long = a * b override inline fun multiply(left: Long, right: Long): Long = left * right
override fun norm(arg: Long): Long = abs(arg) override fun norm(arg: Long): Long = abs(arg)
override inline fun Long.unaryMinus(): Long = (-this) override inline fun Long.unaryMinus(): Long = (-this)
override inline fun Long.plus(b: Long): Long = (this + b) override inline fun Long.plus(other: Long): Long = (this + other)
override inline fun Long.minus(b: Long): Long = (this - b) override inline fun Long.minus(other: Long): Long = (this - other)
override inline fun Long.times(b: Long): Long = (this * b) override inline fun Long.times(other: Long): Long = (this * other)
} }
public val Long.Companion.algebra: LongRing get() = LongRing public val Long.Companion.algebra: LongRing get() = LongRing

View File

@ -18,9 +18,9 @@ public object JBigIntegerField : Ring<BigInteger>, NumericAlgebra<BigInteger> {
override val one: BigInteger get() = BigInteger.ONE override val one: BigInteger get() = BigInteger.ONE
override fun number(value: Number): BigInteger = BigInteger.valueOf(value.toLong()) override fun number(value: Number): BigInteger = BigInteger.valueOf(value.toLong())
override fun add(a: BigInteger, b: BigInteger): BigInteger = a.add(b) override fun add(left: BigInteger, right: BigInteger): BigInteger = left.add(right)
override operator fun BigInteger.minus(b: BigInteger): BigInteger = subtract(b) override operator fun BigInteger.minus(other: BigInteger): BigInteger = subtract(other)
override fun multiply(a: BigInteger, b: BigInteger): BigInteger = a.multiply(b) override fun multiply(left: BigInteger, right: BigInteger): BigInteger = left.multiply(right)
override operator fun BigInteger.unaryMinus(): BigInteger = negate() override operator fun BigInteger.unaryMinus(): BigInteger = negate()
} }
@ -39,15 +39,15 @@ public abstract class JBigDecimalFieldBase internal constructor(
override val one: BigDecimal override val one: BigDecimal
get() = BigDecimal.ONE get() = BigDecimal.ONE
override fun add(a: BigDecimal, b: BigDecimal): BigDecimal = a.add(b) override fun add(left: BigDecimal, right: BigDecimal): BigDecimal = left.add(right)
override operator fun BigDecimal.minus(b: BigDecimal): BigDecimal = subtract(b) override operator fun BigDecimal.minus(other: BigDecimal): BigDecimal = subtract(other)
override fun number(value: Number): BigDecimal = BigDecimal.valueOf(value.toDouble()) override fun number(value: Number): BigDecimal = BigDecimal.valueOf(value.toDouble())
override fun scale(a: BigDecimal, value: Double): BigDecimal = override fun scale(a: BigDecimal, value: Double): BigDecimal =
a.multiply(value.toBigDecimal(mathContext), mathContext) a.multiply(value.toBigDecimal(mathContext), mathContext)
override fun multiply(a: BigDecimal, b: BigDecimal): BigDecimal = a.multiply(b, mathContext) override fun multiply(left: BigDecimal, right: BigDecimal): BigDecimal = left.multiply(right, mathContext)
override fun divide(a: BigDecimal, b: BigDecimal): BigDecimal = a.divide(b, mathContext) override fun divide(left: BigDecimal, right: BigDecimal): BigDecimal = left.divide(right, mathContext)
override fun power(arg: BigDecimal, pow: Number): BigDecimal = arg.pow(pow.toInt(), mathContext) override fun power(arg: BigDecimal, pow: Number): BigDecimal = arg.pow(pow.toInt(), mathContext)
override fun sqrt(arg: BigDecimal): BigDecimal = arg.sqrt(mathContext) override fun sqrt(arg: BigDecimal): BigDecimal = arg.sqrt(mathContext)
override operator fun BigDecimal.unaryMinus(): BigDecimal = negate(mathContext) override operator fun BigDecimal.unaryMinus(): BigDecimal = negate(mathContext)

View File

@ -104,12 +104,12 @@ public class PolynomialSpace<T, C>(
Polynomial(coefficients.map { -it }) Polynomial(coefficients.map { -it })
} }
override fun add(a: Polynomial<T>, b: Polynomial<T>): Polynomial<T> { override fun add(left: Polynomial<T>, right: Polynomial<T>): Polynomial<T> {
val dim = max(a.coefficients.size, b.coefficients.size) val dim = max(left.coefficients.size, right.coefficients.size)
return ring { return ring {
Polynomial(List(dim) { index -> Polynomial(List(dim) { index ->
a.coefficients.getOrElse(index) { zero } + b.coefficients.getOrElse(index) { zero } left.coefficients.getOrElse(index) { zero } + right.coefficients.getOrElse(index) { zero }
}) })
} }
} }

View File

@ -47,7 +47,7 @@ public object Euclidean2DSpace : GeometrySpace<Vector2D>, ScaleOperations<Vector
override fun Vector2D.unaryMinus(): Vector2D = Vector2D(-x, -y) override fun Vector2D.unaryMinus(): Vector2D = Vector2D(-x, -y)
override fun Vector2D.distanceTo(other: Vector2D): Double = (this - other).norm() override fun Vector2D.distanceTo(other: Vector2D): Double = (this - other).norm()
override fun add(a: Vector2D, b: Vector2D): Vector2D = Vector2D(a.x + b.x, a.y + b.y) override fun add(left: Vector2D, right: Vector2D): Vector2D = Vector2D(left.x + right.x, left.y + right.y)
override fun scale(a: Vector2D, value: Double): Vector2D = Vector2D(a.x * value, a.y * value) override fun scale(a: Vector2D, value: Double): Vector2D = Vector2D(a.x * value, a.y * value)
override fun Vector2D.dot(other: Vector2D): Double = x * other.x + y * other.y override fun Vector2D.dot(other: Vector2D): Double = x * other.x + y * other.y
} }

View File

@ -47,8 +47,8 @@ public object Euclidean3DSpace : GeometrySpace<Vector3D>, ScaleOperations<Vector
override fun Vector3D.distanceTo(other: Vector3D): Double = (this - other).norm() override fun Vector3D.distanceTo(other: Vector3D): Double = (this - other).norm()
override fun add(a: Vector3D, b: Vector3D): Vector3D = override fun add(left: Vector3D, right: Vector3D): Vector3D =
Vector3D(a.x + b.x, a.y + b.y, a.z + b.z) Vector3D(left.x + right.x, left.y + right.y, left.z + right.z)
override fun scale(a: Vector3D, value: Double): Vector3D = override fun scale(a: Vector3D, value: Double): Vector3D =
Vector3D(a.x * value, a.y * value, a.z * value) Vector3D(a.x * value, a.y * value, a.z * value)

View File

@ -67,10 +67,10 @@ public interface IndexedHistogramSpace<T : Comparable<T>, V : Any>
public fun produce(builder: HistogramBuilder<T>.() -> Unit): IndexedHistogram<T, V> public fun produce(builder: HistogramBuilder<T>.() -> Unit): IndexedHistogram<T, V>
override fun add(a: IndexedHistogram<T, V>, b: IndexedHistogram<T, V>): IndexedHistogram<T, V> { override fun add(left: IndexedHistogram<T, V>, right: IndexedHistogram<T, V>): IndexedHistogram<T, V> {
require(a.context == this) { "Can't operate on a histogram produced by external space" } require(left.context == this) { "Can't operate on a histogram produced by external space" }
require(b.context == this) { "Can't operate on a histogram produced by external space" } require(right.context == this) { "Can't operate on a histogram produced by external space" }
return IndexedHistogram(this, histogramValueSpace { a.values + b.values }) return IndexedHistogram(this, histogramValueSpace { left.values + right.values })
} }
override fun scale(a: IndexedHistogram<T, V>, value: Double): IndexedHistogram<T, V> { override fun scale(a: IndexedHistogram<T, V>, value: Double): IndexedHistogram<T, V> {

View File

@ -88,20 +88,20 @@ public class TreeHistogramSpace(
TreeHistogramBuilder(binFactory).apply(block).build() TreeHistogramBuilder(binFactory).apply(block).build()
override fun add( override fun add(
a: UnivariateHistogram, left: UnivariateHistogram,
b: UnivariateHistogram, right: UnivariateHistogram,
): UnivariateHistogram { ): UnivariateHistogram {
// require(a.context == this) { "Histogram $a does not belong to this context" } // require(a.context == this) { "Histogram $a does not belong to this context" }
// require(b.context == this) { "Histogram $b does not belong to this context" } // require(b.context == this) { "Histogram $b does not belong to this context" }
val bins = TreeMap<Double, UnivariateBin>().apply { val bins = TreeMap<Double, UnivariateBin>().apply {
(a.bins.map { it.domain } union b.bins.map { it.domain }).forEach { def -> (left.bins.map { it.domain } union right.bins.map { it.domain }).forEach { def ->
put( put(
def.center, def.center,
UnivariateBin( UnivariateBin(
def, def,
value = (a[def.center]?.value ?: 0.0) + (b[def.center]?.value ?: 0.0), value = (left[def.center]?.value ?: 0.0) + (right[def.center]?.value ?: 0.0),
standardDeviation = (a[def.center]?.standardDeviation standardDeviation = (left[def.center]?.standardDeviation
?: 0.0) + (b[def.center]?.standardDeviation ?: 0.0) ?: 0.0) + (right[def.center]?.standardDeviation ?: 0.0)
) )
) )
} }

View File

@ -28,10 +28,10 @@ public object JafamaDoubleField : ExtendedField<Double>, Norm<Double, Double>, S
else -> super<ExtendedField>.binaryOperationFunction(operation) else -> super<ExtendedField>.binaryOperationFunction(operation)
} }
override inline fun add(a: Double, b: Double): Double = a + b override inline fun add(left: Double, right: Double): Double = left + right
override inline fun multiply(a: Double, b: Double): Double = a * b override inline fun multiply(left: Double, right: Double): Double = left * right
override inline fun divide(a: Double, b: Double): Double = a / b override inline fun divide(left: Double, right: Double): Double = left / right
override inline fun scale(a: Double, value: Double): Double = a * value override inline fun scale(a: Double, value: Double): Double = a * value
@ -57,10 +57,10 @@ public object JafamaDoubleField : ExtendedField<Double>, Norm<Double, Double>, S
override inline fun norm(arg: Double): Double = FastMath.abs(arg) override inline fun norm(arg: Double): Double = FastMath.abs(arg)
override inline fun Double.unaryMinus(): Double = -this override inline fun Double.unaryMinus(): Double = -this
override inline fun Double.plus(b: Double): Double = this + b override inline fun Double.plus(other: Double): Double = this + other
override inline fun Double.minus(b: Double): Double = this - b override inline fun Double.minus(other: Double): Double = this - other
override inline fun Double.times(b: Double): Double = this * b override inline fun Double.times(other: Double): Double = this * other
override inline fun Double.div(b: Double): Double = this / b override inline fun Double.div(other: Double): Double = this / other
} }
/** /**
@ -79,10 +79,10 @@ public object StrictJafamaDoubleField : ExtendedField<Double>, Norm<Double, Doub
else -> super<ExtendedField>.binaryOperationFunction(operation) else -> super<ExtendedField>.binaryOperationFunction(operation)
} }
override inline fun add(a: Double, b: Double): Double = a + b override inline fun add(left: Double, right: Double): Double = left + right
override inline fun multiply(a: Double, b: Double): Double = a * b override inline fun multiply(left: Double, right: Double): Double = left * right
override inline fun divide(a: Double, b: Double): Double = a / b override inline fun divide(left: Double, right: Double): Double = left / right
override inline fun scale(a: Double, value: Double): Double = a * value override inline fun scale(a: Double, value: Double): Double = a * value
@ -108,8 +108,8 @@ public object StrictJafamaDoubleField : ExtendedField<Double>, Norm<Double, Doub
override inline fun norm(arg: Double): Double = StrictFastMath.abs(arg) override inline fun norm(arg: Double): Double = StrictFastMath.abs(arg)
override inline fun Double.unaryMinus(): Double = -this override inline fun Double.unaryMinus(): Double = -this
override inline fun Double.plus(b: Double): Double = this + b override inline fun Double.plus(other: Double): Double = this + other
override inline fun Double.minus(b: Double): Double = this - b override inline fun Double.minus(other: Double): Double = this - other
override inline fun Double.times(b: Double): Double = this * b override inline fun Double.times(other: Double): Double = this * other
override inline fun Double.div(b: Double): Double = this / b override inline fun Double.div(other: Double): Double = this / other
} }

View File

@ -72,11 +72,11 @@ public sealed interface Nd4jArrayAlgebra<T, out C : Algebra<T>> : AlgebraND<T, C
*/ */
public sealed interface Nd4jArrayGroupOps<T, out S : Ring<T>> : GroupOpsND<T, S>, Nd4jArrayAlgebra<T, S> { public sealed interface Nd4jArrayGroupOps<T, out S : Ring<T>> : GroupOpsND<T, S>, Nd4jArrayAlgebra<T, S> {
override fun add(a: StructureND<T>, b: StructureND<T>): Nd4jArrayStructure<T> = override fun add(left: StructureND<T>, right: StructureND<T>): Nd4jArrayStructure<T> =
a.ndArray.add(b.ndArray).wrap() left.ndArray.add(right.ndArray).wrap()
override operator fun StructureND<T>.minus(b: StructureND<T>): Nd4jArrayStructure<T> = override operator fun StructureND<T>.minus(other: StructureND<T>): Nd4jArrayStructure<T> =
ndArray.sub(b.ndArray).wrap() ndArray.sub(other.ndArray).wrap()
override operator fun StructureND<T>.unaryMinus(): Nd4jArrayStructure<T> = override operator fun StructureND<T>.unaryMinus(): Nd4jArrayStructure<T> =
ndArray.neg().wrap() ndArray.neg().wrap()
@ -94,8 +94,8 @@ public sealed interface Nd4jArrayGroupOps<T, out S : Ring<T>> : GroupOpsND<T, S>
@OptIn(UnstableKMathAPI::class) @OptIn(UnstableKMathAPI::class)
public sealed interface Nd4jArrayRingOps<T, out R : Ring<T>> : RingOpsND<T, R>, Nd4jArrayGroupOps<T, R> { public sealed interface Nd4jArrayRingOps<T, out R : Ring<T>> : RingOpsND<T, R>, Nd4jArrayGroupOps<T, R> {
override fun multiply(a: StructureND<T>, b: StructureND<T>): Nd4jArrayStructure<T> = override fun multiply(left: StructureND<T>, right: StructureND<T>): Nd4jArrayStructure<T> =
a.ndArray.mul(b.ndArray).wrap() left.ndArray.mul(right.ndArray).wrap()
// //
// override operator fun Nd4jArrayStructure<T>.minus(b: Number): Nd4jArrayStructure<T> { // override operator fun Nd4jArrayStructure<T>.minus(b: Number): Nd4jArrayStructure<T> {
// check(this) // check(this)
@ -132,8 +132,8 @@ public sealed interface Nd4jArrayRingOps<T, out R : Ring<T>> : RingOpsND<T, R>,
*/ */
public sealed interface Nd4jArrayField<T, out F : Field<T>> : FieldOpsND<T, F>, Nd4jArrayRingOps<T, F> { public sealed interface Nd4jArrayField<T, out F : Field<T>> : FieldOpsND<T, F>, Nd4jArrayRingOps<T, F> {
override fun divide(a: StructureND<T>, b: StructureND<T>): Nd4jArrayStructure<T> = override fun divide(left: StructureND<T>, right: StructureND<T>): Nd4jArrayStructure<T> =
a.ndArray.div(b.ndArray).wrap() left.ndArray.div(right.ndArray).wrap()
public operator fun Number.div(b: StructureND<T>): Nd4jArrayStructure<T> = b.ndArray.rdiv(this).wrap() public operator fun Number.div(b: StructureND<T>): Nd4jArrayStructure<T> = b.ndArray.rdiv(this).wrap()

View File

@ -41,8 +41,8 @@ public class SamplerSpace<T : Any, out S>(public val algebra: S) : Group<Sampler
override val zero: Sampler<T> = ConstantSampler(algebra.zero) override val zero: Sampler<T> = ConstantSampler(algebra.zero)
override fun add(a: Sampler<T>, b: Sampler<T>): Sampler<T> = BasicSampler { generator -> override fun add(left: Sampler<T>, right: Sampler<T>): Sampler<T> = BasicSampler { generator ->
a.sample(generator).zip(b.sample(generator)) { aValue, bValue -> algebra { aValue + bValue } } left.sample(generator).zip(right.sample(generator)) { aValue, bValue -> algebra { aValue + bValue } }
} }
override fun scale(a: Sampler<T>, value: Double): Sampler<T> = BasicSampler { generator -> override fun scale(a: Sampler<T>, value: Double): Sampler<T> = BasicSampler { generator ->

View File

@ -5,7 +5,7 @@
package space.kscience.kmath.tensors.api package space.kscience.kmath.tensors.api
import space.kscience.kmath.operations.Algebra import space.kscience.kmath.operations.RingOps
/** /**
* Algebra over a ring on [Tensor]. * Algebra over a ring on [Tensor].
@ -13,7 +13,7 @@ import space.kscience.kmath.operations.Algebra
* *
* @param T the type of items in the tensors. * @param T the type of items in the tensors.
*/ */
public interface TensorAlgebra<T> : Algebra<Tensor<T>> { public interface TensorAlgebra<T> : RingOps<Tensor<T>> {
/** /**
* Returns a single tensor value of unit dimension if tensor shape equals to [1]. * Returns a single tensor value of unit dimension if tensor shape equals to [1].
* *
@ -53,7 +53,7 @@ public interface TensorAlgebra<T> : Algebra<Tensor<T>> {
* @param other tensor to be added. * @param other tensor to be added.
* @return the sum of this tensor and [other]. * @return the sum of this tensor and [other].
*/ */
public operator fun Tensor<T>.plus(other: Tensor<T>): Tensor<T> override fun Tensor<T>.plus(other: Tensor<T>): Tensor<T>
/** /**
* Adds the scalar [value] to each element of this tensor. * Adds the scalar [value] to each element of this tensor.
@ -93,7 +93,7 @@ public interface TensorAlgebra<T> : Algebra<Tensor<T>> {
* @param other tensor to be subtracted. * @param other tensor to be subtracted.
* @return the difference between this tensor and [other]. * @return the difference between this tensor and [other].
*/ */
public operator fun Tensor<T>.minus(other: Tensor<T>): Tensor<T> override fun Tensor<T>.minus(other: Tensor<T>): Tensor<T>
/** /**
* Subtracts the scalar [value] from each element of this tensor. * Subtracts the scalar [value] from each element of this tensor.
@ -134,7 +134,7 @@ public interface TensorAlgebra<T> : Algebra<Tensor<T>> {
* @param other tensor to be multiplied. * @param other tensor to be multiplied.
* @return the product of this tensor and [other]. * @return the product of this tensor and [other].
*/ */
public operator fun Tensor<T>.times(other: Tensor<T>): Tensor<T> override fun Tensor<T>.times(other: Tensor<T>): Tensor<T>
/** /**
* Multiplies the scalar [value] by each element of this tensor. * Multiplies the scalar [value] by each element of this tensor.
@ -155,7 +155,7 @@ public interface TensorAlgebra<T> : Algebra<Tensor<T>> {
* *
* @return tensor negation of the original tensor. * @return tensor negation of the original tensor.
*/ */
public operator fun Tensor<T>.unaryMinus(): Tensor<T> override fun Tensor<T>.unaryMinus(): Tensor<T>
/** /**
* Returns the tensor at index i * Returns the tensor at index i
@ -323,4 +323,8 @@ public interface TensorAlgebra<T> : Algebra<Tensor<T>> {
* @return the index of maximum value of each row of the input tensor in the given dimension [dim]. * @return the index of maximum value of each row of the input tensor in the given dimension [dim].
*/ */
public fun Tensor<T>.argMax(dim: Int, keepDim: Boolean): Tensor<T> public fun Tensor<T>.argMax(dim: Int, keepDim: Boolean): Tensor<T>
override fun add(left: Tensor<T>, right: Tensor<T>): Tensor<T> = left + right
override fun multiply(left: Tensor<T>, right: Tensor<T>): Tensor<T> = left * right
} }

View File

@ -27,7 +27,7 @@ internal fun <T> Tensor<T>.copyToBufferedTensor(): BufferedTensor<T> =
internal fun <T> Tensor<T>.toBufferedTensor(): BufferedTensor<T> = when (this) { internal fun <T> Tensor<T>.toBufferedTensor(): BufferedTensor<T> = when (this) {
is BufferedTensor<T> -> this is BufferedTensor<T> -> this
is MutableBufferND<T> -> if (this.indexes == TensorLinearStructure(this.shape)) { is MutableBufferND<T> -> if (this.indexes == TensorLinearStructure(this.shape)) {
BufferedTensor(this.shape, this.mutableBuffer, 0) BufferedTensor(this.shape, this.buffer, 0)
} else { } else {
this.copyToBufferedTensor() this.copyToBufferedTensor()
} }

View File

@ -88,17 +88,17 @@ public class ViktorFieldND(override val shape: IntArray) : FieldND<Double, Doubl
} }
}.asStructure() }.asStructure()
override fun add(a: StructureND<Double>, b: StructureND<Double>): ViktorStructureND = override fun add(left: StructureND<Double>, right: StructureND<Double>): ViktorStructureND =
(a.f64Buffer + b.f64Buffer).asStructure() (left.f64Buffer + right.f64Buffer).asStructure()
override fun scale(a: StructureND<Double>, value: Double): ViktorStructureND = override fun scale(a: StructureND<Double>, value: Double): ViktorStructureND =
(a.f64Buffer * value).asStructure() (a.f64Buffer * value).asStructure()
override inline fun StructureND<Double>.plus(b: StructureND<Double>): ViktorStructureND = override inline fun StructureND<Double>.plus(other: StructureND<Double>): ViktorStructureND =
(f64Buffer + b.f64Buffer).asStructure() (f64Buffer + other.f64Buffer).asStructure()
override inline fun StructureND<Double>.minus(b: StructureND<Double>): ViktorStructureND = override inline fun StructureND<Double>.minus(other: StructureND<Double>): ViktorStructureND =
(f64Buffer - b.f64Buffer).asStructure() (f64Buffer - other.f64Buffer).asStructure()
override inline fun StructureND<Double>.times(k: Number): ViktorStructureND = override inline fun StructureND<Double>.times(k: Number): ViktorStructureND =
(f64Buffer * k.toDouble()).asStructure() (f64Buffer * k.toDouble()).asStructure()