forked from kscience/kmath
Shapeless ND and Buffer algebras
This commit is contained in:
parent
d0354da80a
commit
688382eed6
@ -57,12 +57,12 @@ internal class NDFieldBenchmark {
|
|||||||
blackhole.consume(res)
|
blackhole.consume(res)
|
||||||
}
|
}
|
||||||
|
|
||||||
@Benchmark
|
// @Benchmark
|
||||||
fun nd4jAdd(blackhole: Blackhole) = with(nd4jField) {
|
// fun nd4jAdd(blackhole: Blackhole) = with(nd4jField) {
|
||||||
var res: StructureND<Double> = one(dim, dim)
|
// var res: StructureND<Double> = one(dim, dim)
|
||||||
repeat(n) { res += 1.0 }
|
// repeat(n) { res += 1.0 }
|
||||||
blackhole.consume(res)
|
// blackhole.consume(res)
|
||||||
}
|
// }
|
||||||
|
|
||||||
private companion object {
|
private companion object {
|
||||||
private const val dim = 1000
|
private const val dim = 1000
|
||||||
|
@ -9,6 +9,7 @@ import space.kscience.kmath.integration.gaussIntegrator
|
|||||||
import space.kscience.kmath.integration.integrate
|
import space.kscience.kmath.integration.integrate
|
||||||
import space.kscience.kmath.integration.value
|
import space.kscience.kmath.integration.value
|
||||||
import space.kscience.kmath.nd.StructureND
|
import space.kscience.kmath.nd.StructureND
|
||||||
|
import space.kscience.kmath.nd.produce
|
||||||
import space.kscience.kmath.nd.withNdAlgebra
|
import space.kscience.kmath.nd.withNdAlgebra
|
||||||
import space.kscience.kmath.operations.algebra
|
import space.kscience.kmath.operations.algebra
|
||||||
import space.kscience.kmath.operations.invoke
|
import space.kscience.kmath.operations.invoke
|
||||||
|
@ -12,6 +12,7 @@ import space.kscience.kmath.linear.transpose
|
|||||||
import space.kscience.kmath.nd.StructureND
|
import space.kscience.kmath.nd.StructureND
|
||||||
import space.kscience.kmath.nd.as2D
|
import space.kscience.kmath.nd.as2D
|
||||||
import space.kscience.kmath.nd.ndAlgebra
|
import space.kscience.kmath.nd.ndAlgebra
|
||||||
|
import space.kscience.kmath.nd.produce
|
||||||
import space.kscience.kmath.operations.DoubleField
|
import space.kscience.kmath.operations.DoubleField
|
||||||
import space.kscience.kmath.operations.invoke
|
import space.kscience.kmath.operations.invoke
|
||||||
import kotlin.system.measureTimeMillis
|
import kotlin.system.measureTimeMillis
|
||||||
|
@ -8,10 +8,8 @@ package space.kscience.kmath.structures
|
|||||||
import kotlinx.coroutines.DelicateCoroutinesApi
|
import kotlinx.coroutines.DelicateCoroutinesApi
|
||||||
import kotlinx.coroutines.GlobalScope
|
import kotlinx.coroutines.GlobalScope
|
||||||
import org.nd4j.linalg.factory.Nd4j
|
import org.nd4j.linalg.factory.Nd4j
|
||||||
import space.kscience.kmath.nd.StructureND
|
import space.kscience.kmath.nd.*
|
||||||
import space.kscience.kmath.nd.autoNdAlgebra
|
import space.kscience.kmath.nd4j.nd4j
|
||||||
import space.kscience.kmath.nd.ndAlgebra
|
|
||||||
import space.kscience.kmath.nd4j.Nd4jArrayField
|
|
||||||
import space.kscience.kmath.operations.DoubleField
|
import space.kscience.kmath.operations.DoubleField
|
||||||
import space.kscience.kmath.operations.invoke
|
import space.kscience.kmath.operations.invoke
|
||||||
import space.kscience.kmath.viktor.ViktorFieldND
|
import space.kscience.kmath.viktor.ViktorFieldND
|
||||||
@ -31,15 +29,17 @@ fun main() {
|
|||||||
Nd4j.zeros(0)
|
Nd4j.zeros(0)
|
||||||
val dim = 1000
|
val dim = 1000
|
||||||
val n = 1000
|
val n = 1000
|
||||||
|
val shape = Shape(dim, dim)
|
||||||
|
|
||||||
|
|
||||||
// automatically build context most suited for given type.
|
// automatically build context most suited for given type.
|
||||||
val autoField = DoubleField.autoNdAlgebra(dim, dim)
|
val autoField = BufferedFieldOpsND(DoubleField, Buffer.Companion::auto)
|
||||||
// specialized nd-field for Double. It works as generic Double field as well.
|
// specialized nd-field for Double. It works as generic Double field as well.
|
||||||
val realField = DoubleField.ndAlgebra(dim, dim)
|
val realField = DoubleField.ndAlgebra
|
||||||
//A generic boxing field. It should be used for objects, not primitives.
|
//A generic boxing field. It should be used for objects, not primitives.
|
||||||
val boxingField = DoubleField.ndAlgebra(Buffer.Companion::boxing, dim, dim)
|
val boxingField = BufferedFieldOpsND(DoubleField, Buffer.Companion::boxing)
|
||||||
// Nd4j specialized field.
|
// Nd4j specialized field.
|
||||||
val nd4jField = Nd4jArrayField.real(dim, dim)
|
val nd4jField = DoubleField.nd4j
|
||||||
//viktor field
|
//viktor field
|
||||||
val viktorField = ViktorFieldND(dim, dim)
|
val viktorField = ViktorFieldND(dim, dim)
|
||||||
//parallel processing based on Java Streams
|
//parallel processing based on Java Streams
|
||||||
@ -47,21 +47,21 @@ fun main() {
|
|||||||
|
|
||||||
measureAndPrint("Boxing addition") {
|
measureAndPrint("Boxing addition") {
|
||||||
boxingField {
|
boxingField {
|
||||||
var res: StructureND<Double> = one
|
var res: StructureND<Double> = one(shape)
|
||||||
repeat(n) { res += 1.0 }
|
repeat(n) { res += 1.0 }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
measureAndPrint("Specialized addition") {
|
measureAndPrint("Specialized addition") {
|
||||||
realField {
|
realField {
|
||||||
var res: StructureND<Double> = one
|
var res: StructureND<Double> = one(shape)
|
||||||
repeat(n) { res += 1.0 }
|
repeat(n) { res += 1.0 }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
measureAndPrint("Nd4j specialized addition") {
|
measureAndPrint("Nd4j specialized addition") {
|
||||||
nd4jField {
|
nd4jField {
|
||||||
var res: StructureND<Double> = one
|
var res: StructureND<Double> = one(shape)
|
||||||
repeat(n) { res += 1.0 }
|
repeat(n) { res += 1.0 }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -82,13 +82,13 @@ fun main() {
|
|||||||
|
|
||||||
measureAndPrint("Automatic field addition") {
|
measureAndPrint("Automatic field addition") {
|
||||||
autoField {
|
autoField {
|
||||||
var res: StructureND<Double> = one
|
var res: StructureND<Double> = one(shape)
|
||||||
repeat(n) { res += 1.0 }
|
repeat(n) { res += 1.0 }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
measureAndPrint("Lazy addition") {
|
measureAndPrint("Lazy addition") {
|
||||||
val res = realField.one.mapAsync(GlobalScope) {
|
val res = realField.one(shape).mapAsync(GlobalScope) {
|
||||||
var c = 0.0
|
var c = 0.0
|
||||||
repeat(n) {
|
repeat(n) {
|
||||||
c += 1.0
|
c += 1.0
|
||||||
|
@ -22,12 +22,12 @@ class StreamDoubleFieldND(override val shape: IntArray) : FieldND<Double, Double
|
|||||||
|
|
||||||
private val strides = DefaultStrides(shape)
|
private val strides = DefaultStrides(shape)
|
||||||
override val elementAlgebra: DoubleField get() = DoubleField
|
override val elementAlgebra: DoubleField get() = DoubleField
|
||||||
override val zero: BufferND<Double> by lazy { produce { zero } }
|
override val zero: BufferND<Double> by lazy { produce(shape) { zero } }
|
||||||
override val one: BufferND<Double> by lazy { produce { one } }
|
override val one: BufferND<Double> by lazy { produce(shape) { one } }
|
||||||
|
|
||||||
override fun number(value: Number): BufferND<Double> {
|
override fun number(value: Number): BufferND<Double> {
|
||||||
val d = value.toDouble() // minimize conversions
|
val d = value.toDouble() // minimize conversions
|
||||||
return produce { d }
|
return produce(shape) { d }
|
||||||
}
|
}
|
||||||
|
|
||||||
private val StructureND<Double>.buffer: DoubleBuffer
|
private val StructureND<Double>.buffer: DoubleBuffer
|
||||||
@ -40,7 +40,7 @@ class StreamDoubleFieldND(override val shape: IntArray) : FieldND<Double, Double
|
|||||||
else -> DoubleBuffer(strides.linearSize) { offset -> get(strides.index(offset)) }
|
else -> DoubleBuffer(strides.linearSize) { offset -> get(strides.index(offset)) }
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun produce(initializer: DoubleField.(IntArray) -> Double): BufferND<Double> {
|
override fun produce(shape: Shape, initializer: DoubleField.(IntArray) -> Double): BufferND<Double> {
|
||||||
val array = IntStream.range(0, strides.linearSize).parallel().mapToDouble { offset ->
|
val array = IntStream.range(0, strides.linearSize).parallel().mapToDouble { offset ->
|
||||||
val index = strides.index(offset)
|
val index = strides.index(offset)
|
||||||
DoubleField.initializer(index)
|
DoubleField.initializer(index)
|
||||||
@ -70,12 +70,12 @@ class StreamDoubleFieldND(override val shape: IntArray) : FieldND<Double, Double
|
|||||||
}
|
}
|
||||||
|
|
||||||
override fun zip(
|
override fun zip(
|
||||||
a: StructureND<Double>,
|
left: StructureND<Double>,
|
||||||
b: StructureND<Double>,
|
right: StructureND<Double>,
|
||||||
transform: DoubleField.(Double, Double) -> Double,
|
transform: DoubleField.(Double, Double) -> Double,
|
||||||
): BufferND<Double> {
|
): BufferND<Double> {
|
||||||
val array = IntStream.range(0, strides.linearSize).parallel().mapToDouble { offset ->
|
val array = IntStream.range(0, strides.linearSize).parallel().mapToDouble { offset ->
|
||||||
DoubleField.transform(a.buffer.array[offset], b.buffer.array[offset])
|
DoubleField.transform(left.buffer.array[offset], right.buffer.array[offset])
|
||||||
}.toArray()
|
}.toArray()
|
||||||
return BufferND(strides, array.asBuffer())
|
return BufferND(strides, array.asBuffer())
|
||||||
}
|
}
|
||||||
|
@ -70,12 +70,12 @@ public class DerivativeStructureField(
|
|||||||
|
|
||||||
override fun DerivativeStructure.unaryMinus(): DerivativeStructure = negate()
|
override fun DerivativeStructure.unaryMinus(): DerivativeStructure = negate()
|
||||||
|
|
||||||
override fun add(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.add(b)
|
override fun add(left: DerivativeStructure, right: DerivativeStructure): DerivativeStructure = left.add(right)
|
||||||
|
|
||||||
override fun scale(a: DerivativeStructure, value: Double): DerivativeStructure = a.multiply(value)
|
override fun scale(a: DerivativeStructure, value: Double): DerivativeStructure = a.multiply(value)
|
||||||
|
|
||||||
override fun multiply(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.multiply(b)
|
override fun multiply(left: DerivativeStructure, right: DerivativeStructure): DerivativeStructure = left.multiply(right)
|
||||||
override fun divide(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.divide(b)
|
override fun divide(left: DerivativeStructure, right: DerivativeStructure): DerivativeStructure = left.divide(right)
|
||||||
override fun sin(arg: DerivativeStructure): DerivativeStructure = arg.sin()
|
override fun sin(arg: DerivativeStructure): DerivativeStructure = arg.sin()
|
||||||
override fun cos(arg: DerivativeStructure): DerivativeStructure = arg.cos()
|
override fun cos(arg: DerivativeStructure): DerivativeStructure = arg.cos()
|
||||||
override fun tan(arg: DerivativeStructure): DerivativeStructure = arg.tan()
|
override fun tan(arg: DerivativeStructure): DerivativeStructure = arg.tan()
|
||||||
|
@ -77,33 +77,33 @@ public object ComplexField :
|
|||||||
|
|
||||||
override fun scale(a: Complex, value: Double): Complex = Complex(a.re * value, a.im * value)
|
override fun scale(a: Complex, value: Double): Complex = Complex(a.re * value, a.im * value)
|
||||||
|
|
||||||
override fun add(a: Complex, b: Complex): Complex = Complex(a.re + b.re, a.im + b.im)
|
override fun add(left: Complex, right: Complex): Complex = Complex(left.re + right.re, left.im + right.im)
|
||||||
// override fun multiply(a: Complex, k: Number): Complex = Complex(a.re * k.toDouble(), a.im * k.toDouble())
|
// override fun multiply(a: Complex, k: Number): Complex = Complex(a.re * k.toDouble(), a.im * k.toDouble())
|
||||||
|
|
||||||
override fun multiply(a: Complex, b: Complex): Complex =
|
override fun multiply(left: Complex, right: Complex): Complex =
|
||||||
Complex(a.re * b.re - a.im * b.im, a.re * b.im + a.im * b.re)
|
Complex(left.re * right.re - left.im * right.im, left.re * right.im + left.im * right.re)
|
||||||
|
|
||||||
override fun divide(a: Complex, b: Complex): Complex = when {
|
override fun divide(left: Complex, right: Complex): Complex = when {
|
||||||
abs(b.im) < abs(b.re) -> {
|
abs(right.im) < abs(right.re) -> {
|
||||||
val wr = b.im / b.re
|
val wr = right.im / right.re
|
||||||
val wd = b.re + wr * b.im
|
val wd = right.re + wr * right.im
|
||||||
|
|
||||||
if (wd.isNaN() || wd == 0.0)
|
if (wd.isNaN() || wd == 0.0)
|
||||||
throw ArithmeticException("Division by zero or infinity")
|
throw ArithmeticException("Division by zero or infinity")
|
||||||
else
|
else
|
||||||
Complex((a.re + a.im * wr) / wd, (a.im - a.re * wr) / wd)
|
Complex((left.re + left.im * wr) / wd, (left.im - left.re * wr) / wd)
|
||||||
}
|
}
|
||||||
|
|
||||||
b.im == 0.0 -> throw ArithmeticException("Division by zero")
|
right.im == 0.0 -> throw ArithmeticException("Division by zero")
|
||||||
|
|
||||||
else -> {
|
else -> {
|
||||||
val wr = b.re / b.im
|
val wr = right.re / right.im
|
||||||
val wd = b.im + wr * b.re
|
val wd = right.im + wr * right.re
|
||||||
|
|
||||||
if (wd.isNaN() || wd == 0.0)
|
if (wd.isNaN() || wd == 0.0)
|
||||||
throw ArithmeticException("Division by zero or infinity")
|
throw ArithmeticException("Division by zero or infinity")
|
||||||
else
|
else
|
||||||
Complex((a.re * wr + a.im) / wd, (a.im * wr - a.re) / wd)
|
Complex((left.re * wr + left.im) / wd, (left.im * wr - left.re) / wd)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -63,27 +63,27 @@ public object QuaternionField : Field<Quaternion>, Norm<Quaternion, Quaternion>,
|
|||||||
*/
|
*/
|
||||||
public val k: Quaternion = Quaternion(0, 0, 0, 1)
|
public val k: Quaternion = Quaternion(0, 0, 0, 1)
|
||||||
|
|
||||||
override fun add(a: Quaternion, b: Quaternion): Quaternion =
|
override fun add(left: Quaternion, right: Quaternion): Quaternion =
|
||||||
Quaternion(a.w + b.w, a.x + b.x, a.y + b.y, a.z + b.z)
|
Quaternion(left.w + right.w, left.x + right.x, left.y + right.y, left.z + right.z)
|
||||||
|
|
||||||
override fun scale(a: Quaternion, value: Double): Quaternion =
|
override fun scale(a: Quaternion, value: Double): Quaternion =
|
||||||
Quaternion(a.w * value, a.x * value, a.y * value, a.z * value)
|
Quaternion(a.w * value, a.x * value, a.y * value, a.z * value)
|
||||||
|
|
||||||
override fun multiply(a: Quaternion, b: Quaternion): Quaternion = Quaternion(
|
override fun multiply(left: Quaternion, right: Quaternion): Quaternion = Quaternion(
|
||||||
a.w * b.w - a.x * b.x - a.y * b.y - a.z * b.z,
|
left.w * right.w - left.x * right.x - left.y * right.y - left.z * right.z,
|
||||||
a.w * b.x + a.x * b.w + a.y * b.z - a.z * b.y,
|
left.w * right.x + left.x * right.w + left.y * right.z - left.z * right.y,
|
||||||
a.w * b.y - a.x * b.z + a.y * b.w + a.z * b.x,
|
left.w * right.y - left.x * right.z + left.y * right.w + left.z * right.x,
|
||||||
a.w * b.z + a.x * b.y - a.y * b.x + a.z * b.w,
|
left.w * right.z + left.x * right.y - left.y * right.x + left.z * right.w,
|
||||||
)
|
)
|
||||||
|
|
||||||
override fun divide(a: Quaternion, b: Quaternion): Quaternion {
|
override fun divide(left: Quaternion, right: Quaternion): Quaternion {
|
||||||
val s = b.w * b.w + b.x * b.x + b.y * b.y + b.z * b.z
|
val s = right.w * right.w + right.x * right.x + right.y * right.y + right.z * right.z
|
||||||
|
|
||||||
return Quaternion(
|
return Quaternion(
|
||||||
(b.w * a.w + b.x * a.x + b.y * a.y + b.z * a.z) / s,
|
(right.w * left.w + right.x * left.x + right.y * left.y + right.z * left.z) / s,
|
||||||
(b.w * a.x - b.x * a.w - b.y * a.z + b.z * a.y) / s,
|
(right.w * left.x - right.x * left.w - right.y * left.z + right.z * left.y) / s,
|
||||||
(b.w * a.y + b.x * a.z - b.y * a.w - b.z * a.x) / s,
|
(right.w * left.y + right.x * left.z - right.y * left.w - right.z * left.x) / s,
|
||||||
(b.w * a.z - b.x * a.y + b.y * a.x - b.z * a.w) / s,
|
(right.w * left.z - right.x * left.y + right.y * left.x - right.z * left.w) / s,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -57,8 +57,8 @@ public open class FunctionalExpressionGroup<T, out A : Group<T>>(
|
|||||||
/**
|
/**
|
||||||
* Builds an Expression of addition of two another expressions.
|
* Builds an Expression of addition of two another expressions.
|
||||||
*/
|
*/
|
||||||
override fun add(a: Expression<T>, b: Expression<T>): Expression<T> =
|
override fun add(left: Expression<T>, right: Expression<T>): Expression<T> =
|
||||||
binaryOperation(GroupOps.PLUS_OPERATION, a, b)
|
binaryOperation(GroupOps.PLUS_OPERATION, left, right)
|
||||||
|
|
||||||
// /**
|
// /**
|
||||||
// * Builds an Expression of multiplication of expression by number.
|
// * Builds an Expression of multiplication of expression by number.
|
||||||
@ -88,8 +88,8 @@ public open class FunctionalExpressionRing<T, out A : Ring<T>>(
|
|||||||
/**
|
/**
|
||||||
* Builds an Expression of multiplication of two expressions.
|
* Builds an Expression of multiplication of two expressions.
|
||||||
*/
|
*/
|
||||||
override fun multiply(a: Expression<T>, b: Expression<T>): Expression<T> =
|
override fun multiply(left: Expression<T>, right: Expression<T>): Expression<T> =
|
||||||
binaryOperationFunction(RingOps.TIMES_OPERATION)(a, b)
|
binaryOperationFunction(RingOps.TIMES_OPERATION)(left, right)
|
||||||
|
|
||||||
public operator fun Expression<T>.times(arg: T): Expression<T> = this * const(arg)
|
public operator fun Expression<T>.times(arg: T): Expression<T> = this * const(arg)
|
||||||
public operator fun T.times(arg: Expression<T>): Expression<T> = arg * this
|
public operator fun T.times(arg: Expression<T>): Expression<T> = arg * this
|
||||||
@ -107,8 +107,8 @@ public open class FunctionalExpressionField<T, out A : Field<T>>(
|
|||||||
/**
|
/**
|
||||||
* Builds an Expression of division an expression by another one.
|
* Builds an Expression of division an expression by another one.
|
||||||
*/
|
*/
|
||||||
override fun divide(a: Expression<T>, b: Expression<T>): Expression<T> =
|
override fun divide(left: Expression<T>, right: Expression<T>): Expression<T> =
|
||||||
binaryOperationFunction(FieldOps.DIV_OPERATION)(a, b)
|
binaryOperationFunction(FieldOps.DIV_OPERATION)(left, right)
|
||||||
|
|
||||||
public operator fun Expression<T>.div(arg: T): Expression<T> = this / const(arg)
|
public operator fun Expression<T>.div(arg: T): Expression<T> = this / const(arg)
|
||||||
public operator fun T.div(arg: Expression<T>): Expression<T> = arg / this
|
public operator fun T.div(arg: Expression<T>): Expression<T> = arg / this
|
||||||
|
@ -31,15 +31,15 @@ public object MstGroup : Group<MST>, NumericAlgebra<MST>, ScaleOperations<MST> {
|
|||||||
|
|
||||||
override fun number(value: Number): MST.Numeric = MstNumericAlgebra.number(value)
|
override fun number(value: Number): MST.Numeric = MstNumericAlgebra.number(value)
|
||||||
override fun bindSymbolOrNull(value: String): Symbol = MstNumericAlgebra.bindSymbolOrNull(value)
|
override fun bindSymbolOrNull(value: String): Symbol = MstNumericAlgebra.bindSymbolOrNull(value)
|
||||||
override fun add(a: MST, b: MST): MST.Binary = binaryOperationFunction(GroupOps.PLUS_OPERATION)(a, b)
|
override fun add(left: MST, right: MST): MST.Binary = binaryOperationFunction(GroupOps.PLUS_OPERATION)(left, right)
|
||||||
override operator fun MST.unaryPlus(): MST.Unary =
|
override operator fun MST.unaryPlus(): MST.Unary =
|
||||||
unaryOperationFunction(GroupOps.PLUS_OPERATION)(this)
|
unaryOperationFunction(GroupOps.PLUS_OPERATION)(this)
|
||||||
|
|
||||||
override operator fun MST.unaryMinus(): MST.Unary =
|
override operator fun MST.unaryMinus(): MST.Unary =
|
||||||
unaryOperationFunction(GroupOps.MINUS_OPERATION)(this)
|
unaryOperationFunction(GroupOps.MINUS_OPERATION)(this)
|
||||||
|
|
||||||
override operator fun MST.minus(b: MST): MST.Binary =
|
override operator fun MST.minus(other: MST): MST.Binary =
|
||||||
binaryOperationFunction(GroupOps.MINUS_OPERATION)(this, b)
|
binaryOperationFunction(GroupOps.MINUS_OPERATION)(this, other)
|
||||||
|
|
||||||
override fun scale(a: MST, value: Double): MST.Binary =
|
override fun scale(a: MST, value: Double): MST.Binary =
|
||||||
binaryOperationFunction(RingOps.TIMES_OPERATION)(a, number(value))
|
binaryOperationFunction(RingOps.TIMES_OPERATION)(a, number(value))
|
||||||
@ -62,17 +62,17 @@ public object MstRing : Ring<MST>, NumbersAddOps<MST>, ScaleOperations<MST> {
|
|||||||
|
|
||||||
override fun number(value: Number): MST.Numeric = MstGroup.number(value)
|
override fun number(value: Number): MST.Numeric = MstGroup.number(value)
|
||||||
override fun bindSymbolOrNull(value: String): Symbol = MstNumericAlgebra.bindSymbolOrNull(value)
|
override fun bindSymbolOrNull(value: String): Symbol = MstNumericAlgebra.bindSymbolOrNull(value)
|
||||||
override fun add(a: MST, b: MST): MST.Binary = MstGroup.add(a, b)
|
override fun add(left: MST, right: MST): MST.Binary = MstGroup.add(left, right)
|
||||||
|
|
||||||
override fun scale(a: MST, value: Double): MST.Binary =
|
override fun scale(a: MST, value: Double): MST.Binary =
|
||||||
MstGroup.binaryOperationFunction(RingOps.TIMES_OPERATION)(a, MstGroup.number(value))
|
MstGroup.binaryOperationFunction(RingOps.TIMES_OPERATION)(a, MstGroup.number(value))
|
||||||
|
|
||||||
override fun multiply(a: MST, b: MST): MST.Binary =
|
override fun multiply(left: MST, right: MST): MST.Binary =
|
||||||
binaryOperationFunction(RingOps.TIMES_OPERATION)(a, b)
|
binaryOperationFunction(RingOps.TIMES_OPERATION)(left, right)
|
||||||
|
|
||||||
override operator fun MST.unaryPlus(): MST.Unary = MstGroup { +this@unaryPlus }
|
override operator fun MST.unaryPlus(): MST.Unary = MstGroup { +this@unaryPlus }
|
||||||
override operator fun MST.unaryMinus(): MST.Unary = MstGroup { -this@unaryMinus }
|
override operator fun MST.unaryMinus(): MST.Unary = MstGroup { -this@unaryMinus }
|
||||||
override operator fun MST.minus(b: MST): MST.Binary = MstGroup { this@minus - b }
|
override operator fun MST.minus(other: MST): MST.Binary = MstGroup { this@minus - other }
|
||||||
|
|
||||||
override fun binaryOperationFunction(operation: String): (left: MST, right: MST) -> MST.Binary =
|
override fun binaryOperationFunction(operation: String): (left: MST, right: MST) -> MST.Binary =
|
||||||
MstGroup.binaryOperationFunction(operation)
|
MstGroup.binaryOperationFunction(operation)
|
||||||
@ -92,18 +92,18 @@ public object MstField : Field<MST>, NumbersAddOps<MST>, ScaleOperations<MST> {
|
|||||||
|
|
||||||
override fun bindSymbolOrNull(value: String): Symbol = MstNumericAlgebra.bindSymbolOrNull(value)
|
override fun bindSymbolOrNull(value: String): Symbol = MstNumericAlgebra.bindSymbolOrNull(value)
|
||||||
override fun number(value: Number): MST.Numeric = MstRing.number(value)
|
override fun number(value: Number): MST.Numeric = MstRing.number(value)
|
||||||
override fun add(a: MST, b: MST): MST.Binary = MstRing.add(a, b)
|
override fun add(left: MST, right: MST): MST.Binary = MstRing.add(left, right)
|
||||||
|
|
||||||
override fun scale(a: MST, value: Double): MST.Binary =
|
override fun scale(a: MST, value: Double): MST.Binary =
|
||||||
MstGroup.binaryOperationFunction(RingOps.TIMES_OPERATION)(a, MstGroup.number(value))
|
MstGroup.binaryOperationFunction(RingOps.TIMES_OPERATION)(a, MstGroup.number(value))
|
||||||
|
|
||||||
override fun multiply(a: MST, b: MST): MST.Binary = MstRing.multiply(a, b)
|
override fun multiply(left: MST, right: MST): MST.Binary = MstRing.multiply(left, right)
|
||||||
override fun divide(a: MST, b: MST): MST.Binary =
|
override fun divide(left: MST, right: MST): MST.Binary =
|
||||||
binaryOperationFunction(FieldOps.DIV_OPERATION)(a, b)
|
binaryOperationFunction(FieldOps.DIV_OPERATION)(left, right)
|
||||||
|
|
||||||
override operator fun MST.unaryPlus(): MST.Unary = MstRing { +this@unaryPlus }
|
override operator fun MST.unaryPlus(): MST.Unary = MstRing { +this@unaryPlus }
|
||||||
override operator fun MST.unaryMinus(): MST.Unary = MstRing { -this@unaryMinus }
|
override operator fun MST.unaryMinus(): MST.Unary = MstRing { -this@unaryMinus }
|
||||||
override operator fun MST.minus(b: MST): MST.Binary = MstRing { this@minus - b }
|
override operator fun MST.minus(other: MST): MST.Binary = MstRing { this@minus - other }
|
||||||
|
|
||||||
override fun binaryOperationFunction(operation: String): (left: MST, right: MST) -> MST.Binary =
|
override fun binaryOperationFunction(operation: String): (left: MST, right: MST) -> MST.Binary =
|
||||||
MstRing.binaryOperationFunction(operation)
|
MstRing.binaryOperationFunction(operation)
|
||||||
@ -134,17 +134,17 @@ public object MstExtendedField : ExtendedField<MST>, NumericAlgebra<MST> {
|
|||||||
override fun asinh(arg: MST): MST.Unary = unaryOperationFunction(ExponentialOperations.ASINH_OPERATION)(arg)
|
override fun asinh(arg: MST): MST.Unary = unaryOperationFunction(ExponentialOperations.ASINH_OPERATION)(arg)
|
||||||
override fun acosh(arg: MST): MST.Unary = unaryOperationFunction(ExponentialOperations.ACOSH_OPERATION)(arg)
|
override fun acosh(arg: MST): MST.Unary = unaryOperationFunction(ExponentialOperations.ACOSH_OPERATION)(arg)
|
||||||
override fun atanh(arg: MST): MST.Unary = unaryOperationFunction(ExponentialOperations.ATANH_OPERATION)(arg)
|
override fun atanh(arg: MST): MST.Unary = unaryOperationFunction(ExponentialOperations.ATANH_OPERATION)(arg)
|
||||||
override fun add(a: MST, b: MST): MST.Binary = MstField.add(a, b)
|
override fun add(left: MST, right: MST): MST.Binary = MstField.add(left, right)
|
||||||
override fun sqrt(arg: MST): MST = unaryOperationFunction(PowerOperations.SQRT_OPERATION)(arg)
|
override fun sqrt(arg: MST): MST = unaryOperationFunction(PowerOperations.SQRT_OPERATION)(arg)
|
||||||
|
|
||||||
override fun scale(a: MST, value: Double): MST =
|
override fun scale(a: MST, value: Double): MST =
|
||||||
binaryOperation(GroupOps.PLUS_OPERATION, a, number(value))
|
binaryOperation(GroupOps.PLUS_OPERATION, a, number(value))
|
||||||
|
|
||||||
override fun multiply(a: MST, b: MST): MST.Binary = MstField.multiply(a, b)
|
override fun multiply(left: MST, right: MST): MST.Binary = MstField.multiply(left, right)
|
||||||
override fun divide(a: MST, b: MST): MST.Binary = MstField.divide(a, b)
|
override fun divide(left: MST, right: MST): MST.Binary = MstField.divide(left, right)
|
||||||
override operator fun MST.unaryPlus(): MST.Unary = MstField { +this@unaryPlus }
|
override operator fun MST.unaryPlus(): MST.Unary = MstField { +this@unaryPlus }
|
||||||
override operator fun MST.unaryMinus(): MST.Unary = MstField { -this@unaryMinus }
|
override operator fun MST.unaryMinus(): MST.Unary = MstField { -this@unaryMinus }
|
||||||
override operator fun MST.minus(b: MST): MST.Binary = MstField { this@minus - b }
|
override operator fun MST.minus(other: MST): MST.Binary = MstField { this@minus - other }
|
||||||
|
|
||||||
override fun power(arg: MST, pow: Number): MST.Binary =
|
override fun power(arg: MST, pow: Number): MST.Binary =
|
||||||
binaryOperationFunction(PowerOperations.POW_OPERATION)(arg, number(pow))
|
binaryOperationFunction(PowerOperations.POW_OPERATION)(arg, number(pow))
|
||||||
|
@ -168,22 +168,22 @@ public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
|
|||||||
|
|
||||||
// Basic math (+, -, *, /)
|
// Basic math (+, -, *, /)
|
||||||
|
|
||||||
override fun add(a: AutoDiffValue<T>, b: AutoDiffValue<T>): AutoDiffValue<T> =
|
override fun add(left: AutoDiffValue<T>, right: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
derive(const { a.value + b.value }) { z ->
|
derive(const { left.value + right.value }) { z ->
|
||||||
a.d += z.d
|
left.d += z.d
|
||||||
b.d += z.d
|
right.d += z.d
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun multiply(a: AutoDiffValue<T>, b: AutoDiffValue<T>): AutoDiffValue<T> =
|
override fun multiply(left: AutoDiffValue<T>, right: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
derive(const { a.value * b.value }) { z ->
|
derive(const { left.value * right.value }) { z ->
|
||||||
a.d += z.d * b.value
|
left.d += z.d * right.value
|
||||||
b.d += z.d * a.value
|
right.d += z.d * left.value
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun divide(a: AutoDiffValue<T>, b: AutoDiffValue<T>): AutoDiffValue<T> =
|
override fun divide(left: AutoDiffValue<T>, right: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
derive(const { a.value / b.value }) { z ->
|
derive(const { left.value / right.value }) { z ->
|
||||||
a.d += z.d / b.value
|
left.d += z.d / right.value
|
||||||
b.d -= z.d * a.value / (b.value * b.value)
|
right.d -= z.d * left.value / (right.value * right.value)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun scale(a: AutoDiffValue<T>, value: Double): AutoDiffValue<T> =
|
override fun scale(a: AutoDiffValue<T>, value: Double): AutoDiffValue<T> =
|
||||||
|
@ -100,12 +100,12 @@ public interface GroupOpsND<T, out A : GroupOps<T>> : GroupOps<StructureND<T>>,
|
|||||||
/**
|
/**
|
||||||
* Element-wise addition.
|
* Element-wise addition.
|
||||||
*
|
*
|
||||||
* @param a the augend.
|
* @param left the augend.
|
||||||
* @param b the addend.
|
* @param right the addend.
|
||||||
* @return the sum.
|
* @return the sum.
|
||||||
*/
|
*/
|
||||||
override fun add(a: StructureND<T>, b: StructureND<T>): StructureND<T> =
|
override fun add(left: StructureND<T>, right: StructureND<T>): StructureND<T> =
|
||||||
zip(a, b) { aValue, bValue -> add(aValue, bValue) }
|
zip(left, right) { aValue, bValue -> add(aValue, bValue) }
|
||||||
|
|
||||||
// TODO move to extensions after KEEP-176
|
// TODO move to extensions after KEEP-176
|
||||||
|
|
||||||
@ -134,7 +134,7 @@ public interface GroupOpsND<T, out A : GroupOps<T>> : GroupOps<StructureND<T>>,
|
|||||||
* @param arg the addend.
|
* @param arg the addend.
|
||||||
* @return the sum.
|
* @return the sum.
|
||||||
*/
|
*/
|
||||||
public operator fun T.plus(arg: StructureND<T>): StructureND<T> = arg.map { value -> add(this@plus, value) }
|
public operator fun T.plus(arg: StructureND<T>): StructureND<T> = arg + this
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Subtracts an ND structure from an element of it.
|
* Subtracts an ND structure from an element of it.
|
||||||
@ -162,12 +162,12 @@ public interface RingOpsND<T, out A : RingOps<T>> : RingOps<StructureND<T>>, Gro
|
|||||||
/**
|
/**
|
||||||
* Element-wise multiplication.
|
* Element-wise multiplication.
|
||||||
*
|
*
|
||||||
* @param a the multiplicand.
|
* @param left the multiplicand.
|
||||||
* @param b the multiplier.
|
* @param right the multiplier.
|
||||||
* @return the product.
|
* @return the product.
|
||||||
*/
|
*/
|
||||||
override fun multiply(a: StructureND<T>, b: StructureND<T>): StructureND<T> =
|
override fun multiply(left: StructureND<T>, right: StructureND<T>): StructureND<T> =
|
||||||
zip(a, b) { aValue, bValue -> multiply(aValue, bValue) }
|
zip(left, right) { aValue, bValue -> multiply(aValue, bValue) }
|
||||||
|
|
||||||
//TODO move to extensions after KEEP-176
|
//TODO move to extensions after KEEP-176
|
||||||
|
|
||||||
@ -208,12 +208,12 @@ public interface FieldOpsND<T, out A : Field<T>> : FieldOps<StructureND<T>>, Rin
|
|||||||
/**
|
/**
|
||||||
* Element-wise division.
|
* Element-wise division.
|
||||||
*
|
*
|
||||||
* @param a the dividend.
|
* @param left the dividend.
|
||||||
* @param b the divisor.
|
* @param right the divisor.
|
||||||
* @return the quotient.
|
* @return the quotient.
|
||||||
*/
|
*/
|
||||||
override fun divide(a: StructureND<T>, b: StructureND<T>): StructureND<T> =
|
override fun divide(left: StructureND<T>, right: StructureND<T>): StructureND<T> =
|
||||||
zip(a, b) { aValue, bValue -> divide(aValue, bValue) }
|
zip(left, right) { aValue, bValue -> divide(aValue, bValue) }
|
||||||
|
|
||||||
//TODO move to extensions after https://github.com/Kotlin/KEEP/blob/master/proposals/context-receivers.md
|
//TODO move to extensions after https://github.com/Kotlin/KEEP/blob/master/proposals/context-receivers.md
|
||||||
/**
|
/**
|
||||||
|
@ -73,7 +73,7 @@ internal inline fun <T, A : Algebra<T>> BufferAlgebraND<T, A>.zipInline(
|
|||||||
r: BufferND<T>,
|
r: BufferND<T>,
|
||||||
crossinline block: A.(l: T, r: T) -> T
|
crossinline block: A.(l: T, r: T) -> T
|
||||||
): BufferND<T> {
|
): BufferND<T> {
|
||||||
require(l.indexes == r.indexes)
|
require(l.indexes == r.indexes) { "Zip requires the same shapes, but found ${l.shape} on the left and ${r.shape} on the right" }
|
||||||
val indexes = l.indexes
|
val indexes = l.indexes
|
||||||
return BufferND(indexes, bufferAlgebra.zipInline(l.buffer, r.buffer, block))
|
return BufferND(indexes, bufferAlgebra.zipInline(l.buffer, r.buffer, block))
|
||||||
}
|
}
|
||||||
@ -114,6 +114,10 @@ public fun <T, A : Algebra<T>> BufferAlgebraND<T, A>.produce(
|
|||||||
initializer: A.(IntArray) -> T
|
initializer: A.(IntArray) -> T
|
||||||
): BufferND<T> = produce(shape, initializer)
|
): BufferND<T> = produce(shape, initializer)
|
||||||
|
|
||||||
|
public fun <T, EA : Algebra<T>, A> A.produce(
|
||||||
|
initializer: EA.(IntArray) -> T
|
||||||
|
): BufferND<T> where A : BufferAlgebraND<T, EA>, A : WithShape = produce(shape, initializer)
|
||||||
|
|
||||||
//// group factories
|
//// group factories
|
||||||
//public fun <T, A : Group<T>> A.ndAlgebra(
|
//public fun <T, A : Group<T>> A.ndAlgebra(
|
||||||
// bufferAlgebra: BufferAlgebra<T, A>,
|
// bufferAlgebra: BufferAlgebra<T, A>,
|
||||||
|
@ -20,7 +20,7 @@ import space.kscience.kmath.structures.MutableBufferFactory
|
|||||||
*/
|
*/
|
||||||
public open class BufferND<out T>(
|
public open class BufferND<out T>(
|
||||||
public val indexes: ShapeIndex,
|
public val indexes: ShapeIndex,
|
||||||
public val buffer: Buffer<T>,
|
public open val buffer: Buffer<T>,
|
||||||
) : StructureND<T> {
|
) : StructureND<T> {
|
||||||
|
|
||||||
override operator fun get(index: IntArray): T = buffer[indexes.offset(index)]
|
override operator fun get(index: IntArray): T = buffer[indexes.offset(index)]
|
||||||
@ -55,14 +55,14 @@ public inline fun <T, reified R : Any> StructureND<T>.mapToBuffer(
|
|||||||
*
|
*
|
||||||
* @param T the type of items.
|
* @param T the type of items.
|
||||||
* @param strides The strides to access elements of [MutableBuffer] by linear indices.
|
* @param strides The strides to access elements of [MutableBuffer] by linear indices.
|
||||||
* @param mutableBuffer The underlying buffer.
|
* @param buffer The underlying buffer.
|
||||||
*/
|
*/
|
||||||
public class MutableBufferND<T>(
|
public class MutableBufferND<T>(
|
||||||
strides: ShapeIndex,
|
strides: ShapeIndex,
|
||||||
public val mutableBuffer: MutableBuffer<T>,
|
override val buffer: MutableBuffer<T>,
|
||||||
) : MutableStructureND<T>, BufferND<T>(strides, mutableBuffer) {
|
) : MutableStructureND<T>, BufferND<T>(strides, buffer) {
|
||||||
override fun set(index: IntArray, value: T) {
|
override fun set(index: IntArray, value: T) {
|
||||||
mutableBuffer[indexes.offset(index)] = value
|
buffer[indexes.offset(index)] = value
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -74,7 +74,7 @@ public inline fun <T, reified R : Any> MutableStructureND<T>.mapToMutableBuffer(
|
|||||||
crossinline transform: (T) -> R,
|
crossinline transform: (T) -> R,
|
||||||
): MutableBufferND<R> {
|
): MutableBufferND<R> {
|
||||||
return if (this is MutableBufferND<T>)
|
return if (this is MutableBufferND<T>)
|
||||||
MutableBufferND(this.indexes, factory.invoke(indexes.linearSize) { transform(mutableBuffer[it]) })
|
MutableBufferND(this.indexes, factory.invoke(indexes.linearSize) { transform(buffer[it]) })
|
||||||
else {
|
else {
|
||||||
val strides = DefaultStrides(shape)
|
val strides = DefaultStrides(shape)
|
||||||
MutableBufferND(strides, factory.invoke(strides.linearSize) { transform(get(strides.index(it))) })
|
MutableBufferND(strides, factory.invoke(strides.linearSize) { transform(get(strides.index(it))) })
|
||||||
|
@ -10,42 +10,132 @@ import space.kscience.kmath.operations.*
|
|||||||
import space.kscience.kmath.structures.DoubleBuffer
|
import space.kscience.kmath.structures.DoubleBuffer
|
||||||
import kotlin.contracts.InvocationKind
|
import kotlin.contracts.InvocationKind
|
||||||
import kotlin.contracts.contract
|
import kotlin.contracts.contract
|
||||||
|
import kotlin.math.pow
|
||||||
|
|
||||||
|
public class DoubleBufferND(
|
||||||
|
indexes: ShapeIndex,
|
||||||
|
override val buffer: DoubleBuffer,
|
||||||
|
) : BufferND<Double>(indexes, buffer)
|
||||||
|
|
||||||
|
|
||||||
public sealed class DoubleFieldOpsND : BufferedFieldOpsND<Double, DoubleField>(DoubleField.bufferAlgebra),
|
public sealed class DoubleFieldOpsND : BufferedFieldOpsND<Double, DoubleField>(DoubleField.bufferAlgebra),
|
||||||
ScaleOperations<StructureND<Double>>, ExtendedFieldOps<StructureND<Double>> {
|
ScaleOperations<StructureND<Double>>, ExtendedFieldOps<StructureND<Double>> {
|
||||||
|
|
||||||
override fun StructureND<Double>.toBufferND(): BufferND<Double> = when (this) {
|
override fun StructureND<Double>.toBufferND(): DoubleBufferND = when (this) {
|
||||||
is BufferND -> this
|
is DoubleBufferND -> this
|
||||||
else -> {
|
else -> {
|
||||||
val indexer = indexerBuilder(shape)
|
val indexer = indexerBuilder(shape)
|
||||||
BufferND(indexer, DoubleBuffer(indexer.linearSize) { offset -> get(indexer.index(offset)) })
|
DoubleBufferND(indexer, DoubleBuffer(indexer.linearSize) { offset -> get(indexer.index(offset)) })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//TODO do specialization
|
private inline fun mapInline(
|
||||||
|
arg: DoubleBufferND,
|
||||||
|
transform: (Double) -> Double
|
||||||
|
): DoubleBufferND {
|
||||||
|
val indexes = arg.indexes
|
||||||
|
val array = arg.buffer.array
|
||||||
|
return DoubleBufferND(indexes, DoubleBuffer(indexes.linearSize) { transform(array[it]) })
|
||||||
|
}
|
||||||
|
|
||||||
override fun scale(a: StructureND<Double>, value: Double): BufferND<Double> =
|
private inline fun zipInline(
|
||||||
|
l: DoubleBufferND,
|
||||||
|
r: DoubleBufferND,
|
||||||
|
block: (l: Double, r: Double) -> Double
|
||||||
|
): DoubleBufferND {
|
||||||
|
require(l.indexes == r.indexes) { "Zip requires the same shapes, but found ${l.shape} on the left and ${r.shape} on the right" }
|
||||||
|
val indexes = l.indexes
|
||||||
|
val lArray = l.buffer.array
|
||||||
|
val rArray = r.buffer.array
|
||||||
|
return DoubleBufferND(indexes, DoubleBuffer(indexes.linearSize) { block(lArray[it], rArray[it]) })
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun StructureND<Double>.map(transform: DoubleField.(Double) -> Double): BufferND<Double> =
|
||||||
|
mapInline(toBufferND()) { DoubleField.transform(it) }
|
||||||
|
|
||||||
|
|
||||||
|
override fun zip(
|
||||||
|
left: StructureND<Double>,
|
||||||
|
right: StructureND<Double>,
|
||||||
|
transform: DoubleField.(Double, Double) -> Double
|
||||||
|
): BufferND<Double> = zipInline(left.toBufferND(), right.toBufferND()) { l, r -> DoubleField.transform(l, r) }
|
||||||
|
|
||||||
|
override fun produce(shape: Shape, initializer: DoubleField.(IntArray) -> Double): DoubleBufferND {
|
||||||
|
val indexer = indexerBuilder(shape)
|
||||||
|
return DoubleBufferND(
|
||||||
|
indexer,
|
||||||
|
DoubleBuffer(indexer.linearSize) { offset ->
|
||||||
|
elementAlgebra.initializer(indexer.index(offset))
|
||||||
|
}
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun add(left: StructureND<Double>, right: StructureND<Double>): DoubleBufferND =
|
||||||
|
zipInline(left.toBufferND(), right.toBufferND()) { l, r -> l + r }
|
||||||
|
|
||||||
|
override fun multiply(left: StructureND<Double>, right: StructureND<Double>): DoubleBufferND =
|
||||||
|
zipInline(left.toBufferND(), right.toBufferND()) { l, r -> l * r }
|
||||||
|
|
||||||
|
override fun StructureND<Double>.unaryMinus(): DoubleBufferND = mapInline(toBufferND()) { -it }
|
||||||
|
|
||||||
|
override fun StructureND<Double>.div(other: StructureND<Double>): DoubleBufferND =
|
||||||
|
zipInline(toBufferND(), other.toBufferND()) { l, r -> l / r }
|
||||||
|
|
||||||
|
override fun StructureND<Double>.plus(arg: Double): DoubleBufferND = mapInline(toBufferND()) { it + arg }
|
||||||
|
|
||||||
|
override fun StructureND<Double>.minus(arg: Double): StructureND<Double> = mapInline(toBufferND()) { it - arg }
|
||||||
|
|
||||||
|
override fun Double.plus(arg: StructureND<Double>): StructureND<Double> = arg + this
|
||||||
|
|
||||||
|
override fun Double.minus(arg: StructureND<Double>): StructureND<Double> = mapInline(arg.toBufferND()) { this - it }
|
||||||
|
|
||||||
|
override fun scale(a: StructureND<Double>, value: Double): DoubleBufferND =
|
||||||
mapInline(a.toBufferND()) { it * value }
|
mapInline(a.toBufferND()) { it * value }
|
||||||
|
|
||||||
override fun power(arg: StructureND<Double>, pow: Number): BufferND<Double> =
|
override fun power(arg: StructureND<Double>, pow: Number): DoubleBufferND =
|
||||||
mapInline(arg.toBufferND()) { power(it, pow) }
|
mapInline(arg.toBufferND()) { it.pow(pow.toDouble()) }
|
||||||
|
|
||||||
override fun exp(arg: StructureND<Double>): BufferND<Double> = mapInline(arg.toBufferND()) { exp(it) }
|
override fun exp(arg: StructureND<Double>): DoubleBufferND =
|
||||||
override fun ln(arg: StructureND<Double>): BufferND<Double> = mapInline(arg.toBufferND()) { ln(it) }
|
mapInline(arg.toBufferND()) { kotlin.math.exp(it) }
|
||||||
|
|
||||||
override fun sin(arg: StructureND<Double>): BufferND<Double> = mapInline(arg.toBufferND()) { sin(it) }
|
override fun ln(arg: StructureND<Double>): DoubleBufferND =
|
||||||
override fun cos(arg: StructureND<Double>): BufferND<Double> = mapInline(arg.toBufferND()) { cos(it) }
|
mapInline(arg.toBufferND()) { kotlin.math.ln(it) }
|
||||||
override fun tan(arg: StructureND<Double>): BufferND<Double> = mapInline(arg.toBufferND()) { tan(it) }
|
|
||||||
override fun asin(arg: StructureND<Double>): BufferND<Double> = mapInline(arg.toBufferND()) { asin(it) }
|
|
||||||
override fun acos(arg: StructureND<Double>): BufferND<Double> = mapInline(arg.toBufferND()) { acos(it) }
|
|
||||||
override fun atan(arg: StructureND<Double>): BufferND<Double> = mapInline(arg.toBufferND()) { atan(it) }
|
|
||||||
|
|
||||||
override fun sinh(arg: StructureND<Double>): BufferND<Double> = mapInline(arg.toBufferND()) { sinh(it) }
|
override fun sin(arg: StructureND<Double>): DoubleBufferND =
|
||||||
override fun cosh(arg: StructureND<Double>): BufferND<Double> = mapInline(arg.toBufferND()) { cosh(it) }
|
mapInline(arg.toBufferND()) { kotlin.math.sin(it) }
|
||||||
override fun tanh(arg: StructureND<Double>): BufferND<Double> = mapInline(arg.toBufferND()) { tanh(it) }
|
|
||||||
override fun asinh(arg: StructureND<Double>): BufferND<Double> = mapInline(arg.toBufferND()) { asinh(it) }
|
override fun cos(arg: StructureND<Double>): DoubleBufferND =
|
||||||
override fun acosh(arg: StructureND<Double>): BufferND<Double> = mapInline(arg.toBufferND()) { acosh(it) }
|
mapInline(arg.toBufferND()) { kotlin.math.cos(it) }
|
||||||
override fun atanh(arg: StructureND<Double>): BufferND<Double> = mapInline(arg.toBufferND()) { atanh(it) }
|
|
||||||
|
override fun tan(arg: StructureND<Double>): DoubleBufferND =
|
||||||
|
mapInline(arg.toBufferND()) { kotlin.math.tan(it) }
|
||||||
|
|
||||||
|
override fun asin(arg: StructureND<Double>): DoubleBufferND =
|
||||||
|
mapInline(arg.toBufferND()) { kotlin.math.asin(it) }
|
||||||
|
|
||||||
|
override fun acos(arg: StructureND<Double>): DoubleBufferND =
|
||||||
|
mapInline(arg.toBufferND()) { kotlin.math.acos(it) }
|
||||||
|
|
||||||
|
override fun atan(arg: StructureND<Double>): DoubleBufferND =
|
||||||
|
mapInline(arg.toBufferND()) { kotlin.math.atan(it) }
|
||||||
|
|
||||||
|
override fun sinh(arg: StructureND<Double>): DoubleBufferND =
|
||||||
|
mapInline(arg.toBufferND()) { kotlin.math.sinh(it) }
|
||||||
|
|
||||||
|
override fun cosh(arg: StructureND<Double>): DoubleBufferND =
|
||||||
|
mapInline(arg.toBufferND()) { kotlin.math.cosh(it) }
|
||||||
|
|
||||||
|
override fun tanh(arg: StructureND<Double>): DoubleBufferND =
|
||||||
|
mapInline(arg.toBufferND()) { kotlin.math.tanh(it) }
|
||||||
|
|
||||||
|
override fun asinh(arg: StructureND<Double>): DoubleBufferND =
|
||||||
|
mapInline(arg.toBufferND()) { kotlin.math.asinh(it) }
|
||||||
|
|
||||||
|
override fun acosh(arg: StructureND<Double>): DoubleBufferND =
|
||||||
|
mapInline(arg.toBufferND()) { kotlin.math.acosh(it) }
|
||||||
|
|
||||||
|
override fun atanh(arg: StructureND<Double>): DoubleBufferND =
|
||||||
|
mapInline(arg.toBufferND()) { kotlin.math.atanh(it) }
|
||||||
|
|
||||||
public companion object : DoubleFieldOpsND()
|
public companion object : DoubleFieldOpsND()
|
||||||
}
|
}
|
||||||
@ -54,7 +144,7 @@ public sealed class DoubleFieldOpsND : BufferedFieldOpsND<Double, DoubleField>(D
|
|||||||
public class DoubleFieldND(override val shape: Shape) :
|
public class DoubleFieldND(override val shape: Shape) :
|
||||||
DoubleFieldOpsND(), FieldND<Double, DoubleField>, NumbersAddOps<StructureND<Double>> {
|
DoubleFieldOpsND(), FieldND<Double, DoubleField>, NumbersAddOps<StructureND<Double>> {
|
||||||
|
|
||||||
override fun number(value: Number): BufferND<Double> {
|
override fun number(value: Number): DoubleBufferND {
|
||||||
val d = value.toDouble() // minimize conversions
|
val d = value.toDouble() // minimize conversions
|
||||||
return produce(shape) { d }
|
return produce(shape) { d }
|
||||||
}
|
}
|
||||||
|
@ -121,11 +121,11 @@ public interface GroupOps<T> : Algebra<T> {
|
|||||||
/**
|
/**
|
||||||
* Addition of two elements.
|
* Addition of two elements.
|
||||||
*
|
*
|
||||||
* @param a the augend.
|
* @param left the augend.
|
||||||
* @param b the addend.
|
* @param right the addend.
|
||||||
* @return the sum.
|
* @return the sum.
|
||||||
*/
|
*/
|
||||||
public fun add(a: T, b: T): T
|
public fun add(left: T, right: T): T
|
||||||
|
|
||||||
// Operations to be performed in this context. Could be moved to extensions in case of KEEP-176.
|
// Operations to be performed in this context. Could be moved to extensions in case of KEEP-176.
|
||||||
|
|
||||||
@ -149,19 +149,19 @@ public interface GroupOps<T> : Algebra<T> {
|
|||||||
* Addition of two elements.
|
* Addition of two elements.
|
||||||
*
|
*
|
||||||
* @receiver the augend.
|
* @receiver the augend.
|
||||||
* @param b the addend.
|
* @param other the addend.
|
||||||
* @return the sum.
|
* @return the sum.
|
||||||
*/
|
*/
|
||||||
public operator fun T.plus(b: T): T = add(this, b)
|
public operator fun T.plus(other: T): T = add(this, other)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Subtraction of two elements.
|
* Subtraction of two elements.
|
||||||
*
|
*
|
||||||
* @receiver the minuend.
|
* @receiver the minuend.
|
||||||
* @param b the subtrahend.
|
* @param other the subtrahend.
|
||||||
* @return the difference.
|
* @return the difference.
|
||||||
*/
|
*/
|
||||||
public operator fun T.minus(b: T): T = add(this, -b)
|
public operator fun T.minus(other: T): T = add(this, -other)
|
||||||
// Dynamic dispatch of operations
|
// Dynamic dispatch of operations
|
||||||
override fun unaryOperationFunction(operation: String): (arg: T) -> T = when (operation) {
|
override fun unaryOperationFunction(operation: String): (arg: T) -> T = when (operation) {
|
||||||
PLUS_OPERATION -> { arg -> +arg }
|
PLUS_OPERATION -> { arg -> +arg }
|
||||||
@ -210,18 +210,18 @@ public interface RingOps<T> : GroupOps<T> {
|
|||||||
/**
|
/**
|
||||||
* Multiplies two elements.
|
* Multiplies two elements.
|
||||||
*
|
*
|
||||||
* @param a the multiplier.
|
* @param left the multiplier.
|
||||||
* @param b the multiplicand.
|
* @param right the multiplicand.
|
||||||
*/
|
*/
|
||||||
public fun multiply(a: T, b: T): T
|
public fun multiply(left: T, right: T): T
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Multiplies this element by scalar.
|
* Multiplies this element by scalar.
|
||||||
*
|
*
|
||||||
* @receiver the multiplier.
|
* @receiver the multiplier.
|
||||||
* @param b the multiplicand.
|
* @param other the multiplicand.
|
||||||
*/
|
*/
|
||||||
public operator fun T.times(b: T): T = multiply(this, b)
|
public operator fun T.times(other: T): T = multiply(this, other)
|
||||||
|
|
||||||
override fun binaryOperationFunction(operation: String): (left: T, right: T) -> T = when (operation) {
|
override fun binaryOperationFunction(operation: String): (left: T, right: T) -> T = when (operation) {
|
||||||
TIMES_OPERATION -> ::multiply
|
TIMES_OPERATION -> ::multiply
|
||||||
@ -260,20 +260,20 @@ public interface FieldOps<T> : RingOps<T> {
|
|||||||
/**
|
/**
|
||||||
* Division of two elements.
|
* Division of two elements.
|
||||||
*
|
*
|
||||||
* @param a the dividend.
|
* @param left the dividend.
|
||||||
* @param b the divisor.
|
* @param right the divisor.
|
||||||
* @return the quotient.
|
* @return the quotient.
|
||||||
*/
|
*/
|
||||||
public fun divide(a: T, b: T): T
|
public fun divide(left: T, right: T): T
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Division of two elements.
|
* Division of two elements.
|
||||||
*
|
*
|
||||||
* @receiver the dividend.
|
* @receiver the dividend.
|
||||||
* @param b the divisor.
|
* @param other the divisor.
|
||||||
* @return the quotient.
|
* @return the quotient.
|
||||||
*/
|
*/
|
||||||
public operator fun T.div(b: T): T = divide(this, b)
|
public operator fun T.div(other: T): T = divide(this, other)
|
||||||
|
|
||||||
override fun binaryOperationFunction(operation: String): (left: T, right: T) -> T = when (operation) {
|
override fun binaryOperationFunction(operation: String): (left: T, right: T) -> T = when (operation) {
|
||||||
DIV_OPERATION -> ::divide
|
DIV_OPERATION -> ::divide
|
||||||
|
@ -34,10 +34,10 @@ public object BigIntField : Field<BigInt>, NumbersAddOps<BigInt>, ScaleOperation
|
|||||||
|
|
||||||
@Suppress("EXTENSION_SHADOWED_BY_MEMBER")
|
@Suppress("EXTENSION_SHADOWED_BY_MEMBER")
|
||||||
override fun BigInt.unaryMinus(): BigInt = -this
|
override fun BigInt.unaryMinus(): BigInt = -this
|
||||||
override fun add(a: BigInt, b: BigInt): BigInt = a.plus(b)
|
override fun add(left: BigInt, right: BigInt): BigInt = left.plus(right)
|
||||||
override fun scale(a: BigInt, value: Double): BigInt = a.times(number(value))
|
override fun scale(a: BigInt, value: Double): BigInt = a.times(number(value))
|
||||||
override fun multiply(a: BigInt, b: BigInt): BigInt = a.times(b)
|
override fun multiply(left: BigInt, right: BigInt): BigInt = left.times(right)
|
||||||
override fun divide(a: BigInt, b: BigInt): BigInt = a.div(b)
|
override fun divide(left: BigInt, right: BigInt): BigInt = left.div(right)
|
||||||
|
|
||||||
public operator fun String.unaryPlus(): BigInt = this.parseBigInteger() ?: error("Can't parse $this as big integer")
|
public operator fun String.unaryPlus(): BigInt = this.parseBigInteger() ?: error("Can't parse $this as big integer")
|
||||||
public operator fun String.unaryMinus(): BigInt =
|
public operator fun String.unaryMinus(): BigInt =
|
||||||
|
@ -134,8 +134,8 @@ public open class BufferRingOps<T, A: Ring<T>>(
|
|||||||
override val bufferFactory: BufferFactory<T>,
|
override val bufferFactory: BufferFactory<T>,
|
||||||
) : BufferAlgebra<T, A>, RingOps<Buffer<T>>{
|
) : BufferAlgebra<T, A>, RingOps<Buffer<T>>{
|
||||||
|
|
||||||
override fun add(a: Buffer<T>, b: Buffer<T>): Buffer<T> = zipInline(a, b) { l, r -> l + r }
|
override fun add(left: Buffer<T>, right: Buffer<T>): Buffer<T> = zipInline(left, right) { l, r -> l + r }
|
||||||
override fun multiply(a: Buffer<T>, b: Buffer<T>): Buffer<T> = zipInline(a, b) { l, r -> l * r }
|
override fun multiply(left: Buffer<T>, right: Buffer<T>): Buffer<T> = zipInline(left, right) { l, r -> l * r }
|
||||||
override fun Buffer<T>.unaryMinus(): Buffer<T> = map { -it }
|
override fun Buffer<T>.unaryMinus(): Buffer<T> = map { -it }
|
||||||
|
|
||||||
override fun unaryOperationFunction(operation: String): (arg: Buffer<T>) -> Buffer<T> =
|
override fun unaryOperationFunction(operation: String): (arg: Buffer<T>) -> Buffer<T> =
|
||||||
@ -153,9 +153,9 @@ public open class BufferFieldOps<T, A : Field<T>>(
|
|||||||
bufferFactory: BufferFactory<T>,
|
bufferFactory: BufferFactory<T>,
|
||||||
) : BufferRingOps<T, A>(elementAlgebra, bufferFactory), BufferAlgebra<T, A>, FieldOps<Buffer<T>>, ScaleOperations<Buffer<T>> {
|
) : BufferRingOps<T, A>(elementAlgebra, bufferFactory), BufferAlgebra<T, A>, FieldOps<Buffer<T>>, ScaleOperations<Buffer<T>> {
|
||||||
|
|
||||||
override fun add(a: Buffer<T>, b: Buffer<T>): Buffer<T> = zipInline(a, b) { l, r -> l + r }
|
override fun add(left: Buffer<T>, right: Buffer<T>): Buffer<T> = zipInline(left, right) { l, r -> l + r }
|
||||||
override fun multiply(a: Buffer<T>, b: Buffer<T>): Buffer<T> = zipInline(a, b) { l, r -> l * r }
|
override fun multiply(left: Buffer<T>, right: Buffer<T>): Buffer<T> = zipInline(left, right) { l, r -> l * r }
|
||||||
override fun divide(a: Buffer<T>, b: Buffer<T>): Buffer<T> = zipInline(a, b) { l, r -> l / r }
|
override fun divide(left: Buffer<T>, right: Buffer<T>): Buffer<T> = zipInline(left, right) { l, r -> l / r }
|
||||||
|
|
||||||
override fun scale(a: Buffer<T>, value: Double): Buffer<T> = a.map { scale(it, value) }
|
override fun scale(a: Buffer<T>, value: Double): Buffer<T> = a.map { scale(it, value) }
|
||||||
override fun Buffer<T>.unaryMinus(): Buffer<T> = map { -it }
|
override fun Buffer<T>.unaryMinus(): Buffer<T> = map { -it }
|
||||||
|
@ -15,36 +15,37 @@ import kotlin.math.*
|
|||||||
* [ExtendedFieldOps] over [DoubleBuffer].
|
* [ExtendedFieldOps] over [DoubleBuffer].
|
||||||
*/
|
*/
|
||||||
public abstract class DoubleBufferOps : ExtendedFieldOps<Buffer<Double>>, Norm<Buffer<Double>, Double> {
|
public abstract class DoubleBufferOps : ExtendedFieldOps<Buffer<Double>>, Norm<Buffer<Double>, Double> {
|
||||||
|
|
||||||
override fun Buffer<Double>.unaryMinus(): DoubleBuffer = if (this is DoubleBuffer) {
|
override fun Buffer<Double>.unaryMinus(): DoubleBuffer = if (this is DoubleBuffer) {
|
||||||
DoubleBuffer(size) { -array[it] }
|
DoubleBuffer(size) { -array[it] }
|
||||||
} else {
|
} else {
|
||||||
DoubleBuffer(size) { -get(it) }
|
DoubleBuffer(size) { -get(it) }
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun add(a: Buffer<Double>, b: Buffer<Double>): DoubleBuffer {
|
override fun add(left: Buffer<Double>, right: Buffer<Double>): DoubleBuffer {
|
||||||
require(b.size == a.size) {
|
require(right.size == left.size) {
|
||||||
"The size of the first buffer ${a.size} should be the same as for second one: ${b.size} "
|
"The size of the first buffer ${left.size} should be the same as for second one: ${right.size} "
|
||||||
}
|
}
|
||||||
|
|
||||||
return if (a is DoubleBuffer && b is DoubleBuffer) {
|
return if (left is DoubleBuffer && right is DoubleBuffer) {
|
||||||
val aArray = a.array
|
val aArray = left.array
|
||||||
val bArray = b.array
|
val bArray = right.array
|
||||||
DoubleBuffer(DoubleArray(a.size) { aArray[it] + bArray[it] })
|
DoubleBuffer(DoubleArray(left.size) { aArray[it] + bArray[it] })
|
||||||
} else DoubleBuffer(DoubleArray(a.size) { a[it] + b[it] })
|
} else DoubleBuffer(DoubleArray(left.size) { left[it] + right[it] })
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun Buffer<Double>.plus(b: Buffer<Double>): DoubleBuffer = add(this, b)
|
override fun Buffer<Double>.plus(other: Buffer<Double>): DoubleBuffer = add(this, other)
|
||||||
|
|
||||||
override fun Buffer<Double>.minus(b: Buffer<Double>): DoubleBuffer {
|
override fun Buffer<Double>.minus(other: Buffer<Double>): DoubleBuffer {
|
||||||
require(b.size == this.size) {
|
require(other.size == this.size) {
|
||||||
"The size of the first buffer ${this.size} should be the same as for second one: ${b.size} "
|
"The size of the first buffer ${this.size} should be the same as for second one: ${other.size} "
|
||||||
}
|
}
|
||||||
|
|
||||||
return if (this is DoubleBuffer && b is DoubleBuffer) {
|
return if (this is DoubleBuffer && other is DoubleBuffer) {
|
||||||
val aArray = this.array
|
val aArray = this.array
|
||||||
val bArray = b.array
|
val bArray = other.array
|
||||||
DoubleBuffer(DoubleArray(this.size) { aArray[it] - bArray[it] })
|
DoubleBuffer(DoubleArray(this.size) { aArray[it] - bArray[it] })
|
||||||
} else DoubleBuffer(DoubleArray(this.size) { this[it] - b[it] })
|
} else DoubleBuffer(DoubleArray(this.size) { this[it] - other[it] })
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
@ -66,29 +67,29 @@ public abstract class DoubleBufferOps : ExtendedFieldOps<Buffer<Double>>, Norm<B
|
|||||||
// } else RealBuffer(DoubleArray(a.size) { a[it] / kValue })
|
// } else RealBuffer(DoubleArray(a.size) { a[it] / kValue })
|
||||||
// }
|
// }
|
||||||
|
|
||||||
override fun multiply(a: Buffer<Double>, b: Buffer<Double>): DoubleBuffer {
|
override fun multiply(left: Buffer<Double>, right: Buffer<Double>): DoubleBuffer {
|
||||||
require(b.size == a.size) {
|
require(right.size == left.size) {
|
||||||
"The size of the first buffer ${a.size} should be the same as for second one: ${b.size} "
|
"The size of the first buffer ${left.size} should be the same as for second one: ${right.size} "
|
||||||
}
|
}
|
||||||
|
|
||||||
return if (a is DoubleBuffer && b is DoubleBuffer) {
|
return if (left is DoubleBuffer && right is DoubleBuffer) {
|
||||||
val aArray = a.array
|
val aArray = left.array
|
||||||
val bArray = b.array
|
val bArray = right.array
|
||||||
DoubleBuffer(DoubleArray(a.size) { aArray[it] * bArray[it] })
|
DoubleBuffer(DoubleArray(left.size) { aArray[it] * bArray[it] })
|
||||||
} else
|
} else
|
||||||
DoubleBuffer(DoubleArray(a.size) { a[it] * b[it] })
|
DoubleBuffer(DoubleArray(left.size) { left[it] * right[it] })
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun divide(a: Buffer<Double>, b: Buffer<Double>): DoubleBuffer {
|
override fun divide(left: Buffer<Double>, right: Buffer<Double>): DoubleBuffer {
|
||||||
require(b.size == a.size) {
|
require(right.size == left.size) {
|
||||||
"The size of the first buffer ${a.size} should be the same as for second one: ${b.size} "
|
"The size of the first buffer ${left.size} should be the same as for second one: ${right.size} "
|
||||||
}
|
}
|
||||||
|
|
||||||
return if (a is DoubleBuffer && b is DoubleBuffer) {
|
return if (left is DoubleBuffer && right is DoubleBuffer) {
|
||||||
val aArray = a.array
|
val aArray = left.array
|
||||||
val bArray = b.array
|
val bArray = right.array
|
||||||
DoubleBuffer(DoubleArray(a.size) { aArray[it] / bArray[it] })
|
DoubleBuffer(DoubleArray(left.size) { aArray[it] / bArray[it] })
|
||||||
} else DoubleBuffer(DoubleArray(a.size) { a[it] / b[it] })
|
} else DoubleBuffer(DoubleArray(left.size) { left[it] / right[it] })
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun sin(arg: Buffer<Double>): DoubleBuffer = if (arg is DoubleBuffer) {
|
override fun sin(arg: Buffer<Double>): DoubleBuffer = if (arg is DoubleBuffer) {
|
||||||
|
@ -73,10 +73,10 @@ public object DoubleField : ExtendedField<Double>, Norm<Double, Double>, ScaleOp
|
|||||||
else -> super<ExtendedField>.binaryOperationFunction(operation)
|
else -> super<ExtendedField>.binaryOperationFunction(operation)
|
||||||
}
|
}
|
||||||
|
|
||||||
override inline fun add(a: Double, b: Double): Double = a + b
|
override inline fun add(left: Double, right: Double): Double = left + right
|
||||||
|
|
||||||
override inline fun multiply(a: Double, b: Double): Double = a * b
|
override inline fun multiply(left: Double, right: Double): Double = left * right
|
||||||
override inline fun divide(a: Double, b: Double): Double = a / b
|
override inline fun divide(left: Double, right: Double): Double = left / right
|
||||||
|
|
||||||
override inline fun scale(a: Double, value: Double): Double = a * value
|
override inline fun scale(a: Double, value: Double): Double = a * value
|
||||||
|
|
||||||
@ -102,10 +102,10 @@ public object DoubleField : ExtendedField<Double>, Norm<Double, Double>, ScaleOp
|
|||||||
override inline fun norm(arg: Double): Double = abs(arg)
|
override inline fun norm(arg: Double): Double = abs(arg)
|
||||||
|
|
||||||
override inline fun Double.unaryMinus(): Double = -this
|
override inline fun Double.unaryMinus(): Double = -this
|
||||||
override inline fun Double.plus(b: Double): Double = this + b
|
override inline fun Double.plus(other: Double): Double = this + other
|
||||||
override inline fun Double.minus(b: Double): Double = this - b
|
override inline fun Double.minus(other: Double): Double = this - other
|
||||||
override inline fun Double.times(b: Double): Double = this * b
|
override inline fun Double.times(other: Double): Double = this * other
|
||||||
override inline fun Double.div(b: Double): Double = this / b
|
override inline fun Double.div(other: Double): Double = this / other
|
||||||
}
|
}
|
||||||
|
|
||||||
public val Double.Companion.algebra: DoubleField get() = DoubleField
|
public val Double.Companion.algebra: DoubleField get() = DoubleField
|
||||||
@ -126,12 +126,12 @@ public object FloatField : ExtendedField<Float>, Norm<Float, Float> {
|
|||||||
else -> super.binaryOperationFunction(operation)
|
else -> super.binaryOperationFunction(operation)
|
||||||
}
|
}
|
||||||
|
|
||||||
override inline fun add(a: Float, b: Float): Float = a + b
|
override inline fun add(left: Float, right: Float): Float = left + right
|
||||||
override fun scale(a: Float, value: Double): Float = a * value.toFloat()
|
override fun scale(a: Float, value: Double): Float = a * value.toFloat()
|
||||||
|
|
||||||
override inline fun multiply(a: Float, b: Float): Float = a * b
|
override inline fun multiply(left: Float, right: Float): Float = left * right
|
||||||
|
|
||||||
override inline fun divide(a: Float, b: Float): Float = a / b
|
override inline fun divide(left: Float, right: Float): Float = left / right
|
||||||
|
|
||||||
override inline fun sin(arg: Float): Float = kotlin.math.sin(arg)
|
override inline fun sin(arg: Float): Float = kotlin.math.sin(arg)
|
||||||
override inline fun cos(arg: Float): Float = kotlin.math.cos(arg)
|
override inline fun cos(arg: Float): Float = kotlin.math.cos(arg)
|
||||||
@ -155,10 +155,10 @@ public object FloatField : ExtendedField<Float>, Norm<Float, Float> {
|
|||||||
override inline fun norm(arg: Float): Float = abs(arg)
|
override inline fun norm(arg: Float): Float = abs(arg)
|
||||||
|
|
||||||
override inline fun Float.unaryMinus(): Float = -this
|
override inline fun Float.unaryMinus(): Float = -this
|
||||||
override inline fun Float.plus(b: Float): Float = this + b
|
override inline fun Float.plus(other: Float): Float = this + other
|
||||||
override inline fun Float.minus(b: Float): Float = this - b
|
override inline fun Float.minus(other: Float): Float = this - other
|
||||||
override inline fun Float.times(b: Float): Float = this * b
|
override inline fun Float.times(other: Float): Float = this * other
|
||||||
override inline fun Float.div(b: Float): Float = this / b
|
override inline fun Float.div(other: Float): Float = this / other
|
||||||
}
|
}
|
||||||
|
|
||||||
public val Float.Companion.algebra: FloatField get() = FloatField
|
public val Float.Companion.algebra: FloatField get() = FloatField
|
||||||
@ -175,14 +175,14 @@ public object IntRing : Ring<Int>, Norm<Int, Int>, NumericAlgebra<Int> {
|
|||||||
get() = 1
|
get() = 1
|
||||||
|
|
||||||
override fun number(value: Number): Int = value.toInt()
|
override fun number(value: Number): Int = value.toInt()
|
||||||
override inline fun add(a: Int, b: Int): Int = a + b
|
override inline fun add(left: Int, right: Int): Int = left + right
|
||||||
override inline fun multiply(a: Int, b: Int): Int = a * b
|
override inline fun multiply(left: Int, right: Int): Int = left * right
|
||||||
override inline fun norm(arg: Int): Int = abs(arg)
|
override inline fun norm(arg: Int): Int = abs(arg)
|
||||||
|
|
||||||
override inline fun Int.unaryMinus(): Int = -this
|
override inline fun Int.unaryMinus(): Int = -this
|
||||||
override inline fun Int.plus(b: Int): Int = this + b
|
override inline fun Int.plus(other: Int): Int = this + other
|
||||||
override inline fun Int.minus(b: Int): Int = this - b
|
override inline fun Int.minus(other: Int): Int = this - other
|
||||||
override inline fun Int.times(b: Int): Int = this * b
|
override inline fun Int.times(other: Int): Int = this * other
|
||||||
}
|
}
|
||||||
|
|
||||||
public val Int.Companion.algebra: IntRing get() = IntRing
|
public val Int.Companion.algebra: IntRing get() = IntRing
|
||||||
@ -199,14 +199,14 @@ public object ShortRing : Ring<Short>, Norm<Short, Short>, NumericAlgebra<Short>
|
|||||||
get() = 1
|
get() = 1
|
||||||
|
|
||||||
override fun number(value: Number): Short = value.toShort()
|
override fun number(value: Number): Short = value.toShort()
|
||||||
override inline fun add(a: Short, b: Short): Short = (a + b).toShort()
|
override inline fun add(left: Short, right: Short): Short = (left + right).toShort()
|
||||||
override inline fun multiply(a: Short, b: Short): Short = (a * b).toShort()
|
override inline fun multiply(left: Short, right: Short): Short = (left * right).toShort()
|
||||||
override fun norm(arg: Short): Short = if (arg > 0) arg else (-arg).toShort()
|
override fun norm(arg: Short): Short = if (arg > 0) arg else (-arg).toShort()
|
||||||
|
|
||||||
override inline fun Short.unaryMinus(): Short = (-this).toShort()
|
override inline fun Short.unaryMinus(): Short = (-this).toShort()
|
||||||
override inline fun Short.plus(b: Short): Short = (this + b).toShort()
|
override inline fun Short.plus(other: Short): Short = (this + other).toShort()
|
||||||
override inline fun Short.minus(b: Short): Short = (this - b).toShort()
|
override inline fun Short.minus(other: Short): Short = (this - other).toShort()
|
||||||
override inline fun Short.times(b: Short): Short = (this * b).toShort()
|
override inline fun Short.times(other: Short): Short = (this * other).toShort()
|
||||||
}
|
}
|
||||||
|
|
||||||
public val Short.Companion.algebra: ShortRing get() = ShortRing
|
public val Short.Companion.algebra: ShortRing get() = ShortRing
|
||||||
@ -223,14 +223,14 @@ public object ByteRing : Ring<Byte>, Norm<Byte, Byte>, NumericAlgebra<Byte> {
|
|||||||
get() = 1
|
get() = 1
|
||||||
|
|
||||||
override fun number(value: Number): Byte = value.toByte()
|
override fun number(value: Number): Byte = value.toByte()
|
||||||
override inline fun add(a: Byte, b: Byte): Byte = (a + b).toByte()
|
override inline fun add(left: Byte, right: Byte): Byte = (left + right).toByte()
|
||||||
override inline fun multiply(a: Byte, b: Byte): Byte = (a * b).toByte()
|
override inline fun multiply(left: Byte, right: Byte): Byte = (left * right).toByte()
|
||||||
override fun norm(arg: Byte): Byte = if (arg > 0) arg else (-arg).toByte()
|
override fun norm(arg: Byte): Byte = if (arg > 0) arg else (-arg).toByte()
|
||||||
|
|
||||||
override inline fun Byte.unaryMinus(): Byte = (-this).toByte()
|
override inline fun Byte.unaryMinus(): Byte = (-this).toByte()
|
||||||
override inline fun Byte.plus(b: Byte): Byte = (this + b).toByte()
|
override inline fun Byte.plus(other: Byte): Byte = (this + other).toByte()
|
||||||
override inline fun Byte.minus(b: Byte): Byte = (this - b).toByte()
|
override inline fun Byte.minus(other: Byte): Byte = (this - other).toByte()
|
||||||
override inline fun Byte.times(b: Byte): Byte = (this * b).toByte()
|
override inline fun Byte.times(other: Byte): Byte = (this * other).toByte()
|
||||||
}
|
}
|
||||||
|
|
||||||
public val Byte.Companion.algebra: ByteRing get() = ByteRing
|
public val Byte.Companion.algebra: ByteRing get() = ByteRing
|
||||||
@ -247,14 +247,14 @@ public object LongRing : Ring<Long>, Norm<Long, Long>, NumericAlgebra<Long> {
|
|||||||
get() = 1L
|
get() = 1L
|
||||||
|
|
||||||
override fun number(value: Number): Long = value.toLong()
|
override fun number(value: Number): Long = value.toLong()
|
||||||
override inline fun add(a: Long, b: Long): Long = a + b
|
override inline fun add(left: Long, right: Long): Long = left + right
|
||||||
override inline fun multiply(a: Long, b: Long): Long = a * b
|
override inline fun multiply(left: Long, right: Long): Long = left * right
|
||||||
override fun norm(arg: Long): Long = abs(arg)
|
override fun norm(arg: Long): Long = abs(arg)
|
||||||
|
|
||||||
override inline fun Long.unaryMinus(): Long = (-this)
|
override inline fun Long.unaryMinus(): Long = (-this)
|
||||||
override inline fun Long.plus(b: Long): Long = (this + b)
|
override inline fun Long.plus(other: Long): Long = (this + other)
|
||||||
override inline fun Long.minus(b: Long): Long = (this - b)
|
override inline fun Long.minus(other: Long): Long = (this - other)
|
||||||
override inline fun Long.times(b: Long): Long = (this * b)
|
override inline fun Long.times(other: Long): Long = (this * other)
|
||||||
}
|
}
|
||||||
|
|
||||||
public val Long.Companion.algebra: LongRing get() = LongRing
|
public val Long.Companion.algebra: LongRing get() = LongRing
|
||||||
|
@ -18,9 +18,9 @@ public object JBigIntegerField : Ring<BigInteger>, NumericAlgebra<BigInteger> {
|
|||||||
override val one: BigInteger get() = BigInteger.ONE
|
override val one: BigInteger get() = BigInteger.ONE
|
||||||
|
|
||||||
override fun number(value: Number): BigInteger = BigInteger.valueOf(value.toLong())
|
override fun number(value: Number): BigInteger = BigInteger.valueOf(value.toLong())
|
||||||
override fun add(a: BigInteger, b: BigInteger): BigInteger = a.add(b)
|
override fun add(left: BigInteger, right: BigInteger): BigInteger = left.add(right)
|
||||||
override operator fun BigInteger.minus(b: BigInteger): BigInteger = subtract(b)
|
override operator fun BigInteger.minus(other: BigInteger): BigInteger = subtract(other)
|
||||||
override fun multiply(a: BigInteger, b: BigInteger): BigInteger = a.multiply(b)
|
override fun multiply(left: BigInteger, right: BigInteger): BigInteger = left.multiply(right)
|
||||||
|
|
||||||
override operator fun BigInteger.unaryMinus(): BigInteger = negate()
|
override operator fun BigInteger.unaryMinus(): BigInteger = negate()
|
||||||
}
|
}
|
||||||
@ -39,15 +39,15 @@ public abstract class JBigDecimalFieldBase internal constructor(
|
|||||||
override val one: BigDecimal
|
override val one: BigDecimal
|
||||||
get() = BigDecimal.ONE
|
get() = BigDecimal.ONE
|
||||||
|
|
||||||
override fun add(a: BigDecimal, b: BigDecimal): BigDecimal = a.add(b)
|
override fun add(left: BigDecimal, right: BigDecimal): BigDecimal = left.add(right)
|
||||||
override operator fun BigDecimal.minus(b: BigDecimal): BigDecimal = subtract(b)
|
override operator fun BigDecimal.minus(other: BigDecimal): BigDecimal = subtract(other)
|
||||||
override fun number(value: Number): BigDecimal = BigDecimal.valueOf(value.toDouble())
|
override fun number(value: Number): BigDecimal = BigDecimal.valueOf(value.toDouble())
|
||||||
|
|
||||||
override fun scale(a: BigDecimal, value: Double): BigDecimal =
|
override fun scale(a: BigDecimal, value: Double): BigDecimal =
|
||||||
a.multiply(value.toBigDecimal(mathContext), mathContext)
|
a.multiply(value.toBigDecimal(mathContext), mathContext)
|
||||||
|
|
||||||
override fun multiply(a: BigDecimal, b: BigDecimal): BigDecimal = a.multiply(b, mathContext)
|
override fun multiply(left: BigDecimal, right: BigDecimal): BigDecimal = left.multiply(right, mathContext)
|
||||||
override fun divide(a: BigDecimal, b: BigDecimal): BigDecimal = a.divide(b, mathContext)
|
override fun divide(left: BigDecimal, right: BigDecimal): BigDecimal = left.divide(right, mathContext)
|
||||||
override fun power(arg: BigDecimal, pow: Number): BigDecimal = arg.pow(pow.toInt(), mathContext)
|
override fun power(arg: BigDecimal, pow: Number): BigDecimal = arg.pow(pow.toInt(), mathContext)
|
||||||
override fun sqrt(arg: BigDecimal): BigDecimal = arg.sqrt(mathContext)
|
override fun sqrt(arg: BigDecimal): BigDecimal = arg.sqrt(mathContext)
|
||||||
override operator fun BigDecimal.unaryMinus(): BigDecimal = negate(mathContext)
|
override operator fun BigDecimal.unaryMinus(): BigDecimal = negate(mathContext)
|
||||||
|
@ -104,12 +104,12 @@ public class PolynomialSpace<T, C>(
|
|||||||
Polynomial(coefficients.map { -it })
|
Polynomial(coefficients.map { -it })
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun add(a: Polynomial<T>, b: Polynomial<T>): Polynomial<T> {
|
override fun add(left: Polynomial<T>, right: Polynomial<T>): Polynomial<T> {
|
||||||
val dim = max(a.coefficients.size, b.coefficients.size)
|
val dim = max(left.coefficients.size, right.coefficients.size)
|
||||||
|
|
||||||
return ring {
|
return ring {
|
||||||
Polynomial(List(dim) { index ->
|
Polynomial(List(dim) { index ->
|
||||||
a.coefficients.getOrElse(index) { zero } + b.coefficients.getOrElse(index) { zero }
|
left.coefficients.getOrElse(index) { zero } + right.coefficients.getOrElse(index) { zero }
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -47,7 +47,7 @@ public object Euclidean2DSpace : GeometrySpace<Vector2D>, ScaleOperations<Vector
|
|||||||
override fun Vector2D.unaryMinus(): Vector2D = Vector2D(-x, -y)
|
override fun Vector2D.unaryMinus(): Vector2D = Vector2D(-x, -y)
|
||||||
|
|
||||||
override fun Vector2D.distanceTo(other: Vector2D): Double = (this - other).norm()
|
override fun Vector2D.distanceTo(other: Vector2D): Double = (this - other).norm()
|
||||||
override fun add(a: Vector2D, b: Vector2D): Vector2D = Vector2D(a.x + b.x, a.y + b.y)
|
override fun add(left: Vector2D, right: Vector2D): Vector2D = Vector2D(left.x + right.x, left.y + right.y)
|
||||||
override fun scale(a: Vector2D, value: Double): Vector2D = Vector2D(a.x * value, a.y * value)
|
override fun scale(a: Vector2D, value: Double): Vector2D = Vector2D(a.x * value, a.y * value)
|
||||||
override fun Vector2D.dot(other: Vector2D): Double = x * other.x + y * other.y
|
override fun Vector2D.dot(other: Vector2D): Double = x * other.x + y * other.y
|
||||||
}
|
}
|
||||||
|
@ -47,8 +47,8 @@ public object Euclidean3DSpace : GeometrySpace<Vector3D>, ScaleOperations<Vector
|
|||||||
|
|
||||||
override fun Vector3D.distanceTo(other: Vector3D): Double = (this - other).norm()
|
override fun Vector3D.distanceTo(other: Vector3D): Double = (this - other).norm()
|
||||||
|
|
||||||
override fun add(a: Vector3D, b: Vector3D): Vector3D =
|
override fun add(left: Vector3D, right: Vector3D): Vector3D =
|
||||||
Vector3D(a.x + b.x, a.y + b.y, a.z + b.z)
|
Vector3D(left.x + right.x, left.y + right.y, left.z + right.z)
|
||||||
|
|
||||||
override fun scale(a: Vector3D, value: Double): Vector3D =
|
override fun scale(a: Vector3D, value: Double): Vector3D =
|
||||||
Vector3D(a.x * value, a.y * value, a.z * value)
|
Vector3D(a.x * value, a.y * value, a.z * value)
|
||||||
|
@ -67,10 +67,10 @@ public interface IndexedHistogramSpace<T : Comparable<T>, V : Any>
|
|||||||
|
|
||||||
public fun produce(builder: HistogramBuilder<T>.() -> Unit): IndexedHistogram<T, V>
|
public fun produce(builder: HistogramBuilder<T>.() -> Unit): IndexedHistogram<T, V>
|
||||||
|
|
||||||
override fun add(a: IndexedHistogram<T, V>, b: IndexedHistogram<T, V>): IndexedHistogram<T, V> {
|
override fun add(left: IndexedHistogram<T, V>, right: IndexedHistogram<T, V>): IndexedHistogram<T, V> {
|
||||||
require(a.context == this) { "Can't operate on a histogram produced by external space" }
|
require(left.context == this) { "Can't operate on a histogram produced by external space" }
|
||||||
require(b.context == this) { "Can't operate on a histogram produced by external space" }
|
require(right.context == this) { "Can't operate on a histogram produced by external space" }
|
||||||
return IndexedHistogram(this, histogramValueSpace { a.values + b.values })
|
return IndexedHistogram(this, histogramValueSpace { left.values + right.values })
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun scale(a: IndexedHistogram<T, V>, value: Double): IndexedHistogram<T, V> {
|
override fun scale(a: IndexedHistogram<T, V>, value: Double): IndexedHistogram<T, V> {
|
||||||
|
@ -88,20 +88,20 @@ public class TreeHistogramSpace(
|
|||||||
TreeHistogramBuilder(binFactory).apply(block).build()
|
TreeHistogramBuilder(binFactory).apply(block).build()
|
||||||
|
|
||||||
override fun add(
|
override fun add(
|
||||||
a: UnivariateHistogram,
|
left: UnivariateHistogram,
|
||||||
b: UnivariateHistogram,
|
right: UnivariateHistogram,
|
||||||
): UnivariateHistogram {
|
): UnivariateHistogram {
|
||||||
// require(a.context == this) { "Histogram $a does not belong to this context" }
|
// require(a.context == this) { "Histogram $a does not belong to this context" }
|
||||||
// require(b.context == this) { "Histogram $b does not belong to this context" }
|
// require(b.context == this) { "Histogram $b does not belong to this context" }
|
||||||
val bins = TreeMap<Double, UnivariateBin>().apply {
|
val bins = TreeMap<Double, UnivariateBin>().apply {
|
||||||
(a.bins.map { it.domain } union b.bins.map { it.domain }).forEach { def ->
|
(left.bins.map { it.domain } union right.bins.map { it.domain }).forEach { def ->
|
||||||
put(
|
put(
|
||||||
def.center,
|
def.center,
|
||||||
UnivariateBin(
|
UnivariateBin(
|
||||||
def,
|
def,
|
||||||
value = (a[def.center]?.value ?: 0.0) + (b[def.center]?.value ?: 0.0),
|
value = (left[def.center]?.value ?: 0.0) + (right[def.center]?.value ?: 0.0),
|
||||||
standardDeviation = (a[def.center]?.standardDeviation
|
standardDeviation = (left[def.center]?.standardDeviation
|
||||||
?: 0.0) + (b[def.center]?.standardDeviation ?: 0.0)
|
?: 0.0) + (right[def.center]?.standardDeviation ?: 0.0)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -28,10 +28,10 @@ public object JafamaDoubleField : ExtendedField<Double>, Norm<Double, Double>, S
|
|||||||
else -> super<ExtendedField>.binaryOperationFunction(operation)
|
else -> super<ExtendedField>.binaryOperationFunction(operation)
|
||||||
}
|
}
|
||||||
|
|
||||||
override inline fun add(a: Double, b: Double): Double = a + b
|
override inline fun add(left: Double, right: Double): Double = left + right
|
||||||
|
|
||||||
override inline fun multiply(a: Double, b: Double): Double = a * b
|
override inline fun multiply(left: Double, right: Double): Double = left * right
|
||||||
override inline fun divide(a: Double, b: Double): Double = a / b
|
override inline fun divide(left: Double, right: Double): Double = left / right
|
||||||
|
|
||||||
override inline fun scale(a: Double, value: Double): Double = a * value
|
override inline fun scale(a: Double, value: Double): Double = a * value
|
||||||
|
|
||||||
@ -57,10 +57,10 @@ public object JafamaDoubleField : ExtendedField<Double>, Norm<Double, Double>, S
|
|||||||
override inline fun norm(arg: Double): Double = FastMath.abs(arg)
|
override inline fun norm(arg: Double): Double = FastMath.abs(arg)
|
||||||
|
|
||||||
override inline fun Double.unaryMinus(): Double = -this
|
override inline fun Double.unaryMinus(): Double = -this
|
||||||
override inline fun Double.plus(b: Double): Double = this + b
|
override inline fun Double.plus(other: Double): Double = this + other
|
||||||
override inline fun Double.minus(b: Double): Double = this - b
|
override inline fun Double.minus(other: Double): Double = this - other
|
||||||
override inline fun Double.times(b: Double): Double = this * b
|
override inline fun Double.times(other: Double): Double = this * other
|
||||||
override inline fun Double.div(b: Double): Double = this / b
|
override inline fun Double.div(other: Double): Double = this / other
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -79,10 +79,10 @@ public object StrictJafamaDoubleField : ExtendedField<Double>, Norm<Double, Doub
|
|||||||
else -> super<ExtendedField>.binaryOperationFunction(operation)
|
else -> super<ExtendedField>.binaryOperationFunction(operation)
|
||||||
}
|
}
|
||||||
|
|
||||||
override inline fun add(a: Double, b: Double): Double = a + b
|
override inline fun add(left: Double, right: Double): Double = left + right
|
||||||
|
|
||||||
override inline fun multiply(a: Double, b: Double): Double = a * b
|
override inline fun multiply(left: Double, right: Double): Double = left * right
|
||||||
override inline fun divide(a: Double, b: Double): Double = a / b
|
override inline fun divide(left: Double, right: Double): Double = left / right
|
||||||
|
|
||||||
override inline fun scale(a: Double, value: Double): Double = a * value
|
override inline fun scale(a: Double, value: Double): Double = a * value
|
||||||
|
|
||||||
@ -108,8 +108,8 @@ public object StrictJafamaDoubleField : ExtendedField<Double>, Norm<Double, Doub
|
|||||||
override inline fun norm(arg: Double): Double = StrictFastMath.abs(arg)
|
override inline fun norm(arg: Double): Double = StrictFastMath.abs(arg)
|
||||||
|
|
||||||
override inline fun Double.unaryMinus(): Double = -this
|
override inline fun Double.unaryMinus(): Double = -this
|
||||||
override inline fun Double.plus(b: Double): Double = this + b
|
override inline fun Double.plus(other: Double): Double = this + other
|
||||||
override inline fun Double.minus(b: Double): Double = this - b
|
override inline fun Double.minus(other: Double): Double = this - other
|
||||||
override inline fun Double.times(b: Double): Double = this * b
|
override inline fun Double.times(other: Double): Double = this * other
|
||||||
override inline fun Double.div(b: Double): Double = this / b
|
override inline fun Double.div(other: Double): Double = this / other
|
||||||
}
|
}
|
||||||
|
@ -72,11 +72,11 @@ public sealed interface Nd4jArrayAlgebra<T, out C : Algebra<T>> : AlgebraND<T, C
|
|||||||
*/
|
*/
|
||||||
public sealed interface Nd4jArrayGroupOps<T, out S : Ring<T>> : GroupOpsND<T, S>, Nd4jArrayAlgebra<T, S> {
|
public sealed interface Nd4jArrayGroupOps<T, out S : Ring<T>> : GroupOpsND<T, S>, Nd4jArrayAlgebra<T, S> {
|
||||||
|
|
||||||
override fun add(a: StructureND<T>, b: StructureND<T>): Nd4jArrayStructure<T> =
|
override fun add(left: StructureND<T>, right: StructureND<T>): Nd4jArrayStructure<T> =
|
||||||
a.ndArray.add(b.ndArray).wrap()
|
left.ndArray.add(right.ndArray).wrap()
|
||||||
|
|
||||||
override operator fun StructureND<T>.minus(b: StructureND<T>): Nd4jArrayStructure<T> =
|
override operator fun StructureND<T>.minus(other: StructureND<T>): Nd4jArrayStructure<T> =
|
||||||
ndArray.sub(b.ndArray).wrap()
|
ndArray.sub(other.ndArray).wrap()
|
||||||
|
|
||||||
override operator fun StructureND<T>.unaryMinus(): Nd4jArrayStructure<T> =
|
override operator fun StructureND<T>.unaryMinus(): Nd4jArrayStructure<T> =
|
||||||
ndArray.neg().wrap()
|
ndArray.neg().wrap()
|
||||||
@ -94,8 +94,8 @@ public sealed interface Nd4jArrayGroupOps<T, out S : Ring<T>> : GroupOpsND<T, S>
|
|||||||
@OptIn(UnstableKMathAPI::class)
|
@OptIn(UnstableKMathAPI::class)
|
||||||
public sealed interface Nd4jArrayRingOps<T, out R : Ring<T>> : RingOpsND<T, R>, Nd4jArrayGroupOps<T, R> {
|
public sealed interface Nd4jArrayRingOps<T, out R : Ring<T>> : RingOpsND<T, R>, Nd4jArrayGroupOps<T, R> {
|
||||||
|
|
||||||
override fun multiply(a: StructureND<T>, b: StructureND<T>): Nd4jArrayStructure<T> =
|
override fun multiply(left: StructureND<T>, right: StructureND<T>): Nd4jArrayStructure<T> =
|
||||||
a.ndArray.mul(b.ndArray).wrap()
|
left.ndArray.mul(right.ndArray).wrap()
|
||||||
//
|
//
|
||||||
// override operator fun Nd4jArrayStructure<T>.minus(b: Number): Nd4jArrayStructure<T> {
|
// override operator fun Nd4jArrayStructure<T>.minus(b: Number): Nd4jArrayStructure<T> {
|
||||||
// check(this)
|
// check(this)
|
||||||
@ -132,8 +132,8 @@ public sealed interface Nd4jArrayRingOps<T, out R : Ring<T>> : RingOpsND<T, R>,
|
|||||||
*/
|
*/
|
||||||
public sealed interface Nd4jArrayField<T, out F : Field<T>> : FieldOpsND<T, F>, Nd4jArrayRingOps<T, F> {
|
public sealed interface Nd4jArrayField<T, out F : Field<T>> : FieldOpsND<T, F>, Nd4jArrayRingOps<T, F> {
|
||||||
|
|
||||||
override fun divide(a: StructureND<T>, b: StructureND<T>): Nd4jArrayStructure<T> =
|
override fun divide(left: StructureND<T>, right: StructureND<T>): Nd4jArrayStructure<T> =
|
||||||
a.ndArray.div(b.ndArray).wrap()
|
left.ndArray.div(right.ndArray).wrap()
|
||||||
|
|
||||||
public operator fun Number.div(b: StructureND<T>): Nd4jArrayStructure<T> = b.ndArray.rdiv(this).wrap()
|
public operator fun Number.div(b: StructureND<T>): Nd4jArrayStructure<T> = b.ndArray.rdiv(this).wrap()
|
||||||
|
|
||||||
|
@ -41,8 +41,8 @@ public class SamplerSpace<T : Any, out S>(public val algebra: S) : Group<Sampler
|
|||||||
|
|
||||||
override val zero: Sampler<T> = ConstantSampler(algebra.zero)
|
override val zero: Sampler<T> = ConstantSampler(algebra.zero)
|
||||||
|
|
||||||
override fun add(a: Sampler<T>, b: Sampler<T>): Sampler<T> = BasicSampler { generator ->
|
override fun add(left: Sampler<T>, right: Sampler<T>): Sampler<T> = BasicSampler { generator ->
|
||||||
a.sample(generator).zip(b.sample(generator)) { aValue, bValue -> algebra { aValue + bValue } }
|
left.sample(generator).zip(right.sample(generator)) { aValue, bValue -> algebra { aValue + bValue } }
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun scale(a: Sampler<T>, value: Double): Sampler<T> = BasicSampler { generator ->
|
override fun scale(a: Sampler<T>, value: Double): Sampler<T> = BasicSampler { generator ->
|
||||||
|
@ -5,7 +5,7 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.tensors.api
|
package space.kscience.kmath.tensors.api
|
||||||
|
|
||||||
import space.kscience.kmath.operations.Algebra
|
import space.kscience.kmath.operations.RingOps
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Algebra over a ring on [Tensor].
|
* Algebra over a ring on [Tensor].
|
||||||
@ -13,7 +13,7 @@ import space.kscience.kmath.operations.Algebra
|
|||||||
*
|
*
|
||||||
* @param T the type of items in the tensors.
|
* @param T the type of items in the tensors.
|
||||||
*/
|
*/
|
||||||
public interface TensorAlgebra<T> : Algebra<Tensor<T>> {
|
public interface TensorAlgebra<T> : RingOps<Tensor<T>> {
|
||||||
/**
|
/**
|
||||||
* Returns a single tensor value of unit dimension if tensor shape equals to [1].
|
* Returns a single tensor value of unit dimension if tensor shape equals to [1].
|
||||||
*
|
*
|
||||||
@ -53,7 +53,7 @@ public interface TensorAlgebra<T> : Algebra<Tensor<T>> {
|
|||||||
* @param other tensor to be added.
|
* @param other tensor to be added.
|
||||||
* @return the sum of this tensor and [other].
|
* @return the sum of this tensor and [other].
|
||||||
*/
|
*/
|
||||||
public operator fun Tensor<T>.plus(other: Tensor<T>): Tensor<T>
|
override fun Tensor<T>.plus(other: Tensor<T>): Tensor<T>
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Adds the scalar [value] to each element of this tensor.
|
* Adds the scalar [value] to each element of this tensor.
|
||||||
@ -93,7 +93,7 @@ public interface TensorAlgebra<T> : Algebra<Tensor<T>> {
|
|||||||
* @param other tensor to be subtracted.
|
* @param other tensor to be subtracted.
|
||||||
* @return the difference between this tensor and [other].
|
* @return the difference between this tensor and [other].
|
||||||
*/
|
*/
|
||||||
public operator fun Tensor<T>.minus(other: Tensor<T>): Tensor<T>
|
override fun Tensor<T>.minus(other: Tensor<T>): Tensor<T>
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Subtracts the scalar [value] from each element of this tensor.
|
* Subtracts the scalar [value] from each element of this tensor.
|
||||||
@ -134,7 +134,7 @@ public interface TensorAlgebra<T> : Algebra<Tensor<T>> {
|
|||||||
* @param other tensor to be multiplied.
|
* @param other tensor to be multiplied.
|
||||||
* @return the product of this tensor and [other].
|
* @return the product of this tensor and [other].
|
||||||
*/
|
*/
|
||||||
public operator fun Tensor<T>.times(other: Tensor<T>): Tensor<T>
|
override fun Tensor<T>.times(other: Tensor<T>): Tensor<T>
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Multiplies the scalar [value] by each element of this tensor.
|
* Multiplies the scalar [value] by each element of this tensor.
|
||||||
@ -155,7 +155,7 @@ public interface TensorAlgebra<T> : Algebra<Tensor<T>> {
|
|||||||
*
|
*
|
||||||
* @return tensor negation of the original tensor.
|
* @return tensor negation of the original tensor.
|
||||||
*/
|
*/
|
||||||
public operator fun Tensor<T>.unaryMinus(): Tensor<T>
|
override fun Tensor<T>.unaryMinus(): Tensor<T>
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns the tensor at index i
|
* Returns the tensor at index i
|
||||||
@ -323,4 +323,8 @@ public interface TensorAlgebra<T> : Algebra<Tensor<T>> {
|
|||||||
* @return the index of maximum value of each row of the input tensor in the given dimension [dim].
|
* @return the index of maximum value of each row of the input tensor in the given dimension [dim].
|
||||||
*/
|
*/
|
||||||
public fun Tensor<T>.argMax(dim: Int, keepDim: Boolean): Tensor<T>
|
public fun Tensor<T>.argMax(dim: Int, keepDim: Boolean): Tensor<T>
|
||||||
|
|
||||||
|
override fun add(left: Tensor<T>, right: Tensor<T>): Tensor<T> = left + right
|
||||||
|
|
||||||
|
override fun multiply(left: Tensor<T>, right: Tensor<T>): Tensor<T> = left * right
|
||||||
}
|
}
|
||||||
|
@ -27,7 +27,7 @@ internal fun <T> Tensor<T>.copyToBufferedTensor(): BufferedTensor<T> =
|
|||||||
internal fun <T> Tensor<T>.toBufferedTensor(): BufferedTensor<T> = when (this) {
|
internal fun <T> Tensor<T>.toBufferedTensor(): BufferedTensor<T> = when (this) {
|
||||||
is BufferedTensor<T> -> this
|
is BufferedTensor<T> -> this
|
||||||
is MutableBufferND<T> -> if (this.indexes == TensorLinearStructure(this.shape)) {
|
is MutableBufferND<T> -> if (this.indexes == TensorLinearStructure(this.shape)) {
|
||||||
BufferedTensor(this.shape, this.mutableBuffer, 0)
|
BufferedTensor(this.shape, this.buffer, 0)
|
||||||
} else {
|
} else {
|
||||||
this.copyToBufferedTensor()
|
this.copyToBufferedTensor()
|
||||||
}
|
}
|
||||||
|
@ -88,17 +88,17 @@ public class ViktorFieldND(override val shape: IntArray) : FieldND<Double, Doubl
|
|||||||
}
|
}
|
||||||
}.asStructure()
|
}.asStructure()
|
||||||
|
|
||||||
override fun add(a: StructureND<Double>, b: StructureND<Double>): ViktorStructureND =
|
override fun add(left: StructureND<Double>, right: StructureND<Double>): ViktorStructureND =
|
||||||
(a.f64Buffer + b.f64Buffer).asStructure()
|
(left.f64Buffer + right.f64Buffer).asStructure()
|
||||||
|
|
||||||
override fun scale(a: StructureND<Double>, value: Double): ViktorStructureND =
|
override fun scale(a: StructureND<Double>, value: Double): ViktorStructureND =
|
||||||
(a.f64Buffer * value).asStructure()
|
(a.f64Buffer * value).asStructure()
|
||||||
|
|
||||||
override inline fun StructureND<Double>.plus(b: StructureND<Double>): ViktorStructureND =
|
override inline fun StructureND<Double>.plus(other: StructureND<Double>): ViktorStructureND =
|
||||||
(f64Buffer + b.f64Buffer).asStructure()
|
(f64Buffer + other.f64Buffer).asStructure()
|
||||||
|
|
||||||
override inline fun StructureND<Double>.minus(b: StructureND<Double>): ViktorStructureND =
|
override inline fun StructureND<Double>.minus(other: StructureND<Double>): ViktorStructureND =
|
||||||
(f64Buffer - b.f64Buffer).asStructure()
|
(f64Buffer - other.f64Buffer).asStructure()
|
||||||
|
|
||||||
override inline fun StructureND<Double>.times(k: Number): ViktorStructureND =
|
override inline fun StructureND<Double>.times(k: Number): ViktorStructureND =
|
||||||
(f64Buffer * k.toDouble()).asStructure()
|
(f64Buffer * k.toDouble()).asStructure()
|
||||||
|
Loading…
Reference in New Issue
Block a user