Specify type explicitly EVERYWHERE in kmath-core, add newlines at ends of files, refactor minor problems, improve documentation

This commit is contained in:
Iaroslav 2020-08-05 03:58:00 +07:00
parent 9fded79af0
commit ae7aefeb6a
No known key found for this signature in database
GPG Key ID: 46E15E4A31B3BCD7
59 changed files with 722 additions and 429 deletions

View File

@ -3,13 +3,18 @@ package scientifik.kmath.domains
import scientifik.kmath.linear.Point import scientifik.kmath.linear.Point
/** /**
* A simple geometric domain * A simple geometric domain.
*
* @param T the type of element of this domain.
*/ */
interface Domain<T : Any> { interface Domain<T : Any> {
/**
* Checks if the specified point is contained in this domain.
*/
operator fun contains(point: Point<T>): Boolean operator fun contains(point: Point<T>): Boolean
/** /**
* Number of hyperspace dimensions * Number of hyperspace dimensions.
*/ */
val dimension: Int val dimension: Int
} }

View File

@ -42,13 +42,14 @@ class HyperSquareDomain(private val lower: RealBuffer, private val upper: RealBu
override fun getUpperBound(num: Int): Double? = upper[num] override fun getUpperBound(num: Int): Double? = upper[num]
override fun nearestInDomain(point: Point<Double>): Point<Double> { override fun nearestInDomain(point: Point<Double>): Point<Double> {
val res: DoubleArray = DoubleArray(point.size) { i -> val res = DoubleArray(point.size) { i ->
when { when {
point[i] < lower[i] -> lower[i] point[i] < lower[i] -> lower[i]
point[i] > upper[i] -> upper[i] point[i] > upper[i] -> upper[i]
else -> point[i] else -> point[i]
} }
} }
return RealBuffer(*res) return RealBuffer(*res)
} }
@ -64,4 +65,4 @@ class HyperSquareDomain(private val lower: RealBuffer, private val upper: RealBu
} }
return res return res
} }
} }

View File

@ -22,8 +22,7 @@ import scientifik.kmath.linear.Point
* *
* @author Alexander Nozik * @author Alexander Nozik
*/ */
interface RealDomain: Domain<Double> { interface RealDomain : Domain<Double> {
fun nearestInDomain(point: Point<Double>): Point<Double> fun nearestInDomain(point: Point<Double>): Point<Double>
/** /**
@ -61,5 +60,4 @@ interface RealDomain: Domain<Double> {
* @return * @return
*/ */
fun volume(): Double fun volume(): Double
}
}

View File

@ -18,7 +18,6 @@ package scientifik.kmath.domains
import scientifik.kmath.linear.Point import scientifik.kmath.linear.Point
class UnconstrainedDomain(override val dimension: Int) : RealDomain { class UnconstrainedDomain(override val dimension: Int) : RealDomain {
override operator fun contains(point: Point<Double>): Boolean = true override operator fun contains(point: Point<Double>): Boolean = true
override fun getLowerBound(num: Int, point: Point<Double>): Double? = Double.NEGATIVE_INFINITY override fun getLowerBound(num: Int, point: Point<Double>): Double? = Double.NEGATIVE_INFINITY
@ -32,5 +31,4 @@ class UnconstrainedDomain(override val dimension: Int) : RealDomain {
override fun nearestInDomain(point: Point<Double>): Point<Double> = point override fun nearestInDomain(point: Point<Double>): Point<Double> = point
override fun volume(): Double = Double.POSITIVE_INFINITY override fun volume(): Double = Double.POSITIVE_INFINITY
}
}

View File

@ -4,7 +4,6 @@ import scientifik.kmath.linear.Point
import scientifik.kmath.structures.asBuffer import scientifik.kmath.structures.asBuffer
inline class UnivariateDomain(val range: ClosedFloatingPointRange<Double>) : RealDomain { inline class UnivariateDomain(val range: ClosedFloatingPointRange<Double>) : RealDomain {
operator fun contains(d: Double): Boolean = range.contains(d) operator fun contains(d: Double): Boolean = range.contains(d)
override operator fun contains(point: Point<Double>): Boolean { override operator fun contains(point: Point<Double>): Boolean {
@ -15,10 +14,10 @@ inline class UnivariateDomain(val range: ClosedFloatingPointRange<Double>) : Rea
override fun nearestInDomain(point: Point<Double>): Point<Double> { override fun nearestInDomain(point: Point<Double>): Point<Double> {
require(point.size == 1) require(point.size == 1)
val value = point[0] val value = point[0]
return when{ return when {
value in range -> point value in range -> point
value >= range.endInclusive -> doubleArrayOf(range.endInclusive).asBuffer() value >= range.endInclusive -> doubleArrayOf(range.endInclusive).asBuffer()
else -> doubleArrayOf(range.start).asBuffer() else -> doubleArrayOf(range.start).asBuffer()
} }
} }
@ -45,4 +44,4 @@ inline class UnivariateDomain(val range: ClosedFloatingPointRange<Double>) : Rea
override fun volume(): Double = range.endInclusive - range.start override fun volume(): Double = range.endInclusive - range.start
override val dimension: Int get() = 1 override val dimension: Int get() = 1
} }

View File

@ -14,9 +14,10 @@ interface Expression<T> {
/** /**
* Create simple lazily evaluated expression inside given algebra * Create simple lazily evaluated expression inside given algebra
*/ */
fun <T> Algebra<T>.expression(block: Algebra<T>.(arguments: Map<String, T>) -> T): Expression<T> = object: Expression<T> { fun <T> Algebra<T>.expression(block: Algebra<T>.(arguments: Map<String, T>) -> T): Expression<T> =
override fun invoke(arguments: Map<String, T>): T = block(arguments) object : Expression<T> {
} override fun invoke(arguments: Map<String, T>): T = block(arguments)
}
operator fun <T> Expression<T>.invoke(vararg pairs: Pair<String, T>): T = invoke(mapOf(*pairs)) operator fun <T> Expression<T>.invoke(vararg pairs: Pair<String, T>): T = invoke(mapOf(*pairs))
@ -33,4 +34,4 @@ interface ExpressionAlgebra<T, E> : Algebra<E> {
* A constant expression which does not depend on arguments * A constant expression which does not depend on arguments
*/ */
fun const(value: T): E fun const(value: T): E
} }

View File

@ -143,4 +143,4 @@ inline fun <T, A : Ring<T>> A.expressionInRing(block: FunctionalExpressionRing<T
FunctionalExpressionRing(this).block() FunctionalExpressionRing(this).block()
inline fun <T, A : Field<T>> A.expressionInField(block: FunctionalExpressionField<T, A>.() -> Expression<T>): Expression<T> = inline fun <T, A : Field<T>> A.expressionInField(block: FunctionalExpressionField<T, A>.() -> Expression<T>): Expression<T> =
FunctionalExpressionField(this).block() FunctionalExpressionField(this).block()

View File

@ -19,22 +19,20 @@ class BufferMatrixContext<T : Any, R : Ring<T>>(
override fun point(size: Int, initializer: (Int) -> T): Point<T> = bufferFactory(size, initializer) override fun point(size: Int, initializer: (Int) -> T): Point<T> = bufferFactory(size, initializer)
companion object { companion object
}
} }
@Suppress("OVERRIDE_BY_INLINE") @Suppress("OVERRIDE_BY_INLINE")
object RealMatrixContext : GenericMatrixContext<Double, RealField> { object RealMatrixContext : GenericMatrixContext<Double, RealField> {
override val elementContext get() = RealField override val elementContext: RealField get() = RealField
override inline fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> Double): Matrix<Double> { override inline fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> Double): Matrix<Double> {
val buffer = RealBuffer(rows * columns) { offset -> initializer(offset / columns, offset % columns) } val buffer = RealBuffer(rows * columns) { offset -> initializer(offset / columns, offset % columns) }
return BufferMatrix(rows, columns, buffer) return BufferMatrix(rows, columns, buffer)
} }
override inline fun point(size: Int, initializer: (Int) -> Double): Point<Double> = RealBuffer(size,initializer) override inline fun point(size: Int, initializer: (Int) -> Double): Point<Double> = RealBuffer(size, initializer)
} }
class BufferMatrix<T : Any>( class BufferMatrix<T : Any>(
@ -52,7 +50,7 @@ class BufferMatrix<T : Any>(
override val shape: IntArray get() = intArrayOf(rowNum, colNum) override val shape: IntArray get() = intArrayOf(rowNum, colNum)
override fun suggestFeature(vararg features: MatrixFeature) = override fun suggestFeature(vararg features: MatrixFeature): BufferMatrix<T> =
BufferMatrix(rowNum, colNum, buffer, this.features + features) BufferMatrix(rowNum, colNum, buffer, this.features + features)
override fun get(index: IntArray): T = get(index[0], index[1]) override fun get(index: IntArray): T = get(index[0], index[1])
@ -84,8 +82,8 @@ class BufferMatrix<T : Any>(
override fun toString(): String { override fun toString(): String {
return if (rowNum <= 5 && colNum <= 5) { return if (rowNum <= 5 && colNum <= 5) {
"Matrix(rowsNum = $rowNum, colNum = $colNum, features=$features)\n" + "Matrix(rowsNum = $rowNum, colNum = $colNum, features=$features)\n" +
rows.asSequence().joinToString(prefix = "(", postfix = ")", separator = "\n ") { rows.asSequence().joinToString(prefix = "(", postfix = ")", separator = "\n ") { buffer ->
it.asSequence().joinToString(separator = "\t") { it.toString() } buffer.asSequence().joinToString(separator = "\t") { it.toString() }
} }
} else { } else {
"Matrix(rowsNum = $rowNum, colNum = $colNum, features=$features)" "Matrix(rowsNum = $rowNum, colNum = $colNum, features=$features)"
@ -121,4 +119,4 @@ infix fun BufferMatrix<Double>.dot(other: BufferMatrix<Double>): BufferMatrix<Do
val buffer = RealBuffer(array) val buffer = RealBuffer(array)
return BufferMatrix(rowNum, other.colNum, buffer) return BufferMatrix(rowNum, other.colNum, buffer)
} }

View File

@ -23,12 +23,10 @@ interface FeaturedMatrix<T : Any> : Matrix<T> {
*/ */
fun suggestFeature(vararg features: MatrixFeature): FeaturedMatrix<T> fun suggestFeature(vararg features: MatrixFeature): FeaturedMatrix<T>
companion object { companion object
}
} }
fun Structure2D.Companion.real(rows: Int, columns: Int, initializer: (Int, Int) -> Double) = fun Structure2D.Companion.real(rows: Int, columns: Int, initializer: (Int, Int) -> Double): Matrix<Double> =
MatrixContext.real.produce(rows, columns, initializer) MatrixContext.real.produce(rows, columns, initializer)
/** /**
@ -41,7 +39,7 @@ fun <T : Any> Structure2D.Companion.square(vararg elements: T): FeaturedMatrix<T
return BufferMatrix(size, size, buffer) return BufferMatrix(size, size, buffer)
} }
val Matrix<*>.features get() = (this as? FeaturedMatrix)?.features?: emptySet() val Matrix<*>.features: Set<MatrixFeature> get() = (this as? FeaturedMatrix)?.features ?: emptySet()
/** /**
* Check if matrix has the given feature class * Check if matrix has the given feature class
@ -68,7 +66,7 @@ fun <T : Any, R : Ring<T>> GenericMatrixContext<T, R>.one(rows: Int, columns: In
* A virtual matrix of zeroes * A virtual matrix of zeroes
*/ */
fun <T : Any, R : Ring<T>> GenericMatrixContext<T, R>.zero(rows: Int, columns: Int): FeaturedMatrix<T> = fun <T : Any, R : Ring<T>> GenericMatrixContext<T, R>.zero(rows: Int, columns: Int): FeaturedMatrix<T> =
VirtualMatrix<T>(rows, columns) { _, _ -> elementContext.zero } VirtualMatrix(rows, columns) { _, _ -> elementContext.zero }
class TransposedFeature<T : Any>(val original: Matrix<T>) : MatrixFeature class TransposedFeature<T : Any>(val original: Matrix<T>) : MatrixFeature
@ -83,4 +81,4 @@ fun <T : Any> Matrix<T>.transpose(): Matrix<T> {
) { i, j -> get(j, i) } ) { i, j -> get(j, i) }
} }
infix fun Matrix<Double>.dot(other: Matrix<Double>): Matrix<Double> = with(MatrixContext.real) { dot(other) } infix fun Matrix<Double>.dot(other: Matrix<Double>): Matrix<Double> = with(MatrixContext.real) { dot(other) }

View File

@ -18,7 +18,7 @@ class LUPDecomposition<T : Any>(
private val even: Boolean private val even: Boolean
) : LUPDecompositionFeature<T>, DeterminantFeature<T> { ) : LUPDecompositionFeature<T>, DeterminantFeature<T> {
val elementContext get() = context.elementContext val elementContext: Field<T> get() = context.elementContext
/** /**
* Returns the matrix L of the decomposition. * Returns the matrix L of the decomposition.
@ -67,7 +67,7 @@ class LUPDecomposition<T : Any>(
} }
fun <T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.abs(value: T) = fun <T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.abs(value: T): T =
if (value > elementContext.zero) value else with(elementContext) { -value } if (value > elementContext.zero) value else with(elementContext) { -value }
@ -169,9 +169,10 @@ fun <T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.lup(
inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.lup( inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.lup(
matrix: Matrix<T>, matrix: Matrix<T>,
noinline checkSingular: (T) -> Boolean noinline checkSingular: (T) -> Boolean
) = lup(T::class, matrix, checkSingular) ): LUPDecomposition<T> = lup(T::class, matrix, checkSingular)
fun GenericMatrixContext<Double, RealField>.lup(matrix: Matrix<Double>) = lup(Double::class, matrix) { it < 1e-11 } fun GenericMatrixContext<Double, RealField>.lup(matrix: Matrix<Double>): LUPDecomposition<Double> =
lup(Double::class, matrix) { it < 1e-11 }
fun <T : Any> LUPDecomposition<T>.solve(type: KClass<T>, matrix: Matrix<T>): Matrix<T> { fun <T : Any> LUPDecomposition<T>.solve(type: KClass<T>, matrix: Matrix<T>): Matrix<T> {
@ -185,7 +186,7 @@ fun <T : Any> LUPDecomposition<T>.solve(type: KClass<T>, matrix: Matrix<T>): Mat
// Apply permutations to b // Apply permutations to b
val bp = create { _, _ -> zero } val bp = create { _, _ -> zero }
for (row in 0 until pivot.size) { for (row in pivot.indices) {
val bpRow = bp.row(row) val bpRow = bp.row(row)
val pRow = pivot[row] val pRow = pivot[row]
for (col in 0 until matrix.colNum) { for (col in 0 until matrix.colNum) {
@ -194,7 +195,7 @@ fun <T : Any> LUPDecomposition<T>.solve(type: KClass<T>, matrix: Matrix<T>): Mat
} }
// Solve LY = b // Solve LY = b
for (col in 0 until pivot.size) { for (col in pivot.indices) {
val bpCol = bp.row(col) val bpCol = bp.row(col)
for (i in col + 1 until pivot.size) { for (i in col + 1 until pivot.size) {
val bpI = bp.row(i) val bpI = bp.row(i)
@ -225,7 +226,7 @@ fun <T : Any> LUPDecomposition<T>.solve(type: KClass<T>, matrix: Matrix<T>): Mat
} }
} }
inline fun <reified T : Any> LUPDecomposition<T>.solve(matrix: Matrix<T>) = solve(T::class, matrix) inline fun <reified T : Any> LUPDecomposition<T>.solve(matrix: Matrix<T>): Matrix<T> = solve(T::class, matrix)
/** /**
* Solve a linear equation **a*x = b** * Solve a linear equation **a*x = b**
@ -240,13 +241,12 @@ inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.
return decomposition.solve(T::class, b) return decomposition.solve(T::class, b)
} }
fun RealMatrixContext.solve(a: Matrix<Double>, b: Matrix<Double>) = fun RealMatrixContext.solve(a: Matrix<Double>, b: Matrix<Double>): Matrix<Double> = solve(a, b) { it < 1e-11 }
solve(a, b) { it < 1e-11 }
inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.inverse( inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.inverse(
matrix: Matrix<T>, matrix: Matrix<T>,
noinline checkSingular: (T) -> Boolean noinline checkSingular: (T) -> Boolean
) = solve(matrix, one(matrix.rowNum, matrix.colNum), checkSingular) ): Matrix<T> = solve(matrix, one(matrix.rowNum, matrix.colNum), checkSingular)
fun RealMatrixContext.inverse(matrix: Matrix<Double>) = fun RealMatrixContext.inverse(matrix: Matrix<Double>): Matrix<Double> =
solve(matrix, one(matrix.rowNum, matrix.colNum)) { it < 1e-11 } solve(matrix, one(matrix.rowNum, matrix.colNum)) { it < 1e-11 }

View File

@ -25,4 +25,4 @@ fun <T : Any> Matrix<T>.asPoint(): Point<T> =
error("Can't convert matrix with more than one column to vector") error("Can't convert matrix with more than one column to vector")
} }
fun <T : Any> Point<T>.asMatrix() = VirtualMatrix(size, 1) { i, _ -> get(i) } fun <T : Any> Point<T>.asMatrix(): VirtualMatrix<T> = VirtualMatrix(size, 1) { i, _ -> get(i) }

View File

@ -29,7 +29,7 @@ interface MatrixContext<T : Any> : SpaceOperations<Matrix<T>> {
/** /**
* Non-boxing double matrix * Non-boxing double matrix
*/ */
val real = RealMatrixContext val real: RealMatrixContext = RealMatrixContext
/** /**
* A structured matrix with custom buffer * A structured matrix with custom buffer
@ -82,12 +82,12 @@ interface GenericMatrixContext<T : Any, R : Ring<T>> : MatrixContext<T> {
} }
} }
override operator fun Matrix<T>.unaryMinus() = override operator fun Matrix<T>.unaryMinus(): Matrix<T> =
produce(rowNum, colNum) { i, j -> elementContext.run { -get(i, j) } } produce(rowNum, colNum) { i, j -> elementContext.run { -get(i, j) } }
override fun add(a: Matrix<T>, b: Matrix<T>): Matrix<T> { override fun add(a: Matrix<T>, b: Matrix<T>): Matrix<T> {
if (a.rowNum != b.rowNum || a.colNum != b.colNum) error("Matrix operation dimension mismatch. [${a.rowNum},${a.colNum}] + [${b.rowNum},${b.colNum}]") if (a.rowNum != b.rowNum || a.colNum != b.colNum) error("Matrix operation dimension mismatch. [${a.rowNum},${a.colNum}] + [${b.rowNum},${b.colNum}]")
return produce(a.rowNum, a.colNum) { i, j -> elementContext.run { a.get(i, j) + b[i, j] } } return produce(a.rowNum, a.colNum) { i, j -> elementContext.run { a[i, j] + b[i, j] } }
} }
override operator fun Matrix<T>.minus(b: Matrix<T>): Matrix<T> { override operator fun Matrix<T>.minus(b: Matrix<T>): Matrix<T> {
@ -96,7 +96,7 @@ interface GenericMatrixContext<T : Any, R : Ring<T>> : MatrixContext<T> {
} }
override fun multiply(a: Matrix<T>, k: Number): Matrix<T> = override fun multiply(a: Matrix<T>, k: Number): Matrix<T> =
produce(a.rowNum, a.colNum) { i, j -> elementContext.run { a.get(i, j) * k } } produce(a.rowNum, a.colNum) { i, j -> elementContext.run { a[i, j] * k } }
operator fun Number.times(matrix: FeaturedMatrix<T>): Matrix<T> = matrix * this operator fun Number.times(matrix: FeaturedMatrix<T>): Matrix<T> = matrix * this

View File

@ -1,7 +1,7 @@
package scientifik.kmath.linear package scientifik.kmath.linear
/** /**
* A marker interface representing some matrix feature like diagonal, sparce, zero, etc. Features used to optimize matrix * A marker interface representing some matrix feature like diagonal, sparse, zero, etc. Features used to optimize matrix
* operations performance in some cases. * operations performance in some cases.
*/ */
interface MatrixFeature interface MatrixFeature
@ -36,19 +36,19 @@ interface DeterminantFeature<T : Any> : MatrixFeature {
} }
@Suppress("FunctionName") @Suppress("FunctionName")
fun <T: Any> DeterminantFeature(determinant: T) = object: DeterminantFeature<T>{ fun <T : Any> DeterminantFeature(determinant: T): DeterminantFeature<T> = object : DeterminantFeature<T> {
override val determinant: T = determinant override val determinant: T = determinant
} }
/** /**
* Lower triangular matrix * Lower triangular matrix
*/ */
object LFeature: MatrixFeature object LFeature : MatrixFeature
/** /**
* Upper triangular feature * Upper triangular feature
*/ */
object UFeature: MatrixFeature object UFeature : MatrixFeature
/** /**
* TODO add documentation * TODO add documentation
@ -59,4 +59,4 @@ interface LUPDecompositionFeature<T : Any> : MatrixFeature {
val p: FeaturedMatrix<T> val p: FeaturedMatrix<T>
} }
//TODO add sparse matrix feature //TODO add sparse matrix feature

View File

@ -54,7 +54,7 @@ interface VectorSpace<T : Any, S : Space<T>> : Space<Point<T>> {
size: Int, size: Int,
space: S, space: S,
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing bufferFactory: BufferFactory<T> = Buffer.Companion::boxing
) = BufferVectorSpace(size, space, bufferFactory) ): BufferVectorSpace<T, S> = BufferVectorSpace(size, space, bufferFactory)
/** /**
* Automatic buffered vector, unboxed if it is possible * Automatic buffered vector, unboxed if it is possible
@ -70,6 +70,6 @@ class BufferVectorSpace<T : Any, S : Space<T>>(
override val space: S, override val space: S,
val bufferFactory: BufferFactory<T> val bufferFactory: BufferFactory<T>
) : VectorSpace<T, S> { ) : VectorSpace<T, S> {
override fun produce(initializer: (Int) -> T) = bufferFactory(size, initializer) override fun produce(initializer: (Int) -> T): Buffer<T> = bufferFactory(size, initializer)
//override fun produceElement(initializer: (Int) -> T): Vector<T, S> = BufferVector(this, produce(initializer)) //override fun produceElement(initializer: (Int) -> T): Vector<T, S> = BufferVector(this, produce(initializer))
} }

View File

@ -20,7 +20,7 @@ class VirtualMatrix<T : Any>(
override fun get(i: Int, j: Int): T = generator(i, j) override fun get(i: Int, j: Int): T = generator(i, j)
override fun suggestFeature(vararg features: MatrixFeature) = override fun suggestFeature(vararg features: MatrixFeature): VirtualMatrix<T> =
VirtualMatrix(rowNum, colNum, this.features + features, generator) VirtualMatrix(rowNum, colNum, this.features + features, generator)
override fun equals(other: Any?): Boolean { override fun equals(other: Any?): Boolean {
@ -56,4 +56,4 @@ class VirtualMatrix<T : Any>(
} }
} }
} }
} }

View File

@ -22,12 +22,12 @@ class DerivationResult<T : Any>(
val deriv: Map<Variable<T>, T>, val deriv: Map<Variable<T>, T>,
val context: Field<T> val context: Field<T>
) : Variable<T>(value) { ) : Variable<T>(value) {
fun deriv(variable: Variable<T>) = deriv[variable] ?: context.zero fun deriv(variable: Variable<T>): T = deriv[variable] ?: context.zero
/** /**
* compute divergence * compute divergence
*/ */
fun div() = context.run { sum(deriv.values) } fun div(): T = context.run { sum(deriv.values) }
/** /**
* Compute a gradient for variables in given order * Compute a gradient for variables in given order
@ -53,7 +53,7 @@ class DerivationResult<T : Any>(
* ``` * ```
*/ */
fun <T : Any, F : Field<T>> F.deriv(body: AutoDiffField<T, F>.() -> Variable<T>): DerivationResult<T> = fun <T : Any, F : Field<T>> F.deriv(body: AutoDiffField<T, F>.() -> Variable<T>): DerivationResult<T> =
AutoDiffContext<T, F>(this).run { AutoDiffContext(this).run {
val result = body() val result = body()
result.d = context.one// computing derivative w.r.t result result.d = context.one// computing derivative w.r.t result
runBackwardPass() runBackwardPass()
@ -86,7 +86,7 @@ abstract class AutoDiffField<T : Any, F : Field<T>> : Field<Variable<T>> {
abstract fun variable(value: T): Variable<T> abstract fun variable(value: T): Variable<T>
inline fun variable(block: F.() -> T) = variable(context.block()) inline fun variable(block: F.() -> T): Variable<T> = variable(context.block())
// Overloads for Double constants // Overloads for Double constants
@ -236,4 +236,4 @@ fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sin(x: Variable<T>): Var
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.cos(x: Variable<T>): Variable<T> = fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.cos(x: Variable<T>): Variable<T> =
derive(variable { cos(x.value) }) { z -> derive(variable { cos(x.value) }) { z ->
x.d -= z.d * sin(x.value) x.d -= z.d * sin(x.value)
} }

View File

@ -43,4 +43,4 @@ fun ClosedFloatingPointRange<Double>.toSequenceWithPoints(numPoints: Int): Seque
fun ClosedFloatingPointRange<Double>.toGrid(numPoints: Int): DoubleArray { fun ClosedFloatingPointRange<Double>.toGrid(numPoints: Int): DoubleArray {
if (numPoints < 2) error("Can't create generic grid with less than two points") if (numPoints < 2) error("Can't create generic grid with less than two points")
return DoubleArray(numPoints) { i -> start + (endInclusive - start) / (numPoints - 1) * i } return DoubleArray(numPoints) { i -> start + (endInclusive - start) / (numPoints - 1) * i }
} }

View File

@ -1,14 +1,15 @@
package scientifik.kmath.misc package scientifik.kmath.misc
import scientifik.kmath.operations.Space import scientifik.kmath.operations.Space
import scientifik.kmath.operations.invoke
import kotlin.jvm.JvmName import kotlin.jvm.JvmName
/** /**
* Generic cumulative operation on iterator * Generic cumulative operation on iterator.
* @param T type of initial iterable *
* @param R type of resulting iterable * @param T the type of initial iterable.
* @param initial lazy evaluated * @param R the type of resulting iterable.
* @param initial lazy evaluated.
*/ */
fun <T, R> Iterator<T>.cumulative(initial: R, operation: (R, T) -> R): Iterator<R> = object : Iterator<R> { fun <T, R> Iterator<T>.cumulative(initial: R, operation: (R, T) -> R): Iterator<R> = object : Iterator<R> {
var state: R = initial var state: R = initial
@ -36,41 +37,41 @@ fun <T, R> List<T>.cumulative(initial: R, operation: (R, T) -> R): List<R> =
/** /**
* Cumulative sum with custom space * Cumulative sum with custom space
*/ */
fun <T> Iterable<T>.cumulativeSum(space: Space<T>) = with(space) { fun <T> Iterable<T>.cumulativeSum(space: Space<T>): Iterable<T> = space {
cumulative(zero) { element: T, sum: T -> sum + element } cumulative(zero) { element: T, sum: T -> sum + element }
} }
@JvmName("cumulativeSumOfDouble") @JvmName("cumulativeSumOfDouble")
fun Iterable<Double>.cumulativeSum() = this.cumulative(0.0) { element, sum -> sum + element } fun Iterable<Double>.cumulativeSum(): Iterable<Double> = this.cumulative(0.0) { element, sum -> sum + element }
@JvmName("cumulativeSumOfInt") @JvmName("cumulativeSumOfInt")
fun Iterable<Int>.cumulativeSum() = this.cumulative(0) { element, sum -> sum + element } fun Iterable<Int>.cumulativeSum(): Iterable<Int> = this.cumulative(0) { element, sum -> sum + element }
@JvmName("cumulativeSumOfLong") @JvmName("cumulativeSumOfLong")
fun Iterable<Long>.cumulativeSum() = this.cumulative(0L) { element, sum -> sum + element } fun Iterable<Long>.cumulativeSum(): Iterable<Long> = this.cumulative(0L) { element, sum -> sum + element }
fun <T> Sequence<T>.cumulativeSum(space: Space<T>) = with(space) { fun <T> Sequence<T>.cumulativeSum(space: Space<T>): Sequence<T> = with(space) {
cumulative(zero) { element: T, sum: T -> sum + element } cumulative(zero) { element: T, sum: T -> sum + element }
} }
@JvmName("cumulativeSumOfDouble") @JvmName("cumulativeSumOfDouble")
fun Sequence<Double>.cumulativeSum() = this.cumulative(0.0) { element, sum -> sum + element } fun Sequence<Double>.cumulativeSum(): Sequence<Double> = this.cumulative(0.0) { element, sum -> sum + element }
@JvmName("cumulativeSumOfInt") @JvmName("cumulativeSumOfInt")
fun Sequence<Int>.cumulativeSum() = this.cumulative(0) { element, sum -> sum + element } fun Sequence<Int>.cumulativeSum(): Sequence<Int> = this.cumulative(0) { element, sum -> sum + element }
@JvmName("cumulativeSumOfLong") @JvmName("cumulativeSumOfLong")
fun Sequence<Long>.cumulativeSum() = this.cumulative(0L) { element, sum -> sum + element } fun Sequence<Long>.cumulativeSum(): Sequence<Long> = this.cumulative(0L) { element, sum -> sum + element }
fun <T> List<T>.cumulativeSum(space: Space<T>) = with(space) { fun <T> List<T>.cumulativeSum(space: Space<T>): List<T> = with(space) {
cumulative(zero) { element: T, sum: T -> sum + element } cumulative(zero) { element: T, sum: T -> sum + element }
} }
@JvmName("cumulativeSumOfDouble") @JvmName("cumulativeSumOfDouble")
fun List<Double>.cumulativeSum() = this.cumulative(0.0) { element, sum -> sum + element } fun List<Double>.cumulativeSum(): List<Double> = this.cumulative(0.0) { element, sum -> sum + element }
@JvmName("cumulativeSumOfInt") @JvmName("cumulativeSumOfInt")
fun List<Int>.cumulativeSum() = this.cumulative(0) { element, sum -> sum + element } fun List<Int>.cumulativeSum(): List<Int> = this.cumulative(0) { element, sum -> sum + element }
@JvmName("cumulativeSumOfLong") @JvmName("cumulativeSumOfLong")
fun List<Long>.cumulativeSum() = this.cumulative(0L) { element, sum -> sum + element } fun List<Long>.cumulativeSum(): List<Long> = this.cumulative(0L) { element, sum -> sum + element }

View File

@ -1,10 +1,15 @@
package scientifik.kmath.operations package scientifik.kmath.operations
/**
* Stub for DSL the [Algebra] is.
*/
@DslMarker @DslMarker
annotation class KMathContext annotation class KMathContext
/** /**
* Marker interface for any algebra * Represents an algebraic structure.
*
* @param T the type of element of this structure.
*/ */
interface Algebra<T> { interface Algebra<T> {
/** /**
@ -24,50 +29,122 @@ interface Algebra<T> {
} }
/** /**
* An algebra with numeric representation of members * An algebraic structure where elements can have numeric representation.
*
* @param T the type of element of this structure.
*/ */
interface NumericAlgebra<T> : Algebra<T> { interface NumericAlgebra<T> : Algebra<T> {
/** /**
* Wrap a number * Wraps a number.
*/ */
fun number(value: Number): T fun number(value: Number): T
/**
* Dynamic call of binary operation [operation] on [left] and [right] where left element is [Number].
*/
fun leftSideNumberOperation(operation: String, left: Number, right: T): T = fun leftSideNumberOperation(operation: String, left: Number, right: T): T =
binaryOperation(operation, number(left), right) binaryOperation(operation, number(left), right)
/**
* Dynamic call of binary operation [operation] on [left] and [right] where right element is [Number].
*/
fun rightSideNumberOperation(operation: String, left: T, right: Number): T = fun rightSideNumberOperation(operation: String, left: T, right: Number): T =
leftSideNumberOperation(operation, right, left) leftSideNumberOperation(operation, right, left)
} }
/** /**
* Call a block with an [Algebra] as receiver * Call a block with an [Algebra] as receiver.
*/ */
inline operator fun <A : Algebra<*>, R> A.invoke(block: A.() -> R): R = run(block) inline operator fun <A : Algebra<*>, R> A.invoke(block: A.() -> R): R = run(block)
/** /**
* Space-like operations without neutral element * Represents semigroup, i.e. algebraic structure with associative binary operation called "addition".
*
* In KMath groups are called spaces, and also define multiplication of element by [Number].
*
* @param T the type of element of this semigroup.
*/ */
interface SpaceOperations<T> : Algebra<T> { interface SpaceOperations<T> : Algebra<T> {
/** /**
* Addition operation for two context elements * Addition of two elements.
*
* @param a the addend.
* @param b the augend.
* @return the sum.
*/ */
fun add(a: T, b: T): T fun add(a: T, b: T): T
/** /**
* Multiplication operation for context element and real number * Multiplication of element by scalar.
*
* @param a the multiplier.
* @param k the multiplicand.
* @return the produce.
*/ */
fun multiply(a: T, k: Number): T fun multiply(a: T, k: Number): T
//Operation to be performed in this context. Could be moved to extensions in case of KEEP-176 // Operations to be performed in this context. Could be moved to extensions in case of KEEP-176
/**
* The negation of this element.
*
* @receiver this value.
* @return the additive inverse of this value.
*/
operator fun T.unaryMinus(): T = multiply(this, -1.0) operator fun T.unaryMinus(): T = multiply(this, -1.0)
/**
* Returns this value.
*
* @receiver this value.
* @return this value.
*/
operator fun T.unaryPlus(): T = this operator fun T.unaryPlus(): T = this
/**
* Addition of two elements.
*
* @receiver the addend.
* @param b the augend.
* @return the sum.
*/
operator fun T.plus(b: T): T = add(this, b) operator fun T.plus(b: T): T = add(this, b)
/**
* Subtraction of two elements.
*
* @receiver the minuend.
* @param b the subtrahend.
* @return the difference.
*/
operator fun T.minus(b: T): T = add(this, -b) operator fun T.minus(b: T): T = add(this, -b)
operator fun T.times(k: Number) = multiply(this, k.toDouble())
operator fun T.div(k: Number) = multiply(this, 1.0 / k.toDouble()) /**
operator fun Number.times(b: T) = b * this * Multiplication of this element by a scalar.
*
* @receiver the multiplier.
* @param k the multiplicand.
* @return the product.
*/
operator fun T.times(k: Number): T = multiply(this, k.toDouble())
/**
* Division of this element by scalar.
*
* @receiver the dividend.
* @param k the divisor.
* @return the quotient.
*/
operator fun T.div(k: Number): T = multiply(this, 1.0 / k.toDouble())
/**
* Multiplication of this number by element.
*
* @receiver the multiplier.
* @param b the multiplicand.
* @return the product.
*/
operator fun Number.times(b: T): T = b * this
override fun unaryOperation(operation: String, arg: T): T = when (operation) { override fun unaryOperation(operation: String, arg: T): T = when (operation) {
PLUS_OPERATION -> arg PLUS_OPERATION -> arg
@ -82,37 +159,56 @@ interface SpaceOperations<T> : Algebra<T> {
} }
companion object { companion object {
const val PLUS_OPERATION = "+" /**
const val MINUS_OPERATION = "-" * The identifier of addition.
const val NOT_OPERATION = "!" */
const val PLUS_OPERATION: String = "+"
/**
* The identifier of subtraction (and negation).
*/
const val MINUS_OPERATION: String = "-"
const val NOT_OPERATION: String = "!"
} }
} }
/** /**
* A general interface representing linear context of some kind. * Represents group, i.e. algebraic structure with associative binary operation called "addition" and its neutral
* The context defines sum operation for its elements and multiplication by real value. * element.
* One must note that in some cases context is a singleton class, but in some cases it
* works as a context for operations inside it.
* *
* TODO do we need non-commutative context? * In KMath groups are called spaces, and also define multiplication of element by [Number].
*
* @param T the type of element of this group.
*/ */
interface Space<T> : SpaceOperations<T> { interface Space<T> : SpaceOperations<T> {
/** /**
* Neutral element for sum operation * The neutral element of addition.
*/ */
val zero: T val zero: T
} }
/** /**
* Operations on ring without multiplication neutral element * Represents semiring, i.e. algebraic structure with two associative binary operations called "addition" and
* "multiplication".
*
* @param T the type of element of this semiring.
*/ */
interface RingOperations<T> : SpaceOperations<T> { interface RingOperations<T> : SpaceOperations<T> {
/** /**
* Multiplication for two field elements * Multiplies two elements.
*
* @param a the multiplier.
* @param b the multiplicand.
*/ */
fun multiply(a: T, b: T): T fun multiply(a: T, b: T): T
/**
* Multiplies this element by scalar.
*
* @receiver the multiplier.
* @param b the multiplicand.
*/
operator fun T.times(b: T): T = multiply(this, b) operator fun T.times(b: T): T = multiply(this, b)
override fun binaryOperation(operation: String, left: T, right: T): T = when (operation) { override fun binaryOperation(operation: String, left: T, right: T): T = when (operation) {
@ -121,12 +217,18 @@ interface RingOperations<T> : SpaceOperations<T> {
} }
companion object { companion object {
const val TIMES_OPERATION = "*" /**
* The identifier of multiplication.
*/
const val TIMES_OPERATION: String = "*"
} }
} }
/** /**
* The same as {@link Space} but with additional multiplication operation * Represents ring, i.e. algebraic structure with two associative binary operations called "addition" and
* "multiplication" and their neutral elements.
*
* @param T the type of element of this ring.
*/ */
interface Ring<T> : Space<T>, RingOperations<T>, NumericAlgebra<T> { interface Ring<T> : Space<T>, RingOperations<T>, NumericAlgebra<T> {
/** /**
@ -150,20 +252,64 @@ interface Ring<T> : Space<T>, RingOperations<T>, NumericAlgebra<T> {
else -> super.rightSideNumberOperation(operation, left, right) else -> super.rightSideNumberOperation(operation, left, right)
} }
/**
* Addition of element and scalar.
*
* @receiver the addend.
* @param b the augend.
*/
operator fun T.plus(b: Number): T = this + number(b)
operator fun T.plus(b: Number) = this.plus(number(b)) /**
operator fun Number.plus(b: T) = b + this * Addition of scalar and element.
*
* @receiver the addend.
* @param b the augend.
*/
operator fun Number.plus(b: T): T = b + this
operator fun T.minus(b: Number) = this.minus(number(b)) /**
operator fun Number.minus(b: T) = -b + this * Subtraction of element from number.
*
* @receiver the minuend.
* @param b the subtrahend.
* @receiver the difference.
*/
operator fun T.minus(b: Number): T = this - number(b)
/**
* Subtraction of number from element.
*
* @receiver the minuend.
* @param b the subtrahend.
* @receiver the difference.
*/
operator fun Number.minus(b: T): T = -b + this
} }
/** /**
* All ring operations but without neutral elements * Represents semifield, i.e. algebraic structure with three operations: associative "addition" and "multiplication",
* and "division".
*
* @param T the type of element of this semifield.
*/ */
interface FieldOperations<T> : RingOperations<T> { interface FieldOperations<T> : RingOperations<T> {
/**
* Division of two elements.
*
* @param a the dividend.
* @param b the divisor.
* @return the quotient.
*/
fun divide(a: T, b: T): T fun divide(a: T, b: T): T
/**
* Division of two elements.
*
* @receiver the dividend.
* @param b the divisor.
* @return the quotient.
*/
operator fun T.div(b: T): T = divide(this, b) operator fun T.div(b: T): T = divide(this, b)
override fun binaryOperation(operation: String, left: T, right: T): T = when (operation) { override fun binaryOperation(operation: String, left: T, right: T): T = when (operation) {
@ -172,13 +318,26 @@ interface FieldOperations<T> : RingOperations<T> {
} }
companion object { companion object {
const val DIV_OPERATION = "/" /**
* The identifier of division.
*/
const val DIV_OPERATION: String = "/"
} }
} }
/** /**
* Four operations algebra * Represents field, i.e. algebraic structure with three operations: associative "addition" and "multiplication",
* and "division" and their neutral elements.
*
* @param T the type of element of this semifield.
*/ */
interface Field<T> : Ring<T>, FieldOperations<T> { interface Field<T> : Ring<T>, FieldOperations<T> {
operator fun Number.div(b: T) = this * divide(one, b) /**
* Division of element by scalar.
*
* @receiver the dividend.
* @param b the divisor.
* @return the quotient.
*/
operator fun Number.div(b: T): T = this * divide(one, b)
} }

View File

@ -2,13 +2,12 @@ package scientifik.kmath.operations
/** /**
* The generic mathematics elements which is able to store its context * The generic mathematics elements which is able to store its context
* @param T the type of space operation results *
* @param I self type of the element. Needed for static type checking * @param C the type of mathematical context for this element.
* @param C the type of mathematical context for this element
*/ */
interface MathElement<C> { interface MathElement<C> {
/** /**
* The context this element belongs to * The context this element belongs to.
*/ */
val context: C val context: C
} }
@ -25,18 +24,17 @@ interface MathWrapper<T, I> {
* @param S the type of space * @param S the type of space
*/ */
interface SpaceElement<T, I : SpaceElement<T, I, S>, S : Space<T>> : MathElement<S>, MathWrapper<T, I> { interface SpaceElement<T, I : SpaceElement<T, I, S>, S : Space<T>> : MathElement<S>, MathWrapper<T, I> {
operator fun plus(b: T): I = context.add(unwrap(), b).wrap()
operator fun plus(b: T) = context.add(unwrap(), b).wrap() operator fun minus(b: T): I = context.add(unwrap(), context.multiply(b, -1.0)).wrap()
operator fun minus(b: T) = context.add(unwrap(), context.multiply(b, -1.0)).wrap() operator fun times(k: Number): I = context.multiply(unwrap(), k.toDouble()).wrap()
operator fun times(k: Number) = context.multiply(unwrap(), k.toDouble()).wrap() operator fun div(k: Number): I = context.multiply(unwrap(), 1.0 / k.toDouble()).wrap()
operator fun div(k: Number) = context.multiply(unwrap(), 1.0 / k.toDouble()).wrap()
} }
/** /**
* Ring element * Ring element
*/ */
interface RingElement<T, I : RingElement<T, I, R>, R : Ring<T>> : SpaceElement<T, I, R> { interface RingElement<T, I : RingElement<T, I, R>, R : Ring<T>> : SpaceElement<T, I, R> {
operator fun times(b: T) = context.multiply(unwrap(), b).wrap() operator fun times(b: T): I = context.multiply(unwrap(), b).wrap()
} }
/** /**
@ -44,5 +42,5 @@ interface RingElement<T, I : RingElement<T, I, R>, R : Ring<T>> : SpaceElement<T
*/ */
interface FieldElement<T, I : FieldElement<T, I, F>, F : Field<T>> : RingElement<T, I, F> { interface FieldElement<T, I : FieldElement<T, I, F>, F : Field<T>> : RingElement<T, I, F> {
override val context: F override val context: F
operator fun div(b: T) = context.divide(unwrap(), b).wrap() operator fun div(b: T): I = context.divide(unwrap(), b).wrap()
} }

View File

@ -12,4 +12,4 @@ fun <T> RingOperations<T>.power(arg: T, power: Int): T {
res *= arg res *= arg
} }
return res return res
} }

View File

@ -194,8 +194,8 @@ class BigInt internal constructor(
} }
infix fun or(other: BigInt): BigInt { infix fun or(other: BigInt): BigInt {
if (this == ZERO) return other; if (this == ZERO) return other
if (other == ZERO) return this; if (other == ZERO) return this
val resSize = max(this.magnitude.size, other.magnitude.size) val resSize = max(this.magnitude.size, other.magnitude.size)
val newMagnitude: Magnitude = Magnitude(resSize) val newMagnitude: Magnitude = Magnitude(resSize)
for (i in 0 until resSize) { for (i in 0 until resSize) {
@ -210,7 +210,7 @@ class BigInt internal constructor(
} }
infix fun and(other: BigInt): BigInt { infix fun and(other: BigInt): BigInt {
if ((this == ZERO) or (other == ZERO)) return ZERO; if ((this == ZERO) or (other == ZERO)) return ZERO
val resSize = min(this.magnitude.size, other.magnitude.size) val resSize = min(this.magnitude.size, other.magnitude.size)
val newMagnitude: Magnitude = Magnitude(resSize) val newMagnitude: Magnitude = Magnitude(resSize)
for (i in 0 until resSize) { for (i in 0 until resSize) {
@ -260,7 +260,7 @@ class BigInt internal constructor(
} }
companion object { companion object {
const val BASE = 0xffffffffUL const val BASE: ULong = 0xffffffffUL
const val BASE_SIZE: Int = 32 const val BASE_SIZE: Int = 32
val ZERO: BigInt = BigInt(0, uintArrayOf()) val ZERO: BigInt = BigInt(0, uintArrayOf())
val ONE: BigInt = BigInt(1, uintArrayOf(1u)) val ONE: BigInt = BigInt(1, uintArrayOf(1u))
@ -394,12 +394,12 @@ fun abs(x: BigInt): BigInt = x.abs()
/** /**
* Convert this [Int] to [BigInt] * Convert this [Int] to [BigInt]
*/ */
fun Int.toBigInt() = BigInt(sign.toByte(), uintArrayOf(kotlin.math.abs(this).toUInt())) fun Int.toBigInt(): BigInt = BigInt(sign.toByte(), uintArrayOf(kotlin.math.abs(this).toUInt()))
/** /**
* Convert this [Long] to [BigInt] * Convert this [Long] to [BigInt]
*/ */
fun Long.toBigInt() = BigInt( fun Long.toBigInt(): BigInt = BigInt(
sign.toByte(), stripLeadingZeros( sign.toByte(), stripLeadingZeros(
uintArrayOf( uintArrayOf(
(kotlin.math.abs(this).toULong() and BASE).toUInt(), (kotlin.math.abs(this).toULong() and BASE).toUInt(),
@ -411,17 +411,17 @@ fun Long.toBigInt() = BigInt(
/** /**
* Convert UInt to [BigInt] * Convert UInt to [BigInt]
*/ */
fun UInt.toBigInt() = BigInt(1, uintArrayOf(this)) fun UInt.toBigInt(): BigInt = BigInt(1, uintArrayOf(this))
/** /**
* Convert ULong to [BigInt] * Convert ULong to [BigInt]
*/ */
fun ULong.toBigInt() = BigInt( fun ULong.toBigInt(): BigInt = BigInt(
1, 1,
stripLeadingZeros( stripLeadingZeros(
uintArrayOf( uintArrayOf(
(this and BigInt.BASE).toUInt(), (this and BASE).toUInt(),
((this shr BigInt.BASE_SIZE) and BigInt.BASE).toUInt() ((this shr BASE_SIZE) and BASE).toUInt()
) )
) )
) )
@ -434,7 +434,7 @@ fun UIntArray.toBigInt(sign: Byte): BigInt {
return BigInt(sign, this.copyOf()) return BigInt(sign, this.copyOf())
} }
val hexChToInt = hashMapOf( val hexChToInt: MutableMap<Char, Int> = hashMapOf(
'0' to 0, '1' to 1, '2' to 2, '3' to 3, '0' to 0, '1' to 1, '2' to 2, '3' to 3,
'4' to 4, '5' to 5, '6' to 6, '7' to 7, '4' to 4, '5' to 5, '6' to 6, '7' to 7,
'8' to 8, '9' to 9, 'A' to 10, 'B' to 11, '8' to 8, '9' to 9, 'A' to 10, 'B' to 11,
@ -497,4 +497,4 @@ fun NDElement.Companion.bigInt(
vararg shape: Int, vararg shape: Int,
initializer: BigIntField.(IntArray) -> BigInt initializer: BigIntField.(IntArray) -> BigInt
): BufferedNDRingElement<BigInt, BigIntField> = ): BufferedNDRingElement<BigInt, BigIntField> =
NDAlgebra.bigInt(*shape).produce(initializer) NDAlgebra.bigInt(*shape).produce(initializer)

View File

@ -18,7 +18,7 @@ object ComplexField : ExtendedField<Complex> {
override val one: Complex = Complex(1.0, 0.0) override val one: Complex = Complex(1.0, 0.0)
val i = Complex(0.0, 1.0) val i: Complex = Complex(0.0, 1.0)
override fun add(a: Complex, b: Complex): Complex = Complex(a.re + b.re, a.im + b.im) override fun add(a: Complex, b: Complex): Complex = Complex(a.re + b.re, a.im + b.im)
@ -45,15 +45,15 @@ object ComplexField : ExtendedField<Complex> {
override fun ln(arg: Complex): Complex = ln(arg.r) + i * atan2(arg.im, arg.re) override fun ln(arg: Complex): Complex = ln(arg.r) + i * atan2(arg.im, arg.re)
operator fun Double.plus(c: Complex) = add(this.toComplex(), c) operator fun Double.plus(c: Complex): Complex = add(this.toComplex(), c)
operator fun Double.minus(c: Complex) = add(this.toComplex(), -c) operator fun Double.minus(c: Complex): Complex = add(this.toComplex(), -c)
operator fun Complex.plus(d: Double) = d + this operator fun Complex.plus(d: Double): Complex = d + this
operator fun Complex.minus(d: Double) = add(this, -d.toComplex()) operator fun Complex.minus(d: Double): Complex = add(this, -d.toComplex())
operator fun Double.times(c: Complex) = Complex(c.re * this, c.im * this) operator fun Double.times(c: Complex): Complex = Complex(c.re * this, c.im * this)
override fun symbol(value: String): Complex = if (value == "i") { override fun symbol(value: String): Complex = if (value == "i") {
i i
@ -104,7 +104,7 @@ val Complex.r: Double get() = sqrt(re * re + im * im)
*/ */
val Complex.theta: Double get() = atan(im / re) val Complex.theta: Double get() = atan(im / re)
fun Double.toComplex() = Complex(this, 0.0) fun Double.toComplex(): Complex = Complex(this, 0.0)
inline fun Buffer.Companion.complex(size: Int, crossinline init: (Int) -> Complex): Buffer<Complex> { inline fun Buffer.Companion.complex(size: Int, crossinline init: (Int) -> Complex): Buffer<Complex> {
return MemoryBuffer.create(Complex, size, init) return MemoryBuffer.create(Complex, size, init)

View File

@ -1,6 +1,5 @@
package scientifik.kmath.operations package scientifik.kmath.operations
import scientifik.kmath.operations.RealField.pow
import kotlin.math.abs import kotlin.math.abs
import kotlin.math.pow as kpow import kotlin.math.pow as kpow
@ -38,6 +37,8 @@ interface ExtendedField<T> : ExtendedFieldOperations<T>, Field<T> {
/** /**
* Real field element wrapping double. * Real field element wrapping double.
* *
* @property value the [Double] value wrapped by this [Real].
*
* TODO inline does not work due to compiler bug. Waiting for fix for KT-27586 * TODO inline does not work due to compiler bug. Waiting for fix for KT-27586
*/ */
inline class Real(val value: Double) : FieldElement<Double, Real, RealField> { inline class Real(val value: Double) : FieldElement<Double, Real, RealField> {
@ -45,7 +46,7 @@ inline class Real(val value: Double) : FieldElement<Double, Real, RealField> {
override fun Double.wrap(): Real = Real(value) override fun Double.wrap(): Real = Real(value)
override val context get() = RealField override val context: RealField get() = RealField
companion object companion object
} }
@ -56,36 +57,36 @@ inline class Real(val value: Double) : FieldElement<Double, Real, RealField> {
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
object RealField : ExtendedField<Double>, Norm<Double, Double> { object RealField : ExtendedField<Double>, Norm<Double, Double> {
override val zero: Double = 0.0 override val zero: Double = 0.0
override inline fun add(a: Double, b: Double) = a + b override inline fun add(a: Double, b: Double): Double = a + b
override inline fun multiply(a: Double, b: Double) = a * b override inline fun multiply(a: Double, b: Double): Double = a * b
override inline fun multiply(a: Double, k: Number) = a * k.toDouble() override inline fun multiply(a: Double, k: Number): Double = a * k.toDouble()
override val one: Double = 1.0 override val one: Double = 1.0
override inline fun divide(a: Double, b: Double) = a / b override inline fun divide(a: Double, b: Double): Double = a / b
override inline fun sin(arg: Double) = kotlin.math.sin(arg) override inline fun sin(arg: Double): Double = kotlin.math.sin(arg)
override inline fun cos(arg: Double) = kotlin.math.cos(arg) override inline fun cos(arg: Double): Double = kotlin.math.cos(arg)
override inline fun tan(arg: Double): Double = kotlin.math.tan(arg) override inline fun tan(arg: Double): Double = kotlin.math.tan(arg)
override inline fun acos(arg: Double): Double = kotlin.math.acos(arg) override inline fun acos(arg: Double): Double = kotlin.math.acos(arg)
override inline fun asin(arg: Double): Double = kotlin.math.asin(arg) override inline fun asin(arg: Double): Double = kotlin.math.asin(arg)
override inline fun atan(arg: Double): Double = kotlin.math.atan(arg) override inline fun atan(arg: Double): Double = kotlin.math.atan(arg)
override inline fun power(arg: Double, pow: Number) = arg.kpow(pow.toDouble()) override inline fun power(arg: Double, pow: Number): Double = arg.kpow(pow.toDouble())
override inline fun exp(arg: Double) = kotlin.math.exp(arg) override inline fun exp(arg: Double): Double = kotlin.math.exp(arg)
override inline fun ln(arg: Double) = kotlin.math.ln(arg) override inline fun ln(arg: Double): Double = kotlin.math.ln(arg)
override inline fun norm(arg: Double) = abs(arg) override inline fun norm(arg: Double): Double = abs(arg)
override inline fun Double.unaryMinus() = -this override inline fun Double.unaryMinus(): Double = -this
override inline fun Double.plus(b: Double) = this + b override inline fun Double.plus(b: Double): Double = this + b
override inline fun Double.minus(b: Double) = this - b override inline fun Double.minus(b: Double): Double = this - b
override inline fun Double.times(b: Double) = this * b override inline fun Double.times(b: Double): Double = this * b
override inline fun Double.div(b: Double) = this / b override inline fun Double.div(b: Double): Double = this / b
override fun binaryOperation(operation: String, left: Double, right: Double): Double = when (operation) { override fun binaryOperation(operation: String, left: Double, right: Double): Double = when (operation) {
PowerOperations.POW_OPERATION -> left pow right PowerOperations.POW_OPERATION -> left pow right
@ -96,36 +97,36 @@ object RealField : ExtendedField<Double>, Norm<Double, Double> {
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
object FloatField : ExtendedField<Float>, Norm<Float, Float> { object FloatField : ExtendedField<Float>, Norm<Float, Float> {
override val zero: Float = 0f override val zero: Float = 0f
override inline fun add(a: Float, b: Float) = a + b override inline fun add(a: Float, b: Float): Float = a + b
override inline fun multiply(a: Float, b: Float) = a * b override inline fun multiply(a: Float, b: Float): Float = a * b
override inline fun multiply(a: Float, k: Number) = a * k.toFloat() override inline fun multiply(a: Float, k: Number): Float = a * k.toFloat()
override val one: Float = 1f override val one: Float = 1f
override inline fun divide(a: Float, b: Float) = a / b override inline fun divide(a: Float, b: Float): Float = a / b
override inline fun sin(arg: Float) = kotlin.math.sin(arg) override inline fun sin(arg: Float): Float = kotlin.math.sin(arg)
override inline fun cos(arg: Float) = kotlin.math.cos(arg) override inline fun cos(arg: Float): Float = kotlin.math.cos(arg)
override inline fun tan(arg: Float) = kotlin.math.tan(arg) override inline fun tan(arg: Float): Float = kotlin.math.tan(arg)
override inline fun acos(arg: Float) = kotlin.math.acos(arg) override inline fun acos(arg: Float): Float = kotlin.math.acos(arg)
override inline fun asin(arg: Float) = kotlin.math.asin(arg) override inline fun asin(arg: Float): Float = kotlin.math.asin(arg)
override inline fun atan(arg: Float) = kotlin.math.atan(arg) override inline fun atan(arg: Float): Float = kotlin.math.atan(arg)
override inline fun power(arg: Float, pow: Number) = arg.pow(pow.toFloat()) override inline fun power(arg: Float, pow: Number): Float = arg.pow(pow.toFloat())
override inline fun exp(arg: Float) = kotlin.math.exp(arg) override inline fun exp(arg: Float): Float = kotlin.math.exp(arg)
override inline fun ln(arg: Float) = kotlin.math.ln(arg) override inline fun ln(arg: Float): Float = kotlin.math.ln(arg)
override inline fun norm(arg: Float) = abs(arg) override inline fun norm(arg: Float): Float = abs(arg)
override inline fun Float.unaryMinus() = -this override inline fun Float.unaryMinus(): Float = -this
override inline fun Float.plus(b: Float) = this + b override inline fun Float.plus(b: Float): Float = this + b
override inline fun Float.minus(b: Float) = this - b override inline fun Float.minus(b: Float): Float = this - b
override inline fun Float.times(b: Float) = this * b override inline fun Float.times(b: Float): Float = this * b
override inline fun Float.div(b: Float) = this / b override inline fun Float.div(b: Float): Float = this / b
} }
/** /**
@ -134,14 +135,14 @@ object FloatField : ExtendedField<Float>, Norm<Float, Float> {
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
object IntRing : Ring<Int>, Norm<Int, Int> { object IntRing : Ring<Int>, Norm<Int, Int> {
override val zero: Int = 0 override val zero: Int = 0
override inline fun add(a: Int, b: Int) = a + b override inline fun add(a: Int, b: Int): Int = a + b
override inline fun multiply(a: Int, b: Int) = a * b override inline fun multiply(a: Int, b: Int): Int = a * b
override inline fun multiply(a: Int, k: Number) = k.toInt() * a override inline fun multiply(a: Int, k: Number): Int = k.toInt() * a
override val one: Int = 1 override val one: Int = 1
override inline fun norm(arg: Int) = abs(arg) override inline fun norm(arg: Int): Int = abs(arg)
override inline fun Int.unaryMinus() = -this override inline fun Int.unaryMinus(): Int = -this
override inline fun Int.plus(b: Int): Int = this + b override inline fun Int.plus(b: Int): Int = this + b
@ -156,20 +157,20 @@ object IntRing : Ring<Int>, Norm<Int, Int> {
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
object ShortRing : Ring<Short>, Norm<Short, Short> { object ShortRing : Ring<Short>, Norm<Short, Short> {
override val zero: Short = 0 override val zero: Short = 0
override inline fun add(a: Short, b: Short) = (a + b).toShort() override inline fun add(a: Short, b: Short): Short = (a + b).toShort()
override inline fun multiply(a: Short, b: Short) = (a * b).toShort() override inline fun multiply(a: Short, b: Short): Short = (a * b).toShort()
override inline fun multiply(a: Short, k: Number) = (a * k.toShort()).toShort() override inline fun multiply(a: Short, k: Number): Short = (a * k.toShort()).toShort()
override val one: Short = 1 override val one: Short = 1
override fun norm(arg: Short): Short = if (arg > 0) arg else (-arg).toShort() override fun norm(arg: Short): Short = if (arg > 0) arg else (-arg).toShort()
override inline fun Short.unaryMinus() = (-this).toShort() override inline fun Short.unaryMinus(): Short = (-this).toShort()
override inline fun Short.plus(b: Short) = (this + b).toShort() override inline fun Short.plus(b: Short): Short = (this + b).toShort()
override inline fun Short.minus(b: Short) = (this - b).toShort() override inline fun Short.minus(b: Short): Short = (this - b).toShort()
override inline fun Short.times(b: Short) = (this * b).toShort() override inline fun Short.times(b: Short): Short = (this * b).toShort()
} }
/** /**
@ -178,20 +179,20 @@ object ShortRing : Ring<Short>, Norm<Short, Short> {
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
object ByteRing : Ring<Byte>, Norm<Byte, Byte> { object ByteRing : Ring<Byte>, Norm<Byte, Byte> {
override val zero: Byte = 0 override val zero: Byte = 0
override inline fun add(a: Byte, b: Byte) = (a + b).toByte() override inline fun add(a: Byte, b: Byte): Byte = (a + b).toByte()
override inline fun multiply(a: Byte, b: Byte) = (a * b).toByte() override inline fun multiply(a: Byte, b: Byte): Byte = (a * b).toByte()
override inline fun multiply(a: Byte, k: Number) = (a * k.toByte()).toByte() override inline fun multiply(a: Byte, k: Number): Byte = (a * k.toByte()).toByte()
override val one: Byte = 1 override val one: Byte = 1
override fun norm(arg: Byte): Byte = if (arg > 0) arg else (-arg).toByte() override fun norm(arg: Byte): Byte = if (arg > 0) arg else (-arg).toByte()
override inline fun Byte.unaryMinus() = (-this).toByte() override inline fun Byte.unaryMinus(): Byte = (-this).toByte()
override inline fun Byte.plus(b: Byte) = (this + b).toByte() override inline fun Byte.plus(b: Byte): Byte = (this + b).toByte()
override inline fun Byte.minus(b: Byte) = (this - b).toByte() override inline fun Byte.minus(b: Byte): Byte = (this - b).toByte()
override inline fun Byte.times(b: Byte) = (this * b).toByte() override inline fun Byte.times(b: Byte): Byte = (this * b).toByte()
} }
/** /**
@ -200,18 +201,18 @@ object ByteRing : Ring<Byte>, Norm<Byte, Byte> {
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
object LongRing : Ring<Long>, Norm<Long, Long> { object LongRing : Ring<Long>, Norm<Long, Long> {
override val zero: Long = 0 override val zero: Long = 0
override inline fun add(a: Long, b: Long) = (a + b) override inline fun add(a: Long, b: Long): Long = (a + b)
override inline fun multiply(a: Long, b: Long) = (a * b) override inline fun multiply(a: Long, b: Long): Long = (a * b)
override inline fun multiply(a: Long, k: Number) = a * k.toLong() override inline fun multiply(a: Long, k: Number): Long = a * k.toLong()
override val one: Long = 1 override val one: Long = 1
override fun norm(arg: Long): Long = abs(arg) override fun norm(arg: Long): Long = abs(arg)
override inline fun Long.unaryMinus() = (-this) override inline fun Long.unaryMinus(): Long = (-this)
override inline fun Long.plus(b: Long) = (this + b) override inline fun Long.plus(b: Long): Long = (this + b)
override inline fun Long.minus(b: Long) = (this - b) override inline fun Long.minus(b: Long): Long = (this - b)
override inline fun Long.times(b: Long) = (this * b) override inline fun Long.times(b: Long): Long = (this * b)
} }

View File

@ -1,84 +1,210 @@
package scientifik.kmath.operations package scientifik.kmath.operations
/* Trigonometric operations */
/** /**
* A container for trigonometric operations for specific type. Trigonometric operations are limited to fields. * A container for trigonometric operations for specific type. They are limited to semifields.
* *
* The operations are not exposed to class directly to avoid method bloat but instead are declared in the field. * The operations are not exposed to class directly to avoid method bloat but instead are declared in the field.
* It also allows to override behavior for optional operations * It also allows to override behavior for optional operations.
*
*/ */
interface TrigonometricOperations<T> : FieldOperations<T> { interface TrigonometricOperations<T> : FieldOperations<T> {
/**
* Computes the sine of [arg].
*/
fun sin(arg: T): T fun sin(arg: T): T
/**
* Computes the cosine of [arg].
*/
fun cos(arg: T): T fun cos(arg: T): T
/**
* Computes the tangent of [arg].
*/
fun tan(arg: T): T fun tan(arg: T): T
companion object { companion object {
const val SIN_OPERATION = "sin" /**
const val COS_OPERATION = "cos" * The identifier of sine.
const val TAN_OPERATION = "tan" */
const val SIN_OPERATION: String = "sin"
/**
* The identifier of cosine.
*/
const val COS_OPERATION: String = "cos"
/**
* The identifier of tangent.
*/
const val TAN_OPERATION: String = "tan"
} }
} }
/**
* A container for inverse trigonometric operations for specific type. They are limited to semifields.
*
* The operations are not exposed to class directly to avoid method bloat but instead are declared in the field.
* It also allows to override behavior for optional operations.
*/
interface InverseTrigonometricOperations<T> : TrigonometricOperations<T> { interface InverseTrigonometricOperations<T> : TrigonometricOperations<T> {
/**
* Computes the inverse sine of [arg].
*/
fun asin(arg: T): T fun asin(arg: T): T
/**
* Computes the inverse cosine of [arg].
*/
fun acos(arg: T): T fun acos(arg: T): T
/**
* Computes the inverse tangent of [arg].
*/
fun atan(arg: T): T fun atan(arg: T): T
companion object { companion object {
const val ASIN_OPERATION = "asin" /**
const val ACOS_OPERATION = "acos" * The identifier of inverse sine.
const val ATAN_OPERATION = "atan" */
const val ASIN_OPERATION: String = "asin"
/**
* The identifier of inverse cosine.
*/
const val ACOS_OPERATION: String = "acos"
/**
* The identifier of inverse tangent.
*/
const val ATAN_OPERATION: String = "atan"
} }
} }
fun <T : MathElement<out TrigonometricOperations<T>>> sin(arg: T): T = arg.context.sin(arg)
fun <T : MathElement<out TrigonometricOperations<T>>> cos(arg: T): T = arg.context.cos(arg)
fun <T : MathElement<out TrigonometricOperations<T>>> tan(arg: T): T = arg.context.tan(arg)
fun <T : MathElement<out InverseTrigonometricOperations<T>>> asin(arg: T): T = arg.context.asin(arg)
fun <T : MathElement<out InverseTrigonometricOperations<T>>> acos(arg: T): T = arg.context.acos(arg)
fun <T : MathElement<out InverseTrigonometricOperations<T>>> atan(arg: T): T = arg.context.atan(arg)
/* Power and roots */
/** /**
* A context extension to include power operations like square roots, etc * Computes the sine of [arg].
*/
fun <T : MathElement<out TrigonometricOperations<T>>> sin(arg: T): T = arg.context.sin(arg)
/**
* Computes the cosine of [arg].
*/
fun <T : MathElement<out TrigonometricOperations<T>>> cos(arg: T): T = arg.context.cos(arg)
/**
* Computes the tangent of [arg].
*/
fun <T : MathElement<out TrigonometricOperations<T>>> tan(arg: T): T = arg.context.tan(arg)
/**
* Computes the inverse sine of [arg].
*/
fun <T : MathElement<out InverseTrigonometricOperations<T>>> asin(arg: T): T = arg.context.asin(arg)
/**
* Computes the inverse cosine of [arg].
*/
fun <T : MathElement<out InverseTrigonometricOperations<T>>> acos(arg: T): T = arg.context.acos(arg)
/**
* Computes the inverse tangent of [arg].
*/
fun <T : MathElement<out InverseTrigonometricOperations<T>>> atan(arg: T): T = arg.context.atan(arg)
/**
* A context extension to include power operations based on exponentiation.
*/ */
interface PowerOperations<T> : Algebra<T> { interface PowerOperations<T> : Algebra<T> {
/**
* Raises [arg] to the power [pow].
*/
fun power(arg: T, pow: Number): T fun power(arg: T, pow: Number): T
fun sqrt(arg: T) = power(arg, 0.5)
infix fun T.pow(pow: Number) = power(this, pow) /**
* Computes the square root of the value [arg].
*/
fun sqrt(arg: T): T = power(arg, 0.5)
/**
* Raises this value to the power [pow].
*/
infix fun T.pow(pow: Number): T = power(this, pow)
companion object { companion object {
const val POW_OPERATION = "pow" /**
const val SQRT_OPERATION = "sqrt" * The identifier of exponentiation.
*/
const val POW_OPERATION: String = "pow"
/**
* The identifier of square root.
*/
const val SQRT_OPERATION: String = "sqrt"
} }
} }
/**
* Raises [arg] to the power [pow].
*/
infix fun <T : MathElement<out PowerOperations<T>>> T.pow(power: Double): T = context.power(this, power) infix fun <T : MathElement<out PowerOperations<T>>> T.pow(power: Double): T = context.power(this, power)
/**
* Computes the square root of the value [arg].
*/
fun <T : MathElement<out PowerOperations<T>>> sqrt(arg: T): T = arg pow 0.5 fun <T : MathElement<out PowerOperations<T>>> sqrt(arg: T): T = arg pow 0.5
/**
* Computes the square of the value [arg].
*/
fun <T : MathElement<out PowerOperations<T>>> sqr(arg: T): T = arg pow 2.0 fun <T : MathElement<out PowerOperations<T>>> sqr(arg: T): T = arg pow 2.0
/* Exponential */ /**
* A container for operations related to `exp` and `ln` functions.
*/
interface ExponentialOperations<T> : Algebra<T> { interface ExponentialOperations<T> : Algebra<T> {
/**
* Computes Euler's number `e` raised to the power of the value [arg].
*/
fun exp(arg: T): T fun exp(arg: T): T
/**
* Computes the natural logarithm (base `e`) of the value [arg].
*/
fun ln(arg: T): T fun ln(arg: T): T
companion object { companion object {
const val EXP_OPERATION = "exp" /**
const val LN_OPERATION = "ln" * The identifier of exponential function.
*/
const val EXP_OPERATION: String = "exp"
/**
* The identifier of natural logarithm.
*/
const val LN_OPERATION: String = "ln"
} }
} }
/**
* The identifier of exponential function.
*/
fun <T : MathElement<out ExponentialOperations<T>>> exp(arg: T): T = arg.context.exp(arg) fun <T : MathElement<out ExponentialOperations<T>>> exp(arg: T): T = arg.context.exp(arg)
/**
* The identifier of natural logarithm.
*/
fun <T : MathElement<out ExponentialOperations<T>>> ln(arg: T): T = arg.context.ln(arg) fun <T : MathElement<out ExponentialOperations<T>>> ln(arg: T): T = arg.context.ln(arg)
/**
* A container for norm functional on element.
*/
interface Norm<in T : Any, out R> { interface Norm<in T : Any, out R> {
/**
* Computes the norm of [arg] (i.e. absolute value or vector length).
*/
fun norm(arg: T): R fun norm(arg: T): R
} }
/**
* Computes the norm of [arg] (i.e. absolute value or vector length).
*/
fun <T : MathElement<out Norm<T, R>>, R> norm(arg: T): R = arg.context.norm(arg) fun <T : MathElement<out Norm<T, R>>, R> norm(arg: T): R = arg.context.norm(arg)

View File

@ -19,10 +19,10 @@ class BoxingNDField<T, F : Field<T>>(
if (!elements.all { it.strides == this.strides }) error("Element strides are not the same as context strides") if (!elements.all { it.strides == this.strides }) error("Element strides are not the same as context strides")
} }
override val zero by lazy { produce { zero } } override val zero: BufferedNDFieldElement<T, F> by lazy { produce { zero } }
override val one by lazy { produce { one } } override val one: BufferedNDFieldElement<T, F> by lazy { produce { one } }
override fun produce(initializer: F.(IntArray) -> T) = override fun produce(initializer: F.(IntArray) -> T): BufferedNDFieldElement<T, F> =
BufferedNDFieldElement( BufferedNDFieldElement(
this, this,
buildBuffer(strides.linearSize) { offset -> elementContext.initializer(strides.index(offset)) }) buildBuffer(strides.linearSize) { offset -> elementContext.initializer(strides.index(offset)) })
@ -79,4 +79,4 @@ inline fun <T : Any, F : Field<T>, R> F.nd(
): R { ): R {
val ndfield: BoxingNDField<T, F> = NDField.boxing(this, *shape, bufferFactory = bufferFactory) val ndfield: BoxingNDField<T, F> = NDField.boxing(this, *shape, bufferFactory = bufferFactory)
return ndfield.action() return ndfield.action()
} }

View File

@ -18,10 +18,10 @@ class BoxingNDRing<T, R : Ring<T>>(
if (!elements.all { it.strides == this.strides }) error("Element strides are not the same as context strides") if (!elements.all { it.strides == this.strides }) error("Element strides are not the same as context strides")
} }
override val zero by lazy { produce { zero } } override val zero: BufferedNDRingElement<T, R> by lazy { produce { zero } }
override val one by lazy { produce { one } } override val one: BufferedNDRingElement<T, R> by lazy { produce { one } }
override fun produce(initializer: R.(IntArray) -> T) = override fun produce(initializer: R.(IntArray) -> T): BufferedNDRingElement<T, R> =
BufferedNDRingElement( BufferedNDRingElement(
this, this,
buildBuffer(strides.linearSize) { offset -> elementContext.initializer(strides.index(offset)) }) buildBuffer(strides.linearSize) { offset -> elementContext.initializer(strides.index(offset)) })
@ -69,4 +69,4 @@ class BoxingNDRing<T, R : Ring<T>>(
override fun NDBuffer<T>.toElement(): RingElement<NDBuffer<T>, *, out BufferedNDRing<T, R>> = override fun NDBuffer<T>.toElement(): RingElement<NDBuffer<T>, *, out BufferedNDRing<T, R>> =
BufferedNDRingElement(this@BoxingNDRing, buffer) BufferedNDRingElement(this@BoxingNDRing, buffer)
} }

View File

@ -7,16 +7,16 @@ import kotlin.reflect.KClass
*/ */
class BufferAccessor2D<T : Any>(val type: KClass<T>, val rowNum: Int, val colNum: Int) { class BufferAccessor2D<T : Any>(val type: KClass<T>, val rowNum: Int, val colNum: Int) {
operator fun Buffer<T>.get(i: Int, j: Int) = get(i + colNum * j) operator fun Buffer<T>.get(i: Int, j: Int): T = get(i + colNum * j)
operator fun MutableBuffer<T>.set(i: Int, j: Int, value: T) { operator fun MutableBuffer<T>.set(i: Int, j: Int, value: T) {
set(i + colNum * j, value) set(i + colNum * j, value)
} }
inline fun create(init: (i: Int, j: Int) -> T) = inline fun create(init: (i: Int, j: Int) -> T): MutableBuffer<T> =
MutableBuffer.auto(type, rowNum * colNum) { offset -> init(offset / colNum, offset % colNum) } MutableBuffer.auto(type, rowNum * colNum) { offset -> init(offset / colNum, offset % colNum) }
fun create(mat: Structure2D<T>) = create { i, j -> mat[i, j] } fun create(mat: Structure2D<T>): MutableBuffer<T> = create { i, j -> mat[i, j] }
//TODO optimize wrapper //TODO optimize wrapper
fun MutableBuffer<T>.collect(): Structure2D<T> = fun MutableBuffer<T>.collect(): Structure2D<T> =
@ -41,5 +41,5 @@ class BufferAccessor2D<T : Any>(val type: KClass<T>, val rowNum: Int, val colNum
/** /**
* Get row * Get row
*/ */
fun MutableBuffer<T>.row(i: Int) = Row(this, i) fun MutableBuffer<T>.row(i: Int): Row = Row(this, i)
} }

View File

@ -2,7 +2,7 @@ package scientifik.kmath.structures
import scientifik.kmath.operations.* import scientifik.kmath.operations.*
interface BufferedNDAlgebra<T, C>: NDAlgebra<T, C, NDBuffer<T>>{ interface BufferedNDAlgebra<T, C> : NDAlgebra<T, C, NDBuffer<T>> {
val strides: Strides val strides: Strides
override fun check(vararg elements: NDBuffer<T>) { override fun check(vararg elements: NDBuffer<T>) {
@ -30,7 +30,7 @@ interface BufferedNDAlgebra<T, C>: NDAlgebra<T, C, NDBuffer<T>>{
} }
interface BufferedNDSpace<T, S : Space<T>> : NDSpace<T, S, NDBuffer<T>>, BufferedNDAlgebra<T,S> { interface BufferedNDSpace<T, S : Space<T>> : NDSpace<T, S, NDBuffer<T>>, BufferedNDAlgebra<T, S> {
override fun NDBuffer<T>.toElement(): SpaceElement<NDBuffer<T>, *, out BufferedNDSpace<T, S>> override fun NDBuffer<T>.toElement(): SpaceElement<NDBuffer<T>, *, out BufferedNDSpace<T, S>>
} }

View File

@ -8,7 +8,7 @@ import scientifik.kmath.operations.*
abstract class BufferedNDElement<T, C> : NDBuffer<T>(), NDElement<T, C, NDBuffer<T>> { abstract class BufferedNDElement<T, C> : NDBuffer<T>(), NDElement<T, C, NDBuffer<T>> {
abstract override val context: BufferedNDAlgebra<T, C> abstract override val context: BufferedNDAlgebra<T, C>
override val strides get() = context.strides override val strides: Strides get() = context.strides
override val shape: IntArray get() = context.shape override val shape: IntArray get() = context.shape
} }
@ -54,9 +54,9 @@ class BufferedNDFieldElement<T, F : Field<T>>(
/** /**
* Element by element application of any operation on elements to the whole array. Just like in numpy * Element by element application of any operation on elements to the whole array. Just like in numpy.
*/ */
operator fun <T : Any, F : Field<T>> Function1<T, T>.invoke(ndElement: BufferedNDElement<T, F>) = operator fun <T : Any, F : Field<T>> Function1<T, T>.invoke(ndElement: BufferedNDElement<T, F>): MathElement<out BufferedNDAlgebra<T, F>> =
ndElement.context.run { map(ndElement) { invoke(it) }.toElement() } ndElement.context.run { map(ndElement) { invoke(it) }.toElement() }
/* plus and minus */ /* plus and minus */
@ -64,13 +64,13 @@ operator fun <T : Any, F : Field<T>> Function1<T, T>.invoke(ndElement: BufferedN
/** /**
* Summation operation for [BufferedNDElement] and single element * Summation operation for [BufferedNDElement] and single element
*/ */
operator fun <T : Any, F : Space<T>> BufferedNDElement<T, F>.plus(arg: T) = operator fun <T : Any, F : Space<T>> BufferedNDElement<T, F>.plus(arg: T): NDElement<T, F, NDBuffer<T>> =
context.map(this) { it + arg }.wrap() context.map(this) { it + arg }.wrap()
/** /**
* Subtraction operation between [BufferedNDElement] and single element * Subtraction operation between [BufferedNDElement] and single element
*/ */
operator fun <T : Any, F : Space<T>> BufferedNDElement<T, F>.minus(arg: T) = operator fun <T : Any, F : Space<T>> BufferedNDElement<T, F>.minus(arg: T): NDElement<T, F, NDBuffer<T>> =
context.map(this) { it - arg }.wrap() context.map(this) { it - arg }.wrap()
/* prod and div */ /* prod and div */
@ -78,11 +78,11 @@ operator fun <T : Any, F : Space<T>> BufferedNDElement<T, F>.minus(arg: T) =
/** /**
* Product operation for [BufferedNDElement] and single element * Product operation for [BufferedNDElement] and single element
*/ */
operator fun <T : Any, F : Ring<T>> BufferedNDElement<T, F>.times(arg: T) = operator fun <T : Any, F : Ring<T>> BufferedNDElement<T, F>.times(arg: T): NDElement<T, F, NDBuffer<T>> =
context.map(this) { it * arg }.wrap() context.map(this) { it * arg }.wrap()
/** /**
* Division operation between [BufferedNDElement] and single element * Division operation between [BufferedNDElement] and single element
*/ */
operator fun <T : Any, F : Field<T>> BufferedNDElement<T, F>.div(arg: T) = operator fun <T : Any, F : Field<T>> BufferedNDElement<T, F>.div(arg: T): NDElement<T, F, NDBuffer<T>> =
context.map(this) { it / arg }.wrap() context.map(this) { it / arg }.wrap()

View File

@ -15,22 +15,22 @@ typealias MutableBufferFactory<T> = (Int, (Int) -> T) -> MutableBuffer<T>
interface Buffer<T> { interface Buffer<T> {
/** /**
* The size of the buffer * The size of this buffer.
*/ */
val size: Int val size: Int
/** /**
* Get element at given index * Gets element at given index.
*/ */
operator fun get(index: Int): T operator fun get(index: Int): T
/** /**
* Iterate over all elements * Iterates over all elements.
*/ */
operator fun iterator(): Iterator<T> operator fun iterator(): Iterator<T>
/** /**
* Check content eqiality with another buffer * Checks content equality with another buffer.
*/ */
fun contentEquals(other: Buffer<*>): Boolean = fun contentEquals(other: Buffer<*>): Boolean =
asSequence().mapIndexed { index, value -> value == other[index] }.all { it } asSequence().mapIndexed { index, value -> value == other[index] }.all { it }
@ -124,10 +124,9 @@ inline class ListBuffer<T>(val list: List<T>) : Buffer<T> {
override fun iterator(): Iterator<T> = list.iterator() override fun iterator(): Iterator<T> = list.iterator()
} }
fun <T> List<T>.asBuffer() = ListBuffer<T>(this) fun <T> List<T>.asBuffer(): ListBuffer<T> = ListBuffer(this)
@Suppress("FunctionName") inline fun <T> ListBuffer(size: Int, init: (Int) -> T): ListBuffer<T> = List(size, init).asBuffer()
inline fun <T> ListBuffer(size: Int, init: (Int) -> T) = List(size, init).asBuffer()
inline class MutableListBuffer<T>(val list: MutableList<T>) : MutableBuffer<T> { inline class MutableListBuffer<T>(val list: MutableList<T>) : MutableBuffer<T> {
@ -165,13 +164,13 @@ fun <T> Array<T>.asBuffer(): ArrayBuffer<T> = ArrayBuffer(this)
inline class ReadOnlyBuffer<T>(val buffer: MutableBuffer<T>) : Buffer<T> { inline class ReadOnlyBuffer<T>(val buffer: MutableBuffer<T>) : Buffer<T> {
override val size: Int get() = buffer.size override val size: Int get() = buffer.size
override fun get(index: Int): T = buffer.get(index) override fun get(index: Int): T = buffer[index]
override fun iterator() = buffer.iterator() override fun iterator(): Iterator<T> = buffer.iterator()
} }
/** /**
* A buffer with content calculated on-demand. The calculated contect is not stored, so it is recalculated on each call. * A buffer with content calculated on-demand. The calculated content is not stored, so it is recalculated on each call.
* Useful when one needs single element from the buffer. * Useful when one needs single element from the buffer.
*/ */
class VirtualBuffer<T>(override val size: Int, private val generator: (Int) -> T) : Buffer<T> { class VirtualBuffer<T>(override val size: Int, private val generator: (Int) -> T) : Buffer<T> {
@ -205,4 +204,4 @@ fun <T> Buffer<T>.asReadOnly(): Buffer<T> = if (this is MutableBuffer) {
*/ */
typealias BufferTransform<T, R> = (Buffer<T>) -> Buffer<R> typealias BufferTransform<T, R> = (Buffer<T>) -> Buffer<R>
typealias SuspendBufferTransform<T, R> = suspend (Buffer<T>) -> Buffer<R> typealias SuspendBufferTransform<T, R> = suspend (Buffer<T>) -> Buffer<R>

View File

@ -17,8 +17,8 @@ class ComplexNDField(override val shape: IntArray) :
override val strides: Strides = DefaultStrides(shape) override val strides: Strides = DefaultStrides(shape)
override val elementContext: ComplexField get() = ComplexField override val elementContext: ComplexField get() = ComplexField
override val zero by lazy { produce { zero } } override val zero: ComplexNDElement by lazy { produce { zero } }
override val one by lazy { produce { one } } override val one: ComplexNDElement by lazy { produce { one } }
inline fun buildBuffer(size: Int, crossinline initializer: (Int) -> Complex): Buffer<Complex> = inline fun buildBuffer(size: Int, crossinline initializer: (Int) -> Complex): Buffer<Complex> =
Buffer.complex(size) { initializer(it) } Buffer.complex(size) { initializer(it) }
@ -69,23 +69,23 @@ class ComplexNDField(override val shape: IntArray) :
override fun NDBuffer<Complex>.toElement(): FieldElement<NDBuffer<Complex>, *, out BufferedNDField<Complex, ComplexField>> = override fun NDBuffer<Complex>.toElement(): FieldElement<NDBuffer<Complex>, *, out BufferedNDField<Complex, ComplexField>> =
BufferedNDFieldElement(this@ComplexNDField, buffer) BufferedNDFieldElement(this@ComplexNDField, buffer)
override fun power(arg: NDBuffer<Complex>, pow: Number) = map(arg) { power(it, pow) } override fun power(arg: NDBuffer<Complex>, pow: Number): ComplexNDElement = map(arg) { power(it, pow) }
override fun exp(arg: NDBuffer<Complex>) = map(arg) { exp(it) } override fun exp(arg: NDBuffer<Complex>): ComplexNDElement = map(arg) { exp(it) }
override fun ln(arg: NDBuffer<Complex>) = map(arg) { ln(it) } override fun ln(arg: NDBuffer<Complex>): ComplexNDElement = map(arg) { ln(it) }
override fun sin(arg: NDBuffer<Complex>) = map(arg) { sin(it) } override fun sin(arg: NDBuffer<Complex>): ComplexNDElement = map(arg) { sin(it) }
override fun cos(arg: NDBuffer<Complex>) = map(arg) { cos(it) } override fun cos(arg: NDBuffer<Complex>): ComplexNDElement = map(arg) { cos(it) }
override fun tan(arg: NDBuffer<Complex>): NDBuffer<Complex> = map(arg) { tan(it) } override fun tan(arg: NDBuffer<Complex>): ComplexNDElement = map(arg) { tan(it) }
override fun asin(arg: NDBuffer<Complex>): NDBuffer<Complex> = map(arg) { asin(it) } override fun asin(arg: NDBuffer<Complex>): ComplexNDElement = map(arg) { asin(it) }
override fun acos(arg: NDBuffer<Complex>): NDBuffer<Complex> = map(arg) {acos(it)} override fun acos(arg: NDBuffer<Complex>): ComplexNDElement = map(arg) { acos(it) }
override fun atan(arg: NDBuffer<Complex>): NDBuffer<Complex> = map(arg) {atan(it)} override fun atan(arg: NDBuffer<Complex>): ComplexNDElement = map(arg) { atan(it) }
} }
@ -100,7 +100,7 @@ inline fun BufferedNDField<Complex, ComplexField>.produceInline(crossinline init
/** /**
* Map one [ComplexNDElement] using function with indexes * Map one [ComplexNDElement] using function with indexes
*/ */
inline fun ComplexNDElement.mapIndexed(crossinline transform: ComplexField.(index: IntArray, Complex) -> Complex) = inline fun ComplexNDElement.mapIndexed(crossinline transform: ComplexField.(index: IntArray, Complex) -> Complex): ComplexNDElement =
context.produceInline { offset -> transform(strides.index(offset), buffer[offset]) } context.produceInline { offset -> transform(strides.index(offset), buffer[offset]) }
/** /**
@ -114,7 +114,7 @@ inline fun ComplexNDElement.map(crossinline transform: ComplexField.(Complex) ->
/** /**
* Element by element application of any operation on elements to the whole array. Just like in numpy * Element by element application of any operation on elements to the whole array. Just like in numpy
*/ */
operator fun Function1<Complex, Complex>.invoke(ndElement: ComplexNDElement) = operator fun Function1<Complex, Complex>.invoke(ndElement: ComplexNDElement): ComplexNDElement =
ndElement.map { this@invoke(it) } ndElement.map { this@invoke(it) }
@ -123,19 +123,18 @@ operator fun Function1<Complex, Complex>.invoke(ndElement: ComplexNDElement) =
/** /**
* Summation operation for [BufferedNDElement] and single element * Summation operation for [BufferedNDElement] and single element
*/ */
operator fun ComplexNDElement.plus(arg: Complex) = operator fun ComplexNDElement.plus(arg: Complex): ComplexNDElement = map { it + arg }
map { it + arg }
/** /**
* Subtraction operation between [BufferedNDElement] and single element * Subtraction operation between [BufferedNDElement] and single element
*/ */
operator fun ComplexNDElement.minus(arg: Complex) = operator fun ComplexNDElement.minus(arg: Complex): ComplexNDElement =
map { it - arg } map { it - arg }
operator fun ComplexNDElement.plus(arg: Double) = operator fun ComplexNDElement.plus(arg: Double): ComplexNDElement =
map { it + arg } map { it + arg }
operator fun ComplexNDElement.minus(arg: Double) = operator fun ComplexNDElement.minus(arg: Double): ComplexNDElement =
map { it - arg } map { it - arg }
fun NDField.Companion.complex(vararg shape: Int): ComplexNDField = ComplexNDField(shape) fun NDField.Companion.complex(vararg shape: Int): ComplexNDField = ComplexNDField(shape)
@ -148,4 +147,4 @@ fun NDElement.Companion.complex(vararg shape: Int, initializer: ComplexField.(In
*/ */
inline fun <R> ComplexField.nd(vararg shape: Int, action: ComplexNDField.() -> R): R { inline fun <R> ComplexField.nd(vararg shape: Int, action: ComplexNDField.() -> R): R {
return NDField.complex(*shape).run(action) return NDField.complex(*shape).run(action)
} }

View File

@ -4,7 +4,6 @@ import scientifik.kmath.operations.ExtendedField
interface ExtendedNDField<T : Any, F : ExtendedField<T>, N : NDStructure<T>> : NDField<T, F, N>, ExtendedField<N> interface ExtendedNDField<T : Any, F : ExtendedField<T>, N : NDStructure<T>> : NDField<T, F, N>, ExtendedField<N>
///** ///**
// * NDField that supports [ExtendedField] operations on its elements // * NDField that supports [ExtendedField] operations on its elements
// */ // */
@ -36,5 +35,3 @@ interface ExtendedNDField<T : Any, F : ExtendedField<T>, N : NDStructure<T>> : N
// return produce { with(elementContext) { cos(arg[it]) } } // return produce { with(elementContext) { cos(arg[it]) } }
// } // }
//} //}

View File

@ -19,11 +19,11 @@ interface FlaggedBuffer<T> : Buffer<T> {
/** /**
* The value is valid if all flags are down * The value is valid if all flags are down
*/ */
fun FlaggedBuffer<*>.isValid(index: Int) = getFlag(index) != 0.toByte() fun FlaggedBuffer<*>.isValid(index: Int): Boolean = getFlag(index) != 0.toByte()
fun FlaggedBuffer<*>.hasFlag(index: Int, flag: ValueFlag) = (getFlag(index) and flag.mask) != 0.toByte() fun FlaggedBuffer<*>.hasFlag(index: Int, flag: ValueFlag): Boolean = (getFlag(index) and flag.mask) != 0.toByte()
fun FlaggedBuffer<*>.isMissing(index: Int) = hasFlag(index, ValueFlag.MISSING) fun FlaggedBuffer<*>.isMissing(index: Int): Boolean = hasFlag(index, ValueFlag.MISSING)
/** /**
* A real buffer which supports flags for each value like NaN or Missing * A real buffer which supports flags for each value like NaN or Missing
@ -45,9 +45,9 @@ class FlaggedRealBuffer(val values: DoubleArray, val flags: ByteArray) : Flagged
} }
inline fun FlaggedRealBuffer.forEachValid(block: (Double) -> Unit) { inline fun FlaggedRealBuffer.forEachValid(block: (Double) -> Unit) {
for(i in indices){ for (i in indices) {
if(isValid(i)){ if (isValid(i)) {
block(values[i]) block(values[i])
} }
} }
} }

View File

@ -9,12 +9,11 @@ inline class IntBuffer(val array: IntArray) : MutableBuffer<Int> {
array[index] = value array[index] = value
} }
override fun iterator() = array.iterator() override fun iterator(): IntIterator = array.iterator()
override fun copy(): MutableBuffer<Int> = override fun copy(): MutableBuffer<Int> =
IntBuffer(array.copyOf()) IntBuffer(array.copyOf())
} }
fun IntArray.asBuffer(): IntBuffer = IntBuffer(this)
fun IntArray.asBuffer() = IntBuffer(this)

View File

@ -9,11 +9,11 @@ inline class LongBuffer(val array: LongArray) : MutableBuffer<Long> {
array[index] = value array[index] = value
} }
override fun iterator() = array.iterator() override fun iterator(): LongIterator = array.iterator()
override fun copy(): MutableBuffer<Long> = override fun copy(): MutableBuffer<Long> =
LongBuffer(array.copyOf()) LongBuffer(array.copyOf())
} }
fun LongArray.asBuffer() = LongBuffer(this) fun LongArray.asBuffer(): LongBuffer = LongBuffer(this)

View File

@ -3,7 +3,7 @@ package scientifik.kmath.structures
import scientifik.memory.* import scientifik.memory.*
/** /**
* A non-boxing buffer based on [ByteBuffer] storage * A non-boxing buffer over [Memory] object.
*/ */
open class MemoryBuffer<T : Any>(protected val memory: Memory, protected val spec: MemorySpec<T>) : Buffer<T> { open class MemoryBuffer<T : Any>(protected val memory: Memory, protected val spec: MemorySpec<T>) : Buffer<T> {
@ -17,7 +17,7 @@ open class MemoryBuffer<T : Any>(protected val memory: Memory, protected val spe
companion object { companion object {
fun <T : Any> create(spec: MemorySpec<T>, size: Int) = fun <T : Any> create(spec: MemorySpec<T>, size: Int): MemoryBuffer<T> =
MemoryBuffer(Memory.allocate(size * spec.objectSize), spec) MemoryBuffer(Memory.allocate(size * spec.objectSize), spec)
inline fun <T : Any> create( inline fun <T : Any> create(
@ -36,25 +36,25 @@ open class MemoryBuffer<T : Any>(protected val memory: Memory, protected val spe
class MutableMemoryBuffer<T : Any>(memory: Memory, spec: MemorySpec<T>) : MemoryBuffer<T>(memory, spec), class MutableMemoryBuffer<T : Any>(memory: Memory, spec: MemorySpec<T>) : MemoryBuffer<T>(memory, spec),
MutableBuffer<T> { MutableBuffer<T> {
private val writer = memory.writer() private val writer: MemoryWriter = memory.writer()
override fun set(index: Int, value: T) = writer.write(spec, spec.objectSize * index, value) override fun set(index: Int, value: T): Unit = writer.write(spec, spec.objectSize * index, value)
override fun copy(): MutableBuffer<T> = MutableMemoryBuffer(memory.copy(), spec) override fun copy(): MutableBuffer<T> = MutableMemoryBuffer(memory.copy(), spec)
companion object { companion object {
fun <T : Any> create(spec: MemorySpec<T>, size: Int) = fun <T : Any> create(spec: MemorySpec<T>, size: Int): MutableMemoryBuffer<T> =
MutableMemoryBuffer(Memory.allocate(size * spec.objectSize), spec) MutableMemoryBuffer(Memory.allocate(size * spec.objectSize), spec)
inline fun <T : Any> create( inline fun <T : Any> create(
spec: MemorySpec<T>, spec: MemorySpec<T>,
size: Int, size: Int,
crossinline initializer: (Int) -> T crossinline initializer: (Int) -> T
) = ): MutableMemoryBuffer<T> =
MutableMemoryBuffer(Memory.allocate(size * spec.objectSize), spec).also { buffer -> MutableMemoryBuffer(Memory.allocate(size * spec.objectSize), spec).also { buffer ->
(0 until size).forEach { (0 until size).forEach {
buffer[it] = initializer(it) buffer[it] = initializer(it)
} }
} }
} }
} }

View File

@ -56,7 +56,7 @@ interface NDAlgebra<T, C, N : NDStructure<T>> {
/** /**
* element-by-element invoke a function working on [T] on a [NDStructure] * element-by-element invoke a function working on [T] on a [NDStructure]
*/ */
operator fun Function1<T, T>.invoke(structure: N) = map(structure) { value -> this@invoke(value) } operator fun Function1<T, T>.invoke(structure: N): N = map(structure) { value -> this@invoke(value) }
companion object companion object
} }
@ -76,12 +76,12 @@ interface NDSpace<T, S : Space<T>, N : NDStructure<T>> : Space<N>, NDAlgebra<T,
override fun multiply(a: N, k: Number): N = map(a) { multiply(it, k) } override fun multiply(a: N, k: Number): N = map(a) { multiply(it, k) }
//TODO move to extensions after KEEP-176 //TODO move to extensions after KEEP-176
operator fun N.plus(arg: T) = map(this) { value -> add(arg, value) } operator fun N.plus(arg: T): N = map(this) { value -> add(arg, value) }
operator fun N.minus(arg: T) = map(this) { value -> add(arg, -value) } operator fun N.minus(arg: T): N = map(this) { value -> add(arg, -value) }
operator fun T.plus(arg: N) = map(arg) { value -> add(this@plus, value) } operator fun T.plus(arg: N): N = map(arg) { value -> add(this@plus, value) }
operator fun T.minus(arg: N) = map(arg) { value -> add(-this@minus, value) } operator fun T.minus(arg: N): N = map(arg) { value -> add(-this@minus, value) }
companion object companion object
} }
@ -97,20 +97,18 @@ interface NDRing<T, R : Ring<T>, N : NDStructure<T>> : Ring<N>, NDSpace<T, R, N>
override fun multiply(a: N, b: N): N = combine(a, b) { aValue, bValue -> multiply(aValue, bValue) } override fun multiply(a: N, b: N): N = combine(a, b) { aValue, bValue -> multiply(aValue, bValue) }
//TODO move to extensions after KEEP-176 //TODO move to extensions after KEEP-176
operator fun N.times(arg: T) = map(this) { value -> multiply(arg, value) } operator fun N.times(arg: T): N = map(this) { value -> multiply(arg, value) }
operator fun T.times(arg: N) = map(arg) { value -> multiply(this@times, value) } operator fun T.times(arg: N): N = map(arg) { value -> multiply(this@times, value) }
companion object companion object
} }
/** /**
* Field for n-dimensional structures. * Field for n-dimensional structures.
* @param shape - the list of dimensions of the array *
* @param elementField - operations field defined on individual array element
* @param T - the type of the element contained in ND structure * @param T - the type of the element contained in ND structure
* @param F - field of structure elements * @param F - field of structure elements
* @param R - actual nd-element type of this field
*/ */
interface NDField<T, F : Field<T>, N : NDStructure<T>> : Field<N>, NDRing<T, F, N> { interface NDField<T, F : Field<T>, N : NDStructure<T>> : Field<N>, NDRing<T, F, N> {
@ -120,9 +118,9 @@ interface NDField<T, F : Field<T>, N : NDStructure<T>> : Field<N>, NDRing<T, F,
override fun divide(a: N, b: N): N = combine(a, b) { aValue, bValue -> divide(aValue, bValue) } override fun divide(a: N, b: N): N = combine(a, b) { aValue, bValue -> divide(aValue, bValue) }
//TODO move to extensions after KEEP-176 //TODO move to extensions after KEEP-176
operator fun N.div(arg: T) = map(this) { value -> divide(arg, value) } operator fun N.div(arg: T): N = map(this) { value -> divide(arg, value) }
operator fun T.div(arg: N) = map(arg) { divide(it, this@div) } operator fun T.div(arg: N): N = map(arg) { divide(it, this@div) }
companion object { companion object {
@ -131,7 +129,7 @@ interface NDField<T, F : Field<T>, N : NDStructure<T>> : Field<N>, NDRing<T, F,
/** /**
* Create a nd-field for [Double] values or pull it from cache if it was created previously * Create a nd-field for [Double] values or pull it from cache if it was created previously
*/ */
fun real(vararg shape: Int) = realNDFieldCache.getOrPut(shape) { RealNDField(shape) } fun real(vararg shape: Int): RealNDField = realNDFieldCache.getOrPut(shape) { RealNDField(shape) }
/** /**
* Create a nd-field with boxing generic buffer * Create a nd-field with boxing generic buffer
@ -140,7 +138,7 @@ interface NDField<T, F : Field<T>, N : NDStructure<T>> : Field<N>, NDRing<T, F,
field: F, field: F,
vararg shape: Int, vararg shape: Int,
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing bufferFactory: BufferFactory<T> = Buffer.Companion::boxing
) = BoxingNDField(shape, field, bufferFactory) ): BoxingNDField<T, F> = BoxingNDField(shape, field, bufferFactory)
/** /**
* Create a most suitable implementation for nd-field using reified class. * Create a most suitable implementation for nd-field using reified class.

View File

@ -23,19 +23,23 @@ interface NDElement<T, C, N : NDStructure<T>> : NDStructure<T> {
/** /**
* Create a optimized NDArray of doubles * Create a optimized NDArray of doubles
*/ */
fun real(shape: IntArray, initializer: RealField.(IntArray) -> Double = { 0.0 }) = fun real(shape: IntArray, initializer: RealField.(IntArray) -> Double = { 0.0 }): RealNDElement =
NDField.real(*shape).produce(initializer) NDField.real(*shape).produce(initializer)
fun real1D(dim: Int, initializer: (Int) -> Double = { _ -> 0.0 }) = fun real1D(dim: Int, initializer: (Int) -> Double = { _ -> 0.0 }): RealNDElement =
real(intArrayOf(dim)) { initializer(it[0]) } real(intArrayOf(dim)) { initializer(it[0]) }
fun real2D(dim1: Int, dim2: Int, initializer: (Int, Int) -> Double = { _, _ -> 0.0 }) = fun real2D(dim1: Int, dim2: Int, initializer: (Int, Int) -> Double = { _, _ -> 0.0 }): RealNDElement =
real(intArrayOf(dim1, dim2)) { initializer(it[0], it[1]) } real(intArrayOf(dim1, dim2)) { initializer(it[0], it[1]) }
fun real3D(dim1: Int, dim2: Int, dim3: Int, initializer: (Int, Int, Int) -> Double = { _, _, _ -> 0.0 }) = fun real3D(
real(intArrayOf(dim1, dim2, dim3)) { initializer(it[0], it[1], it[2]) } dim1: Int,
dim2: Int,
dim3: Int,
initializer: (Int, Int, Int) -> Double = { _, _, _ -> 0.0 }
): RealNDElement = real(intArrayOf(dim1, dim2, dim3)) { initializer(it[0], it[1], it[2]) }
/** /**
@ -62,16 +66,17 @@ interface NDElement<T, C, N : NDStructure<T>> : NDStructure<T> {
} }
fun <T, C, N : NDStructure<T>> NDElement<T, C, N>.mapIndexed(transform: C.(index: IntArray, T) -> T) = fun <T, C, N : NDStructure<T>> NDElement<T, C, N>.mapIndexed(transform: C.(index: IntArray, T) -> T): NDElement<T, C, N> =
context.mapIndexed(unwrap(), transform).wrap() context.mapIndexed(unwrap(), transform).wrap()
fun <T, C, N : NDStructure<T>> NDElement<T, C, N>.map(transform: C.(T) -> T) = context.map(unwrap(), transform).wrap() fun <T, C, N : NDStructure<T>> NDElement<T, C, N>.map(transform: C.(T) -> T): NDElement<T, C, N> =
context.map(unwrap(), transform).wrap()
/** /**
* Element by element application of any operation on elements to the whole [NDElement] * Element by element application of any operation on elements to the whole [NDElement]
*/ */
operator fun <T, C, N : NDStructure<T>> Function1<T, T>.invoke(ndElement: NDElement<T, C, N>) = operator fun <T, C, N : NDStructure<T>> Function1<T, T>.invoke(ndElement: NDElement<T, C, N>): NDElement<T, C, N> =
ndElement.map { value -> this@invoke(value) } ndElement.map { value -> this@invoke(value) }
/* plus and minus */ /* plus and minus */
@ -79,13 +84,13 @@ operator fun <T, C, N : NDStructure<T>> Function1<T, T>.invoke(ndElement: NDElem
/** /**
* Summation operation for [NDElement] and single element * Summation operation for [NDElement] and single element
*/ */
operator fun <T, S : Space<T>, N : NDStructure<T>> NDElement<T, S, N>.plus(arg: T) = operator fun <T, S : Space<T>, N : NDStructure<T>> NDElement<T, S, N>.plus(arg: T): NDElement<T, S, N> =
map { value -> arg + value } map { value -> arg + value }
/** /**
* Subtraction operation between [NDElement] and single element * Subtraction operation between [NDElement] and single element
*/ */
operator fun <T, S : Space<T>, N : NDStructure<T>> NDElement<T, S, N>.minus(arg: T) = operator fun <T, S : Space<T>, N : NDStructure<T>> NDElement<T, S, N>.minus(arg: T): NDElement<T, S, N> =
map { value -> arg - value } map { value -> arg - value }
/* prod and div */ /* prod and div */
@ -93,13 +98,13 @@ operator fun <T, S : Space<T>, N : NDStructure<T>> NDElement<T, S, N>.minus(arg:
/** /**
* Product operation for [NDElement] and single element * Product operation for [NDElement] and single element
*/ */
operator fun <T, R : Ring<T>, N : NDStructure<T>> NDElement<T, R, N>.times(arg: T) = operator fun <T, R : Ring<T>, N : NDStructure<T>> NDElement<T, R, N>.times(arg: T): NDElement<T, R, N> =
map { value -> arg * value } map { value -> arg * value }
/** /**
* Division operation between [NDElement] and single element * Division operation between [NDElement] and single element
*/ */
operator fun <T, F : Field<T>, N : NDStructure<T>> NDElement<T, F, N>.div(arg: T) = operator fun <T, F : Field<T>, N : NDStructure<T>> NDElement<T, F, N>.div(arg: T): NDElement<T, F, N> =
map { value -> arg / value } map { value -> arg / value }

View File

@ -8,7 +8,7 @@ interface NDStructure<T> {
val shape: IntArray val shape: IntArray
val dimension get() = shape.size val dimension: Int get() = shape.size
operator fun get(index: IntArray): T operator fun get(index: IntArray): T
@ -44,32 +44,49 @@ interface NDStructure<T> {
strides: Strides, strides: Strides,
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing, bufferFactory: BufferFactory<T> = Buffer.Companion::boxing,
initializer: (IntArray) -> T initializer: (IntArray) -> T
) = ): BufferNDStructure<T> =
BufferNDStructure(strides, bufferFactory(strides.linearSize) { i -> initializer(strides.index(i)) }) BufferNDStructure(strides, bufferFactory(strides.linearSize) { i -> initializer(strides.index(i)) })
/** /**
* Inline create NDStructure with non-boxing buffer implementation if it is possible * Inline create NDStructure with non-boxing buffer implementation if it is possible
*/ */
inline fun <reified T : Any> auto(strides: Strides, crossinline initializer: (IntArray) -> T) = inline fun <reified T : Any> auto(
strides: Strides,
crossinline initializer: (IntArray) -> T
): BufferNDStructure<T> =
BufferNDStructure(strides, Buffer.auto(strides.linearSize) { i -> initializer(strides.index(i)) }) BufferNDStructure(strides, Buffer.auto(strides.linearSize) { i -> initializer(strides.index(i)) })
inline fun <T : Any> auto(type: KClass<T>, strides: Strides, crossinline initializer: (IntArray) -> T) = inline fun <T : Any> auto(
type: KClass<T>,
strides: Strides,
crossinline initializer: (IntArray) -> T
): BufferNDStructure<T> =
BufferNDStructure(strides, Buffer.auto(type, strides.linearSize) { i -> initializer(strides.index(i)) }) BufferNDStructure(strides, Buffer.auto(type, strides.linearSize) { i -> initializer(strides.index(i)) })
fun <T> build( fun <T> build(
shape: IntArray, shape: IntArray,
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing, bufferFactory: BufferFactory<T> = Buffer.Companion::boxing,
initializer: (IntArray) -> T initializer: (IntArray) -> T
) = build(DefaultStrides(shape), bufferFactory, initializer) ): BufferNDStructure<T> = build(DefaultStrides(shape), bufferFactory, initializer)
inline fun <reified T : Any> auto(shape: IntArray, crossinline initializer: (IntArray) -> T) = inline fun <reified T : Any> auto(
shape: IntArray,
crossinline initializer: (IntArray) -> T
): BufferNDStructure<T> =
auto(DefaultStrides(shape), initializer) auto(DefaultStrides(shape), initializer)
@JvmName("autoVarArg") @JvmName("autoVarArg")
inline fun <reified T : Any> auto(vararg shape: Int, crossinline initializer: (IntArray) -> T) = inline fun <reified T : Any> auto(
vararg shape: Int,
crossinline initializer: (IntArray) -> T
): BufferNDStructure<T> =
auto(DefaultStrides(shape), initializer) auto(DefaultStrides(shape), initializer)
inline fun <T : Any> auto(type: KClass<T>, vararg shape: Int, crossinline initializer: (IntArray) -> T) = inline fun <T : Any> auto(
type: KClass<T>,
vararg shape: Int,
crossinline initializer: (IntArray) -> T
): BufferNDStructure<T> =
auto(type, DefaultStrides(shape), initializer) auto(type, DefaultStrides(shape), initializer)
} }
} }
@ -128,7 +145,7 @@ class DefaultStrides private constructor(override val shape: IntArray) : Strides
/** /**
* Strides for memory access * Strides for memory access
*/ */
override val strides by lazy { override val strides: List<Int> by lazy {
sequence { sequence {
var current = 1 var current = 1
yield(1) yield(1)
@ -238,7 +255,7 @@ inline fun <T, reified R : Any> NDStructure<T>.mapToBuffer(
} }
/** /**
* Mutable ND buffer based on linear [autoBuffer] * Mutable ND buffer based on linear [MutableBuffer].
*/ */
class MutableBufferNDStructure<T>( class MutableBufferNDStructure<T>(
override val strides: Strides, override val strides: Strides,
@ -251,7 +268,7 @@ class MutableBufferNDStructure<T>(
} }
} }
override fun set(index: IntArray, value: T) = buffer.set(strides.offset(index), value) override fun set(index: IntArray, value: T): Unit = buffer.set(strides.offset(index), value)
} }
inline fun <reified T : Any> NDStructure<T>.combine( inline fun <reified T : Any> NDStructure<T>.combine(
@ -260,4 +277,4 @@ inline fun <reified T : Any> NDStructure<T>.combine(
): NDStructure<T> { ): NDStructure<T> {
if (!this.shape.contentEquals(struct.shape)) error("Shape mismatch in structure combination") if (!this.shape.contentEquals(struct.shape)) error("Shape mismatch in structure combination")
return NDStructure.auto(shape) { block(this[it], struct[it]) } return NDStructure.auto(shape) { block(this[it], struct[it]) }
} }

View File

@ -9,7 +9,7 @@ inline class RealBuffer(val array: DoubleArray) : MutableBuffer<Double> {
array[index] = value array[index] = value
} }
override fun iterator() = array.iterator() override fun iterator(): DoubleIterator = array.iterator()
override fun copy(): MutableBuffer<Double> = override fun copy(): MutableBuffer<Double> =
RealBuffer(array.copyOf()) RealBuffer(array.copyOf())
@ -31,4 +31,4 @@ val MutableBuffer<out Double>.array: DoubleArray
DoubleArray(size) { get(it) } DoubleArray(size) { get(it) }
} }
fun DoubleArray.asBuffer() = RealBuffer(this) fun DoubleArray.asBuffer(): RealBuffer = RealBuffer(this)

View File

@ -12,8 +12,8 @@ class RealNDField(override val shape: IntArray) :
override val strides: Strides = DefaultStrides(shape) override val strides: Strides = DefaultStrides(shape)
override val elementContext: RealField get() = RealField override val elementContext: RealField get() = RealField
override val zero by lazy { produce { zero } } override val zero: RealNDElement by lazy { produce { zero } }
override val one by lazy { produce { one } } override val one: RealNDElement by lazy { produce { one } }
inline fun buildBuffer(size: Int, crossinline initializer: (Int) -> Double): Buffer<Double> = inline fun buildBuffer(size: Int, crossinline initializer: (Int) -> Double): Buffer<Double> =
RealBuffer(DoubleArray(size) { initializer(it) }) RealBuffer(DoubleArray(size) { initializer(it) })
@ -64,15 +64,15 @@ class RealNDField(override val shape: IntArray) :
override fun NDBuffer<Double>.toElement(): FieldElement<NDBuffer<Double>, *, out BufferedNDField<Double, RealField>> = override fun NDBuffer<Double>.toElement(): FieldElement<NDBuffer<Double>, *, out BufferedNDField<Double, RealField>> =
BufferedNDFieldElement(this@RealNDField, buffer) BufferedNDFieldElement(this@RealNDField, buffer)
override fun power(arg: NDBuffer<Double>, pow: Number) = map(arg) { power(it, pow) } override fun power(arg: NDBuffer<Double>, pow: Number): RealNDElement = map(arg) { power(it, pow) }
override fun exp(arg: NDBuffer<Double>) = map(arg) { exp(it) } override fun exp(arg: NDBuffer<Double>): RealNDElement = map(arg) { exp(it) }
override fun ln(arg: NDBuffer<Double>) = map(arg) { ln(it) } override fun ln(arg: NDBuffer<Double>): RealNDElement = map(arg) { ln(it) }
override fun sin(arg: NDBuffer<Double>) = map(arg) { sin(it) } override fun sin(arg: NDBuffer<Double>): RealNDElement = map(arg) { sin(it) }
override fun cos(arg: NDBuffer<Double>) = map(arg) { cos(it) } override fun cos(arg: NDBuffer<Double>): RealNDElement = map(arg) { cos(it) }
override fun tan(arg: NDBuffer<Double>): NDBuffer<Double> = map(arg) { tan(it) } override fun tan(arg: NDBuffer<Double>): NDBuffer<Double> = map(arg) { tan(it) }
@ -95,7 +95,7 @@ inline fun BufferedNDField<Double, RealField>.produceInline(crossinline initiali
/** /**
* Map one [RealNDElement] using function with indexes * Map one [RealNDElement] using function with indexes
*/ */
inline fun RealNDElement.mapIndexed(crossinline transform: RealField.(index: IntArray, Double) -> Double) = inline fun RealNDElement.mapIndexed(crossinline transform: RealField.(index: IntArray, Double) -> Double): RealNDElement =
context.produceInline { offset -> transform(strides.index(offset), buffer[offset]) } context.produceInline { offset -> transform(strides.index(offset), buffer[offset]) }
/** /**
@ -107,9 +107,9 @@ inline fun RealNDElement.map(crossinline transform: RealField.(Double) -> Double
} }
/** /**
* Element by element application of any operation on elements to the whole array. Just like in numpy * Element by element application of any operation on elements to the whole array. Just like in numpy.
*/ */
operator fun Function1<Double, Double>.invoke(ndElement: RealNDElement) = operator fun Function1<Double, Double>.invoke(ndElement: RealNDElement): RealNDElement =
ndElement.map { this@invoke(it) } ndElement.map { this@invoke(it) }
@ -118,13 +118,13 @@ operator fun Function1<Double, Double>.invoke(ndElement: RealNDElement) =
/** /**
* Summation operation for [BufferedNDElement] and single element * Summation operation for [BufferedNDElement] and single element
*/ */
operator fun RealNDElement.plus(arg: Double) = operator fun RealNDElement.plus(arg: Double): RealNDElement =
map { it + arg } map { it + arg }
/** /**
* Subtraction operation between [BufferedNDElement] and single element * Subtraction operation between [BufferedNDElement] and single element
*/ */
operator fun RealNDElement.minus(arg: Double) = operator fun RealNDElement.minus(arg: Double): RealNDElement =
map { it - arg } map { it - arg }
/** /**
@ -132,4 +132,4 @@ operator fun RealNDElement.minus(arg: Double) =
*/ */
inline fun <R> RealField.nd(vararg shape: Int, action: RealNDField.() -> R): R { inline fun <R> RealField.nd(vararg shape: Int, action: RealNDField.() -> R): R {
return NDField.real(*shape).run(action) return NDField.real(*shape).run(action)
} }

View File

@ -9,12 +9,11 @@ inline class ShortBuffer(val array: ShortArray) : MutableBuffer<Short> {
array[index] = value array[index] = value
} }
override fun iterator() = array.iterator() override fun iterator(): ShortIterator = array.iterator()
override fun copy(): MutableBuffer<Short> = override fun copy(): MutableBuffer<Short> =
ShortBuffer(array.copyOf()) ShortBuffer(array.copyOf())
} }
fun ShortArray.asBuffer(): ShortBuffer = ShortBuffer(this)
fun ShortArray.asBuffer() = ShortBuffer(this)

View File

@ -12,8 +12,8 @@ class ShortNDRing(override val shape: IntArray) :
override val strides: Strides = DefaultStrides(shape) override val strides: Strides = DefaultStrides(shape)
override val elementContext: ShortRing get() = ShortRing override val elementContext: ShortRing get() = ShortRing
override val zero by lazy { produce { ShortRing.zero } } override val zero: ShortNDElement by lazy { produce { zero } }
override val one by lazy { produce { ShortRing.one } } override val one: ShortNDElement by lazy { produce { one } }
inline fun buildBuffer(size: Int, crossinline initializer: (Int) -> Short): Buffer<Short> = inline fun buildBuffer(size: Int, crossinline initializer: (Int) -> Short): Buffer<Short> =
ShortBuffer(ShortArray(size) { initializer(it) }) ShortBuffer(ShortArray(size) { initializer(it) })
@ -40,6 +40,7 @@ class ShortNDRing(override val shape: IntArray) :
transform: ShortRing.(index: IntArray, Short) -> Short transform: ShortRing.(index: IntArray, Short) -> Short
): ShortNDElement { ): ShortNDElement {
check(arg) check(arg)
return BufferedNDRingElement( return BufferedNDRingElement(
this, this,
buildBuffer(arg.strides.linearSize) { offset -> buildBuffer(arg.strides.linearSize) { offset ->
@ -67,7 +68,7 @@ class ShortNDRing(override val shape: IntArray) :
/** /**
* Fast element production using function inlining * Fast element production using function inlining.
*/ */
inline fun BufferedNDRing<Short, ShortRing>.produceInline(crossinline initializer: ShortRing.(Int) -> Short): ShortNDElement { inline fun BufferedNDRing<Short, ShortRing>.produceInline(crossinline initializer: ShortRing.(Int) -> Short): ShortNDElement {
val array = ShortArray(strides.linearSize) { offset -> ShortRing.initializer(offset) } val array = ShortArray(strides.linearSize) { offset -> ShortRing.initializer(offset) }
@ -75,22 +76,22 @@ inline fun BufferedNDRing<Short, ShortRing>.produceInline(crossinline initialize
} }
/** /**
* Element by element application of any operation on elements to the whole array. Just like in numpy * Element by element application of any operation on elements to the whole array.
*/ */
operator fun Function1<Short, Short>.invoke(ndElement: ShortNDElement) = operator fun Function1<Short, Short>.invoke(ndElement: ShortNDElement): ShortNDElement =
ndElement.context.produceInline { i -> invoke(ndElement.buffer[i]) } ndElement.context.produceInline { i -> invoke(ndElement.buffer[i]) }
/* plus and minus */ /* plus and minus */
/** /**
* Summation operation for [StridedNDFieldElement] and single element * Summation operation for [ShortNDElement] and single element.
*/ */
operator fun ShortNDElement.plus(arg: Short) = operator fun ShortNDElement.plus(arg: Short): ShortNDElement =
context.produceInline { i -> (buffer[i] + arg).toShort() } context.produceInline { i -> (buffer[i] + arg).toShort() }
/** /**
* Subtraction operation between [StridedNDFieldElement] and single element * Subtraction operation between [ShortNDElement] and single element.
*/ */
operator fun ShortNDElement.minus(arg: Short) = operator fun ShortNDElement.minus(arg: Short): ShortNDElement =
context.produceInline { i -> (buffer[i] - arg).toShort() } context.produceInline { i -> (buffer[i] - arg).toShort() }

View File

@ -17,7 +17,7 @@ interface Structure1D<T> : NDStructure<T>, Buffer<T> {
/** /**
* A 1D wrapper for nd-structure * A 1D wrapper for nd-structure
*/ */
private inline class Structure1DWrapper<T>(val structure: NDStructure<T>) : Structure1D<T>{ private inline class Structure1DWrapper<T>(val structure: NDStructure<T>) : Structure1D<T> {
override val shape: IntArray get() = structure.shape override val shape: IntArray get() = structure.shape
override val size: Int get() = structure.shape[0] override val size: Int get() = structure.shape[0]
@ -39,14 +39,14 @@ private inline class Buffer1DWrapper<T>(val buffer: Buffer<T>) : Structure1D<T>
override fun elements(): Sequence<Pair<IntArray, T>> = override fun elements(): Sequence<Pair<IntArray, T>> =
asSequence().mapIndexed { index, value -> intArrayOf(index) to value } asSequence().mapIndexed { index, value -> intArrayOf(index) to value }
override fun get(index: Int): T = buffer.get(index) override fun get(index: Int): T = buffer[index]
} }
/** /**
* Represent a [NDStructure] as [Structure1D]. Throw error in case of dimension mismatch * Represent a [NDStructure] as [Structure1D]. Throw error in case of dimension mismatch
*/ */
fun <T> NDStructure<T>.as1D(): Structure1D<T> = if (shape.size == 1) { fun <T> NDStructure<T>.as1D(): Structure1D<T> = if (shape.size == 1) {
if( this is NDBuffer){ if (this is NDBuffer) {
Buffer1DWrapper(this.buffer) Buffer1DWrapper(this.buffer)
} else { } else {
Structure1DWrapper(this) Structure1DWrapper(this)
@ -59,4 +59,4 @@ fun <T> NDStructure<T>.as1D(): Structure1D<T> = if (shape.size == 1) {
/** /**
* Represent this buffer as 1D structure * Represent this buffer as 1D structure
*/ */
fun <T> Buffer<T>.asND(): Structure1D<T> = Buffer1DWrapper(this) fun <T> Buffer<T>.asND(): Structure1D<T> = Buffer1DWrapper(this)

View File

@ -32,9 +32,7 @@ interface Structure2D<T> : NDStructure<T> {
} }
} }
companion object { companion object
}
} }
/** /**
@ -57,4 +55,4 @@ fun <T> NDStructure<T>.as2D(): Structure2D<T> = if (shape.size == 2) {
error("Can't create 2d-structure from ${shape.size}d-structure") error("Can't create 2d-structure from ${shape.size}d-structure")
} }
typealias Matrix<T> = Structure2D<T> typealias Matrix<T> = Structure2D<T>

View File

@ -31,7 +31,7 @@ class ExpressionFieldTest {
@Test @Test
fun separateContext() { fun separateContext() {
fun <T> FunctionalExpressionField<T,*>.expression(): Expression<T> { fun <T> FunctionalExpressionField<T, *>.expression(): Expression<T> {
val x = variable("x") val x = variable("x")
return x * x + 2 * x + one return x * x + 2 * x + one
} }
@ -42,7 +42,7 @@ class ExpressionFieldTest {
@Test @Test
fun valueExpression() { fun valueExpression() {
val expressionBuilder: FunctionalExpressionField<Double,*>.() -> Expression<Double> = { val expressionBuilder: FunctionalExpressionField<Double, *>.() -> Expression<Double> = {
val x = variable("x") val x = variable("x")
x * x + 2 * x + one x * x + 2 * x + one
} }
@ -50,4 +50,4 @@ class ExpressionFieldTest {
val expression = FunctionalExpressionField(RealField).expressionBuilder() val expression = FunctionalExpressionField(RealField).expressionBuilder()
assertEquals(expression("x" to 1.0), 4.0) assertEquals(expression("x" to 1.0), 4.0)
} }
} }

View File

@ -49,17 +49,17 @@ class MatrixTest {
@Test @Test
fun test2DDot() { fun test2DDot() {
val firstMatrix = NDStructure.auto(2,3){ (i, j) -> (i + j).toDouble() }.as2D() val firstMatrix = NDStructure.auto(2, 3) { (i, j) -> (i + j).toDouble() }.as2D()
val secondMatrix = NDStructure.auto(3,2){ (i, j) -> (i + j).toDouble() }.as2D() val secondMatrix = NDStructure.auto(3, 2) { (i, j) -> (i + j).toDouble() }.as2D()
MatrixContext.real.run { MatrixContext.real.run {
// val firstMatrix = produce(2, 3) { i, j -> (i + j).toDouble() } // val firstMatrix = produce(2, 3) { i, j -> (i + j).toDouble() }
// val secondMatrix = produce(3, 2) { i, j -> (i + j).toDouble() } // val secondMatrix = produce(3, 2) { i, j -> (i + j).toDouble() }
val result = firstMatrix dot secondMatrix val result = firstMatrix dot secondMatrix
assertEquals(2, result.rowNum) assertEquals(2, result.rowNum)
assertEquals(2, result.colNum) assertEquals(2, result.colNum)
assertEquals(8.0, result[0,1]) assertEquals(8.0, result[0, 1])
assertEquals(8.0, result[1,0]) assertEquals(8.0, result[1, 0])
assertEquals(14.0, result[1,1]) assertEquals(14.0, result[1, 1])
} }
} }
} }

View File

@ -48,4 +48,4 @@ class RealLUSolverTest {
assertEquals(expected, inverted) assertEquals(expected, inverted)
} }
} }

View File

@ -8,10 +8,10 @@ import kotlin.test.assertEquals
import kotlin.test.assertTrue import kotlin.test.assertTrue
class AutoDiffTest { class AutoDiffTest {
fun Variable(int: Int): Variable<Double> = Variable(int.toDouble())
fun Variable(int: Int) = Variable(int.toDouble()) fun deriv(body: AutoDiffField<Double, RealField>.() -> Variable<Double>): DerivationResult<Double> =
RealField.deriv(body)
fun deriv(body: AutoDiffField<Double, RealField>.() -> Variable<Double>) = RealField.deriv(body)
@Test @Test
fun testPlusX2() { fun testPlusX2() {
@ -178,5 +178,4 @@ class AutoDiffTest {
private fun assertApprox(a: Double, b: Double) { private fun assertApprox(a: Double, b: Double) {
if ((a - b) > 1e-10) assertEquals(a, b) if ((a - b) > 1e-10) assertEquals(a, b)
} }
}
}

View File

@ -10,4 +10,4 @@ class CumulativeKtTest {
val cumulative = initial.cumulativeSum() val cumulative = initial.cumulativeSum()
assertEquals(listOf(-1.0, 1.0, 2.0, 3.0), cumulative) assertEquals(listOf(-1.0, 1.0, 2.0, 3.0), cumulative)
} }
} }

View File

@ -47,4 +47,3 @@ class BigIntAlgebraTest {
} }
} }

View File

@ -19,8 +19,8 @@ class BigIntConstructorTest {
@Test @Test
fun testConstructor_0xffffffffaL() { fun testConstructor_0xffffffffaL() {
val x = -0xffffffffaL.toBigInt() val x = (-0xffffffffaL).toBigInt()
val y = uintArrayOf(0xfffffffaU, 0xfU).toBigInt(-1) val y = uintArrayOf(0xfffffffaU, 0xfU).toBigInt(-1)
assertEquals(x, y) assertEquals(x, y)
} }
} }

View File

@ -19,7 +19,7 @@ class BigIntConversionsTest {
@Test @Test
fun testToString_0x17ead2ffffd() { fun testToString_0x17ead2ffffd() {
val x = -0x17ead2ffffdL.toBigInt() val x = (-0x17ead2ffffdL).toBigInt()
assertEquals("-0x17ead2ffffd", x.toString()) assertEquals("-0x17ead2ffffd", x.toString())
} }
@ -40,4 +40,4 @@ class BigIntConversionsTest {
val x = "-7059135710711894913860".parseBigInteger() val x = "-7059135710711894913860".parseBigInteger()
assertEquals("-0x17ead2ffffd11223344", x.toString()) assertEquals("-0x17ead2ffffd11223344", x.toString())
} }
} }

View File

@ -31,7 +31,7 @@ class BigIntOperationsTest {
@Test @Test
fun testUnaryMinus() { fun testUnaryMinus() {
val x = 1234.toBigInt() val x = 1234.toBigInt()
val y = -1234.toBigInt() val y = (-1234).toBigInt()
assertEquals(-x, y) assertEquals(-x, y)
} }
@ -48,18 +48,18 @@ class BigIntOperationsTest {
@Test @Test
fun testMinus__2_1() { fun testMinus__2_1() {
val x = -2.toBigInt() val x = (-2).toBigInt()
val y = 1.toBigInt() val y = 1.toBigInt()
val res = x - y val res = x - y
val sum = -3.toBigInt() val sum = (-3).toBigInt()
assertEquals(sum, res) assertEquals(sum, res)
} }
@Test @Test
fun testMinus___2_1() { fun testMinus___2_1() {
val x = -2.toBigInt() val x = (-2).toBigInt()
val y = 1.toBigInt() val y = 1.toBigInt()
val res = -x - y val res = -x - y
@ -74,7 +74,7 @@ class BigIntOperationsTest {
val y = 0xffffffffaL.toBigInt() val y = 0xffffffffaL.toBigInt()
val res = x - y val res = x - y
val sum = -0xfffffcfc1L.toBigInt() val sum = (-0xfffffcfc1L).toBigInt()
assertEquals(sum, res) assertEquals(sum, res)
} }
@ -92,11 +92,11 @@ class BigIntOperationsTest {
@Test @Test
fun testMultiply__2_3() { fun testMultiply__2_3() {
val x = -2.toBigInt() val x = (-2).toBigInt()
val y = 3.toBigInt() val y = 3.toBigInt()
val res = x * y val res = x * y
val prod = -6.toBigInt() val prod = (-6).toBigInt()
assertEquals(prod, res) assertEquals(prod, res)
} }
@ -129,7 +129,7 @@ class BigIntOperationsTest {
val y = -0xfff456 val y = -0xfff456
val res = x * y val res = x * y
val prod = -0xffe579ad5dc2L.toBigInt() val prod = (-0xffe579ad5dc2L).toBigInt()
assertEquals(prod, res) assertEquals(prod, res)
} }
@ -259,7 +259,7 @@ class BigIntOperationsTest {
val y = -3 val y = -3
val res = x / y val res = x / y
val div = -6.toBigInt() val div = (-6).toBigInt()
assertEquals(div, res) assertEquals(div, res)
} }
@ -267,10 +267,10 @@ class BigIntOperationsTest {
@Test @Test
fun testBigDivision_20__3() { fun testBigDivision_20__3() {
val x = 20.toBigInt() val x = 20.toBigInt()
val y = -3.toBigInt() val y = (-3).toBigInt()
val res = x / y val res = x / y
val div = -6.toBigInt() val div = (-6).toBigInt()
assertEquals(div, res) assertEquals(div, res)
} }
@ -378,4 +378,4 @@ class BigIntOperationsTest {
return assertEquals(res, x % mod) return assertEquals(res, x % mod)
} }
} }

View File

@ -11,4 +11,4 @@ class RealFieldTest {
} }
assertEquals(5.0, sqrt) assertEquals(5.0, sqrt)
} }
} }

View File

@ -11,4 +11,4 @@ class ComplexBufferSpecTest {
val buffer = Buffer.complex(20) { Complex(it.toDouble(), -it.toDouble()) } val buffer = Buffer.complex(20) { Complex(it.toDouble(), -it.toDouble()) }
assertEquals(Complex(5.0, -5.0), buffer[5]) assertEquals(Complex(5.0, -5.0), buffer[5])
} }
} }

View File

@ -10,4 +10,4 @@ class NDFieldTest {
val ndArray = NDElement.real(intArrayOf(10, 10)) { (it[0] + it[1]).toDouble() } val ndArray = NDElement.real(intArrayOf(10, 10)) { (it[0] + it[1]).toDouble() }
assertEquals(ndArray[5, 5], 10.0) assertEquals(ndArray[5, 5], 10.0)
} }
} }

View File

@ -8,8 +8,8 @@ import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
class NumberNDFieldTest { class NumberNDFieldTest {
val array1 = real2D(3, 3) { i, j -> (i + j).toDouble() } val array1: RealNDElement = real2D(3, 3) { i, j -> (i + j).toDouble() }
val array2 = real2D(3, 3) { i, j -> (i - j).toDouble() } val array2: RealNDElement = real2D(3, 3) { i, j -> (i - j).toDouble() }
@Test @Test
fun testSum() { fun testSum() {