Merge pull request #380 from mipt-npm/commandertvis/contracts

Add contracts to some functions, fix multiple style issues
This commit is contained in:
Iaroslav Postovalov 2021-07-13 23:12:48 +07:00 committed by GitHub
commit ecd70f2139
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
73 changed files with 210 additions and 200 deletions

View File

@ -13,6 +13,8 @@ import space.kscience.kmath.jafama.JafamaDoubleField
import space.kscience.kmath.jafama.StrictJafamaDoubleField
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.invoke
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
import kotlin.random.Random
@State(Scope.Benchmark)
@ -31,9 +33,10 @@ internal class JafamaBenchmark {
fun strictJafama(blackhole: Blackhole) = invokeBenchmarks(blackhole) { x ->
StrictJafamaDoubleField { x * power(x, 4) * exp(x) / cos(x) + sin(x) }
}
}
private inline fun invokeBenchmarks(blackhole: Blackhole, expr: (Double) -> Double) {
contract { callsInPlace(expr, InvocationKind.AT_LEAST_ONCE) }
val rng = Random(0)
repeat(1000000) { blackhole.consume(expr(rng.nextDouble())) }
}
}

View File

@ -40,14 +40,14 @@ internal class MatrixInverseBenchmark {
@Benchmark
fun cmLUPInversion(blackhole: Blackhole) {
with(CMLinearSpace) {
CMLinearSpace {
blackhole.consume(inverse(matrix))
}
}
@Benchmark
fun ejmlInverse(blackhole: Blackhole) {
with(EjmlLinearSpaceDDRM) {
EjmlLinearSpaceDDRM {
blackhole.consume(matrix.getFeature<InverseMatrixFeature<Double>>()?.inverse)
}
}

View File

@ -115,6 +115,8 @@ via extension function.
Usually it is bad idea to compare the direct numerical operation performance in different languages, but it hard to
work completely without frame of reference. In this case, simple numpy code:
```python
import numpy as np
res = np.ones((1000,1000))
for i in range(1000):
res = res + 1.0

View File

@ -10,7 +10,7 @@ import space.kscience.kmath.ast.rendering.LatexSyntaxRenderer
import space.kscience.kmath.ast.rendering.MathMLSyntaxRenderer
import space.kscience.kmath.ast.rendering.renderWithStringBuilder
public fun main() {
fun main() {
val mst = "exp(sqrt(x))-asin(2*x)/(2e10+x^3)/(-12)".parseMath()
val syntax = FeaturedMathRendererWithPostProcess.Default.render(mst)
println("MathSyntax:")

View File

@ -13,7 +13,7 @@ import space.kscience.kmath.operations.*
import kotlin.reflect.KClass
/**
* Prints any [Symbol] as a [SymbolSyntax] containing the [Symbol.value] of it.
* Prints any [Symbol] as a [SymbolSyntax] containing the [Symbol.identity] of it.
*
* @author Iaroslav Postovalov
*/

View File

@ -11,7 +11,6 @@ import space.kscience.kmath.expressions.Symbol.Companion.x
import space.kscience.kmath.expressions.interpret
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.IntRing
import space.kscience.kmath.operations.bindSymbol
import space.kscience.kmath.operations.invoke
import kotlin.test.Test
import kotlin.test.assertEquals

View File

@ -9,7 +9,6 @@ import space.kscience.kmath.expressions.MstRing
import space.kscience.kmath.expressions.Symbol.Companion.x
import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.operations.IntRing
import space.kscience.kmath.operations.bindSymbol
import space.kscience.kmath.operations.invoke
import kotlin.test.Test
import kotlin.test.assertEquals

View File

@ -3,6 +3,8 @@
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
@file:Suppress("ClassName")
package space.kscience.kmath.internal.estree
import kotlin.js.RegExp

View File

@ -10,6 +10,8 @@ import space.kscience.kmath.expressions.MST
import space.kscience.kmath.expressions.Symbol
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.IntRing
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
import space.kscience.kmath.estree.compile as estreeCompile
import space.kscience.kmath.estree.compileToExpression as estreeCompileToExpression
import space.kscience.kmath.wasm.compile as wasmCompile
@ -34,6 +36,7 @@ private object ESTreeCompilerTestContext : CompilerTestContext {
}
internal actual inline fun runCompilerTest(action: CompilerTestContext.() -> Unit) {
contract { callsInPlace(action, InvocationKind.AT_LEAST_ONCE) }
action(WasmCompilerTestContext)
action(ESTreeCompilerTestContext)
}

View File

@ -11,7 +11,6 @@ import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.expressions.symbol
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.IntRing
import space.kscience.kmath.operations.bindSymbol
import space.kscience.kmath.operations.invoke
import kotlin.test.Test
import kotlin.test.assertEquals

View File

@ -10,6 +10,8 @@ import space.kscience.kmath.expressions.MST
import space.kscience.kmath.expressions.Symbol
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.IntRing
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
import space.kscience.kmath.asm.compile as asmCompile
import space.kscience.kmath.asm.compileToExpression as asmCompileToExpression
@ -22,4 +24,7 @@ private object AsmCompilerTestContext : CompilerTestContext {
asmCompile(algebra, arguments)
}
internal actual inline fun runCompilerTest(action: CompilerTestContext.() -> Unit) = action(AsmCompilerTestContext)
internal actual inline fun runCompilerTest(action: CompilerTestContext.() -> Unit) {
contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) }
action(AsmCompilerTestContext)
}

View File

@ -15,7 +15,7 @@ import space.kscience.kmath.operations.NumbersAddOperations
* A field over commons-math [DerivativeStructure].
*
* @property order The derivation order.
* @property bindings The map of bindings values. All bindings are considered free parameters
* @param bindings The map of bindings values. All bindings are considered free parameters
*/
@OptIn(UnstableKMathAPI::class)
public class DerivativeStructureField(

View File

@ -52,11 +52,11 @@ public class CMOptimization(
public fun exportOptimizationData(): List<OptimizationData> = optimizationData.values.toList()
public override fun initialGuess(map: Map<Symbol, Double>): Unit {
public override fun initialGuess(map: Map<Symbol, Double>) {
addOptimizationData(InitialGuess(map.toDoubleArray()))
}
public override fun function(expression: Expression<Double>): Unit {
public override fun function(expression: Expression<Double>) {
val objectiveFunction = ObjectiveFunction {
val args = it.toMap()
expression(args)

View File

@ -32,7 +32,7 @@ public object Transformations {
/**
* Create a virtual buffer on top of array
*/
private fun Array<org.apache.commons.math3.complex.Complex>.asBuffer() = VirtualBuffer<Complex>(size) {
private fun Array<org.apache.commons.math3.complex.Complex>.asBuffer() = VirtualBuffer(size) {
val value = get(it)
Complex(value.real, value.imaginary)
}

View File

@ -16,7 +16,7 @@ internal inline fun diff(
order: Int,
vararg parameters: Pair<Symbol, Double>,
block: DerivativeStructureField.() -> Unit,
): Unit {
) {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
DerivativeStructureField(order, mapOf(*parameters)).run(block)
}

View File

@ -147,7 +147,7 @@ public object QuaternionField : Field<Quaternion>, Norm<Quaternion, Quaternion>,
return if (arg.w > 0)
Quaternion(ln(arg.w), 0, 0, 0)
else {
val l = ComplexField { ComplexField.ln(arg.w.toComplex()) }
val l = ComplexField { ln(arg.w.toComplex()) }
Quaternion(l.re, l.im, 0, 0)
}

View File

@ -16,7 +16,7 @@ import kotlin.math.max
* The buffer of X values.
*/
@UnstableKMathAPI
public interface XYColumnarData<T, out X : T, out Y : T> : ColumnarData<T> {
public interface XYColumnarData<out T, out X : T, out Y : T> : ColumnarData<T> {
/**
* The buffer of X values
*/

View File

@ -14,7 +14,7 @@ import space.kscience.kmath.structures.Buffer
* Inherits [XYColumnarData].
*/
@UnstableKMathAPI
public interface XYZColumnarData<T, out X : T, out Y : T, out Z : T> : XYColumnarData<T, X, Y> {
public interface XYZColumnarData<out T, out X : T, out Y : T, out Z : T> : XYColumnarData<T, X, Y> {
public val z: Buffer<Z>
override fun get(symbol: Symbol): Buffer<T>? = when (symbol) {

View File

@ -12,7 +12,7 @@ import space.kscience.kmath.linear.Point
*
* @param T the type of element of this domain.
*/
public interface Domain<T : Any> {
public interface Domain<in T : Any> {
/**
* Checks if the specified point is contained in this domain.
*/

View File

@ -9,7 +9,6 @@ package space.kscience.kmath.expressions
* Represents expression which structure can be differentiated.
*
* @param T the type this expression takes as argument and returns.
* @param R the type of expression this expression can be differentiated to.
*/
public interface DifferentiableExpression<T> : Expression<T> {
/**
@ -31,9 +30,11 @@ public fun <T> DifferentiableExpression<T>.derivative(name: String): Expression
derivative(StringSymbol(name))
/**
* A special type of [DifferentiableExpression] which returns typed expressions as derivatives
* A special type of [DifferentiableExpression] which returns typed expressions as derivatives.
*
* @param R the type of expression this expression can be differentiated to.
*/
public interface SpecialDifferentiableExpression<T, R: Expression<T>>: DifferentiableExpression<T> {
public interface SpecialDifferentiableExpression<T, out R : Expression<T>> : DifferentiableExpression<T> {
override fun derivativeOrNull(symbols: List<Symbol>): R?
}
@ -64,6 +65,6 @@ public abstract class FirstDerivativeExpression<T> : DifferentiableExpression<T>
/**
* A factory that converts an expression in autodiff variables to a [DifferentiableExpression]
*/
public fun interface AutoDiffProcessor<T : Any, I : Any, A : ExpressionAlgebra<T, I>, out R : Expression<T>> {
public fun interface AutoDiffProcessor<T : Any, I : Any, out A : ExpressionAlgebra<T, I>, out R : Expression<T>> {
public fun process(function: A.() -> I): DifferentiableExpression<T>
}

View File

@ -6,13 +6,15 @@
package space.kscience.kmath.expressions
import space.kscience.kmath.operations.*
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
/**
* A context class for [Expression] construction.
*
* @param algebra The algebra to provide for Expressions built.
*/
public abstract class FunctionalExpressionAlgebra<T, A : Algebra<T>>(
public abstract class FunctionalExpressionAlgebra<T, out A : Algebra<T>>(
public val algebra: A,
) : ExpressionAlgebra<T, Expression<T>> {
/**
@ -29,9 +31,6 @@ public abstract class FunctionalExpressionAlgebra<T, A : Algebra<T>>(
?: error("Symbol '$value' is not supported in $this")
}
/**
* Builds an Expression of dynamic call of binary operation [operation] on [left] and [right].
*/
public override fun binaryOperationFunction(operation: String): (left: Expression<T>, right: Expression<T>) -> Expression<T> =
{ left, right ->
Expression { arguments ->
@ -39,9 +38,6 @@ public abstract class FunctionalExpressionAlgebra<T, A : Algebra<T>>(
}
}
/**
* Builds an Expression of dynamic call of unary operation with name [operation] on [arg].
*/
public override fun unaryOperationFunction(operation: String): (arg: Expression<T>) -> Expression<T> = { arg ->
Expression { arguments -> algebra.unaryOperationFunction(operation)(arg.invoke(arguments)) }
}
@ -50,7 +46,7 @@ public abstract class FunctionalExpressionAlgebra<T, A : Algebra<T>>(
/**
* A context class for [Expression] construction for [Ring] algebras.
*/
public open class FunctionalExpressionGroup<T, A : Group<T>>(
public open class FunctionalExpressionGroup<T, out A : Group<T>>(
algebra: A,
) : FunctionalExpressionAlgebra<T, A>(algebra), Group<Expression<T>> {
public override val zero: Expression<T> get() = const(algebra.zero)
@ -84,7 +80,7 @@ public open class FunctionalExpressionGroup<T, A : Group<T>>(
}
public open class FunctionalExpressionRing<T, A : Ring<T>>(
public open class FunctionalExpressionRing<T, out A : Ring<T>>(
algebra: A,
) : FunctionalExpressionGroup<T, A>(algebra), Ring<Expression<T>> {
public override val one: Expression<T> get() = const(algebra.one)
@ -105,7 +101,7 @@ public open class FunctionalExpressionRing<T, A : Ring<T>>(
super<FunctionalExpressionGroup>.binaryOperationFunction(operation)
}
public open class FunctionalExpressionField<T, A : Field<T>>(
public open class FunctionalExpressionField<T, out A : Field<T>>(
algebra: A,
) : FunctionalExpressionRing<T, A>(algebra), Field<Expression<T>>, ScaleOperations<Expression<T>> {
/**
@ -131,7 +127,7 @@ public open class FunctionalExpressionField<T, A : Field<T>>(
super<FunctionalExpressionRing>.bindSymbolOrNull(value)
}
public open class FunctionalExpressionExtendedField<T, A : ExtendedField<T>>(
public open class FunctionalExpressionExtendedField<T, out A : ExtendedField<T>>(
algebra: A,
) : FunctionalExpressionField<T, A>(algebra), ExtendedField<Expression<T>> {
public override fun number(value: Number): Expression<T> = const(algebra.number(value))
@ -172,14 +168,26 @@ public open class FunctionalExpressionExtendedField<T, A : ExtendedField<T>>(
public override fun bindSymbol(value: String): Expression<T> = super<FunctionalExpressionField>.bindSymbol(value)
}
public inline fun <T, A : Ring<T>> A.expressionInSpace(block: FunctionalExpressionGroup<T, A>.() -> Expression<T>): Expression<T> =
FunctionalExpressionGroup(this).block()
public inline fun <T, A : Group<T>> A.expressionInGroup(
block: FunctionalExpressionGroup<T, A>.() -> Expression<T>,
): Expression<T> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return FunctionalExpressionGroup(this).block()
}
public inline fun <T, A : Ring<T>> A.expressionInRing(block: FunctionalExpressionRing<T, A>.() -> Expression<T>): Expression<T> =
FunctionalExpressionRing(this).block()
public inline fun <T, A : Ring<T>> A.expressionInRing(
block: FunctionalExpressionRing<T, A>.() -> Expression<T>,
): Expression<T> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return FunctionalExpressionRing(this).block()
}
public inline fun <T, A : Field<T>> A.expressionInField(block: FunctionalExpressionField<T, A>.() -> Expression<T>): Expression<T> =
FunctionalExpressionField(this).block()
public inline fun <T, A : Field<T>> A.expressionInField(
block: FunctionalExpressionField<T, A>.() -> Expression<T>,
): Expression<T> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return FunctionalExpressionField(this).block()
}
public inline fun <T, A : ExtendedField<T>> A.expressionInExtendedField(
block: FunctionalExpressionExtendedField<T, A>.() -> Expression<T>,

View File

@ -119,8 +119,6 @@ public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
get() = getDerivative(this)
set(value) = setDerivative(this, value)
public inline fun const(block: F.() -> T): AutoDiffValue<T> = const(context.block())
/**
* Performs update of derivative after the rest of the formula in the back-pass.
*
@ -194,6 +192,11 @@ public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
}
}
public inline fun <T : Any, F : Field<T>> SimpleAutoDiffField<T, F>.const(block: F.() -> T): AutoDiffValue<T> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return const(context.block())
}
/**
* Runs differentiation and establishes [SimpleAutoDiffField] context inside the block of code.
@ -208,7 +211,7 @@ public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
* assertEquals(9.0, x.d) // dy/dx
* ```
*
* @param body the action in [SimpleAutoDiffField] context returning [AutoDiffVariable] to differentiate with respect to.
* @param body the action in [SimpleAutoDiffField] context returning [AutoDiffValue] to differentiate with respect to.
* @return the result of differentiation.
*/
public fun <T : Any, F : Field<T>> F.simpleAutoDiff(

View File

@ -9,6 +9,8 @@ import space.kscience.kmath.linear.Point
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.nd.Structure2D
import space.kscience.kmath.structures.BufferFactory
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
import kotlin.jvm.JvmInline
/**
@ -65,9 +67,13 @@ public value class SimpleSymbolIndexer(override val symbols: List<Symbol>) : Sym
* Execute the block with symbol indexer based on given symbol order
*/
@UnstableKMathAPI
public inline fun <R> withSymbols(vararg symbols: Symbol, block: SymbolIndexer.() -> R): R =
with(SimpleSymbolIndexer(symbols.toList()), block)
public inline fun <R> withSymbols(vararg symbols: Symbol, block: SymbolIndexer.() -> R): R {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return with(SimpleSymbolIndexer(symbols.toList()), block)
}
@UnstableKMathAPI
public inline fun <R> withSymbols(symbols: Collection<Symbol>, block: SymbolIndexer.() -> R): R =
with(SimpleSymbolIndexer(symbols.toList()), block)
public inline fun <R> withSymbols(symbols: Collection<Symbol>, block: SymbolIndexer.() -> R): R {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return with(SimpleSymbolIndexer(symbols.toList()), block)
}

View File

@ -15,7 +15,7 @@ import space.kscience.kmath.structures.VirtualBuffer
import space.kscience.kmath.structures.indices
public class BufferedLinearSpace<T : Any, A : Ring<T>>(
public class BufferedLinearSpace<T : Any, out A : Ring<T>>(
override val elementAlgebra: A,
private val bufferFactory: BufferFactory<T>,
) : LinearSpace<T, A> {

View File

@ -32,7 +32,7 @@ public typealias Point<T> = Buffer<T>
* Basic operations on matrices and vectors. Operates on [Matrix].
*
* @param T the type of items in the matrices.
* @param M the type of operated matrices.
* @param A the type of ring over [T].
*/
public interface LinearSpace<T : Any, out A : Ring<T>> {
public val elementAlgebra: A

View File

@ -35,7 +35,7 @@ public object UnitFeature : DiagonalFeature
*
* @param T the type of matrices' items.
*/
public interface InverseMatrixFeature<T : Any> : MatrixFeature {
public interface InverseMatrixFeature<out T : Any> : MatrixFeature {
/**
* The inverse matrix of the matrix that owns this feature.
*/
@ -47,7 +47,7 @@ public interface InverseMatrixFeature<T : Any> : MatrixFeature {
*
* @param T the type of matrices' items.
*/
public interface DeterminantFeature<T : Any> : MatrixFeature {
public interface DeterminantFeature<out T : Any> : MatrixFeature {
/**
* The determinant of the matrix that owns this feature.
*/
@ -80,7 +80,7 @@ public object UFeature : MatrixFeature
*
* @param T the type of matrices' items.
*/
public interface LUDecompositionFeature<T : Any> : MatrixFeature {
public interface LUDecompositionFeature<out T : Any> : MatrixFeature {
/**
* The lower triangular matrix in this decomposition. It may have [LFeature].
*/
@ -98,7 +98,7 @@ public interface LUDecompositionFeature<T : Any> : MatrixFeature {
*
* @param T the type of matrices' items.
*/
public interface LupDecompositionFeature<T : Any> : MatrixFeature {
public interface LupDecompositionFeature<out T : Any> : MatrixFeature {
/**
* The lower triangular matrix in this decomposition. It may have [LFeature].
*/
@ -126,7 +126,7 @@ public object OrthogonalFeature : MatrixFeature
*
* @param T the type of matrices' items.
*/
public interface QRDecompositionFeature<T : Any> : MatrixFeature {
public interface QRDecompositionFeature<out T : Any> : MatrixFeature {
/**
* The orthogonal matrix in this decomposition. It may have [OrthogonalFeature].
*/
@ -144,7 +144,7 @@ public interface QRDecompositionFeature<T : Any> : MatrixFeature {
*
* @param T the type of matrices' items.
*/
public interface CholeskyDecompositionFeature<T : Any> : MatrixFeature {
public interface CholeskyDecompositionFeature<out T : Any> : MatrixFeature {
/**
* The triangular matrix in this decomposition. It may have either [UFeature] or [LFeature].
*/
@ -157,7 +157,7 @@ public interface CholeskyDecompositionFeature<T : Any> : MatrixFeature {
*
* @param T the type of matrices' items.
*/
public interface SingularValueDecompositionFeature<T : Any> : MatrixFeature {
public interface SingularValueDecompositionFeature<out T : Any> : MatrixFeature {
/**
* The matrix in this decomposition. It is unitary, and it consists from left singular vectors.
*/

View File

@ -34,7 +34,7 @@ public inline fun <T, R> Iterable<T>.cumulative(initial: R, crossinline operatio
public inline fun <T, R> Sequence<T>.cumulative(initial: R, crossinline operation: (R, T) -> R): Sequence<R> =
Sequence { this@cumulative.iterator().cumulative(initial, operation) }
public fun <T, R> List<T>.cumulative(initial: R, operation: (R, T) -> R): List<R> =
public inline fun <T, R> List<T>.cumulative(initial: R, crossinline operation: (R, T) -> R): List<R> =
iterator().cumulative(initial, operation).asSequence().toList()
//Cumulative sum

View File

@ -24,7 +24,6 @@ public class ShapeMismatchException(public val expected: IntArray, public val ac
*
* @param T the type of ND-structure element.
* @param C the type of the element context.
* @param N the type of the structure.
*/
public interface AlgebraND<T, out C : Algebra<T>> {
/**
@ -118,8 +117,7 @@ internal fun <T, C : Algebra<T>> AlgebraND<T, C>.checkShape(element: StructureND
* Space of [StructureND].
*
* @param T the type of the element contained in ND structure.
* @param N the type of ND structure.
* @param S the type of space of structure elements.
* @param S the type of group over structure elements.
*/
public interface GroupND<T, out S : Group<T>> : Group<StructureND<T>>, AlgebraND<T, S> {
/**
@ -186,8 +184,7 @@ public interface GroupND<T, out S : Group<T>> : Group<StructureND<T>>, AlgebraND
* Ring of [StructureND].
*
* @param T the type of the element contained in ND structure.
* @param N the type of ND structure.
* @param R the type of ring of structure elements.
* @param R the type of ring over structure elements.
*/
public interface RingND<T, out R : Ring<T>> : Ring<StructureND<T>>, GroupND<T, R> {
/**
@ -227,7 +224,7 @@ public interface RingND<T, out R : Ring<T>> : Ring<StructureND<T>>, GroupND<T, R
* Field of [StructureND].
*
* @param T the type of the element contained in ND structure.
* @param F the type field of structure elements.
* @param F the type field over structure elements.
*/
public interface FieldND<T, out F : Field<T>> : Field<StructureND<T>>, RingND<T, F>, ScaleOperations<StructureND<T>> {
/**

View File

@ -57,7 +57,7 @@ public interface BufferAlgebraND<T, out A : Algebra<T>> : AlgebraND<T, A> {
}
}
public open class BufferedGroupND<T, A : Group<T>>(
public open class BufferedGroupND<T, out A : Group<T>>(
final override val shape: IntArray,
final override val elementContext: A,
final override val bufferFactory: BufferFactory<T>,
@ -67,7 +67,7 @@ public open class BufferedGroupND<T, A : Group<T>>(
override fun StructureND<T>.unaryMinus(): StructureND<T> = produce { -get(it) }
}
public open class BufferedRingND<T, R : Ring<T>>(
public open class BufferedRingND<T, out R : Ring<T>>(
shape: IntArray,
elementContext: R,
bufferFactory: BufferFactory<T>,
@ -75,7 +75,7 @@ public open class BufferedRingND<T, R : Ring<T>>(
override val one: BufferND<T> by lazy { produce { one } }
}
public open class BufferedFieldND<T, R : Field<T>>(
public open class BufferedFieldND<T, out R : Field<T>>(
shape: IntArray,
elementContext: R,
bufferFactory: BufferFactory<T>,

View File

@ -31,9 +31,8 @@ public class ShortRingND(
/**
* Fast element production using function inlining.
*/
public inline fun BufferedRingND<Short, ShortRing>.produceInline(crossinline initializer: ShortRing.(Int) -> Short): BufferND<Short> {
return BufferND(strides, ShortBuffer(ShortArray(strides.linearSize) { offset -> ShortRing.initializer(offset) }))
}
public inline fun BufferedRingND<Short, ShortRing>.produceInline(crossinline initializer: ShortRing.(Int) -> Short): BufferND<Short> =
BufferND(strides, ShortBuffer(ShortArray(strides.linearSize) { offset -> ShortRing.initializer(offset) }))
public inline fun <R> ShortRing.nd(vararg shape: Int, action: ShortRingND.() -> R): R {
contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) }

View File

@ -53,7 +53,7 @@ public interface StructureND<out T> {
public fun elements(): Sequence<Pair<IntArray, T>>
/**
* Feature is some additional strucure information which allows to access it special properties or hints.
* Feature is some additional structure information which allows to access it special properties or hints.
* If the feature is not present, null is returned.
*/
@UnstableKMathAPI

View File

@ -15,7 +15,7 @@ import space.kscience.kmath.misc.UnstableKMathAPI
*/
@UnstableKMathAPI
@Deprecated("AlgebraElements are considered odd and will be removed in future releases.")
public interface AlgebraElement<T, C : Algebra<T>> {
public interface AlgebraElement<T, out C : Algebra<T>> {
/**
* The context this element belongs to.
*/

View File

@ -445,7 +445,7 @@ public fun UIntArray.toBigInt(sign: Byte): BigInt {
* Returns null if a valid number can not be read from a string
*/
public fun String.parseBigInteger(): BigInt? {
if (this.isEmpty()) return null
if (isEmpty()) return null
val sign: Int
val positivePartIndex = when (this[0]) {

View File

@ -14,30 +14,30 @@ import space.kscience.kmath.nd.as2D
* A context that allows to operate on a [MutableBuffer] as on 2d array
*/
internal class BufferAccessor2D<T : Any>(
public val rowNum: Int,
public val colNum: Int,
val rowNum: Int,
val colNum: Int,
val factory: MutableBufferFactory<T>,
) {
public operator fun Buffer<T>.get(i: Int, j: Int): T = get(i * colNum + j)
operator fun Buffer<T>.get(i: Int, j: Int): T = get(i * colNum + j)
public operator fun MutableBuffer<T>.set(i: Int, j: Int, value: T) {
operator fun MutableBuffer<T>.set(i: Int, j: Int, value: T) {
set(i * colNum + j, value)
}
public inline fun create(crossinline init: (i: Int, j: Int) -> T): MutableBuffer<T> =
inline fun create(crossinline init: (i: Int, j: Int) -> T): MutableBuffer<T> =
factory(rowNum * colNum) { offset -> init(offset / colNum, offset % colNum) }
public fun create(mat: Structure2D<T>): MutableBuffer<T> = create { i, j -> mat[i, j] }
fun create(mat: Structure2D<T>): MutableBuffer<T> = create { i, j -> mat[i, j] }
//TODO optimize wrapper
public fun MutableBuffer<T>.collect(): Structure2D<T> = StructureND.buffered(
fun MutableBuffer<T>.collect(): Structure2D<T> = StructureND.buffered(
DefaultStrides(intArrayOf(rowNum, colNum)),
factory
) { (i, j) ->
get(i, j)
}.as2D()
public inner class Row(public val buffer: MutableBuffer<T>, public val rowIndex: Int) : MutableBuffer<T> {
inner class Row(val buffer: MutableBuffer<T>, val rowIndex: Int) : MutableBuffer<T> {
override val size: Int get() = colNum
override operator fun get(index: Int): T = buffer[rowIndex, index]
@ -54,5 +54,5 @@ internal class BufferAccessor2D<T : Any>(
/**
* Get row
*/
public fun MutableBuffer<T>.row(i: Int): Row = Row(this, i)
fun MutableBuffer<T>.row(i: Int): Row = Row(this, i)
}

View File

@ -67,9 +67,9 @@ public inline fun <T : Any, reified R : Any> Buffer<T>.map(block: (T) -> R): Buf
* Create a new buffer from this one with the given mapping function.
* Provided [bufferFactory] is used to construct the new buffer.
*/
public fun <T : Any, R : Any> Buffer<T>.map(
public inline fun <T : Any, R : Any> Buffer<T>.map(
bufferFactory: BufferFactory<R>,
block: (T) -> R,
crossinline block: (T) -> R,
): Buffer<R> = bufferFactory(size) { block(get(it)) }
/**

View File

@ -1,5 +0,0 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/

View File

@ -10,7 +10,7 @@ import space.kscience.kmath.operations.invoke
import kotlin.test.assertEquals
import kotlin.test.assertNotEquals
internal class FieldVerifier<T, A : Field<T>>(
internal class FieldVerifier<T, out A : Field<T>>(
algebra: A, a: T, b: T, c: T, x: Number,
) : RingVerifier<T, A>(algebra, a, b, c, x) {

View File

@ -10,7 +10,7 @@ import space.kscience.kmath.operations.ScaleOperations
import space.kscience.kmath.operations.invoke
import kotlin.test.assertEquals
internal open class RingVerifier<T, A>(algebra: A, a: T, b: T, c: T, x: Number) :
internal open class RingVerifier<T, out A>(algebra: A, a: T, b: T, c: T, x: Number) :
SpaceVerifier<T, A>(algebra, a, b, c, x) where A : Ring<T>, A : ScaleOperations<T> {
override fun verify() {

View File

@ -15,30 +15,33 @@ public val Dispatchers.Math: CoroutineDispatcher
/**
* An imitator of [Deferred] which holds a suspended function block and dispatcher
*/
internal class LazyDeferred<T>(val dispatcher: CoroutineDispatcher, val block: suspend CoroutineScope.() -> T) {
@PublishedApi
internal class LazyDeferred<out T>(val dispatcher: CoroutineDispatcher, val block: suspend CoroutineScope.() -> T) {
private var deferred: Deferred<T>? = null
internal fun start(scope: CoroutineScope) {
fun start(scope: CoroutineScope) {
if (deferred == null) deferred = scope.async(dispatcher, block = block)
}
suspend fun await(): T = deferred?.await() ?: error("Coroutine not started")
}
public class AsyncFlow<T> internal constructor(internal val deferredFlow: Flow<LazyDeferred<T>>) : Flow<T> {
public class AsyncFlow<out T> @PublishedApi internal constructor(
@PublishedApi internal val deferredFlow: Flow<LazyDeferred<T>>,
) : Flow<T> {
override suspend fun collect(collector: FlowCollector<T>): Unit =
deferredFlow.collect { collector.emit((it.await())) }
}
public fun <T, R> Flow<T>.async(
public inline fun <T, R> Flow<T>.async(
dispatcher: CoroutineDispatcher = Dispatchers.Default,
block: suspend CoroutineScope.(T) -> R,
crossinline block: suspend CoroutineScope.(T) -> R,
): AsyncFlow<R> {
val flow = map { LazyDeferred(dispatcher) { block(it) } }
return AsyncFlow(flow)
}
public fun <T, R> AsyncFlow<T>.map(action: (T) -> R): AsyncFlow<R> =
public inline fun <T, R> AsyncFlow<T>.map(crossinline action: (T) -> R): AsyncFlow<R> =
AsyncFlow(deferredFlow.map { input ->
//TODO add function composition
LazyDeferred(input.dispatcher) {

View File

@ -17,7 +17,7 @@ import kotlin.jvm.JvmInline
/**
* A matrix with compile-time controlled dimension
*/
public interface DMatrix<T, R : Dimension, C : Dimension> : Structure2D<T> {
public interface DMatrix<out T, R : Dimension, C : Dimension> : Structure2D<T> {
public companion object {
/**
* Coerces a regular matrix to a matrix with type-safe dimensions and throws a error if coercion failed
@ -46,7 +46,7 @@ public interface DMatrix<T, R : Dimension, C : Dimension> : Structure2D<T> {
* An inline wrapper for a Matrix
*/
@JvmInline
public value class DMatrixWrapper<T, R : Dimension, C : Dimension>(
public value class DMatrixWrapper<out T, R : Dimension, C : Dimension>(
private val structure: Structure2D<T>,
) : DMatrix<T, R, C> {
override val shape: IntArray get() = structure.shape
@ -58,7 +58,7 @@ public value class DMatrixWrapper<T, R : Dimension, C : Dimension>(
/**
* Dimension-safe point
*/
public interface DPoint<T, D : Dimension> : Point<T> {
public interface DPoint<out T, D : Dimension> : Point<T> {
public companion object {
public inline fun <T, reified D : Dimension> coerce(point: Point<T>): DPoint<T, D> {
require(point.size == Dimension.dim<D>().toInt()) {
@ -76,7 +76,7 @@ public interface DPoint<T, D : Dimension> : Point<T> {
* Dimension-safe point wrapper
*/
@JvmInline
public value class DPointWrapper<T, D : Dimension>(public val point: Point<T>) :
public value class DPointWrapper<out T, D : Dimension>(public val point: Point<T>) :
DPoint<T, D> {
override val size: Int get() = point.size
@ -111,7 +111,7 @@ public value class DMatrixContext<T : Any, out A : Ring<T>>(public val context:
): DMatrix<T, R, C> {
val rows = Dimension.dim<R>()
val cols = Dimension.dim<C>()
return context.buildMatrix(rows.toInt(), cols.toInt(), initializer).coerce<R, C>()
return context.buildMatrix(rows.toInt(), cols.toInt(), initializer).coerce()
}
public inline fun <reified D : Dimension> point(noinline initializer: A.(Int) -> T): DPoint<T, D> {

View File

@ -32,7 +32,7 @@ internal class DoubleMatrixTest {
@Test
fun testSequenceToMatrix() {
val m = Sequence<DoubleArray> {
val m = Sequence {
listOf(
DoubleArray(10) { 10.0 },
DoubleArray(10) { 20.0 },

View File

@ -14,7 +14,7 @@ import space.kscience.kmath.operations.Ring
* @param T the piece key type.
* @param R the sub-function type.
*/
public fun interface Piecewise<T, R> {
public fun interface Piecewise<in T, out R> {
/**
* Returns the appropriate sub-function for given piece key.
*/
@ -23,7 +23,9 @@ public fun interface Piecewise<T, R> {
/**
* Represents piecewise-defined function where all the sub-functions are polynomials.
* @param pieces An ordered list of range-polynomial pairs. The list does not in general guarantee that there are no "holes" in it.
*
* @property pieces An ordered list of range-polynomial pairs. The list does not in general guarantee that there are no
* "holes" in it.
*/
public interface PiecewisePolynomial<T : Comparable<T>> : Piecewise<T, Polynomial<T>> {
public val pieces: Collection<Pair<ClosedRange<T>, Polynomial<T>>>
@ -45,7 +47,7 @@ public fun <T : Comparable<T>> PiecewisePolynomial(
/**
* An optimized piecewise which uses not separate pieces, but a range separated by delimiters.
* The pices search is logarithmic
* The pieces search is logarithmic
*/
private class OrderedPiecewisePolynomial<T : Comparable<T>>(
override val pieces: List<Pair<ClosedRange<T>, Polynomial<T>>>,

View File

@ -17,7 +17,7 @@ import kotlin.math.pow
*
* @param coefficients constant is the leftmost coefficient.
*/
public class Polynomial<T>(public val coefficients: List<T>) {
public class Polynomial<out T>(public val coefficients: List<T>) {
override fun toString(): String = "Polynomial$coefficients"
}
@ -69,7 +69,7 @@ public fun <T, A> Polynomial<T>.differentiate(
public fun <T, A> Polynomial<T>.integrate(
algebra: A,
): Polynomial<T> where A : Field<T>, A : NumericAlgebra<T> = algebra {
val integratedCoefficients = buildList<T>(coefficients.size + 1) {
val integratedCoefficients = buildList(coefficients.size + 1) {
add(zero)
coefficients.forEachIndexed{ index, t -> add(t / (number(index) + one)) }
}

View File

@ -18,7 +18,7 @@ public interface Integrand {
public inline fun <reified T : IntegrandFeature> Integrand.getFeature(): T? = getFeature(T::class)
public class IntegrandValue<T : Any>(public val value: T) : IntegrandFeature {
public class IntegrandValue<out T : Any>(public val value: T) : IntegrandFeature {
override fun toString(): String = "Value($value)"
}

View File

@ -93,7 +93,7 @@ public fun <T : Any> UnivariateIntegrator<T>.integrate(
): UnivariateIntegrand<T> = integrate(UnivariateIntegrand(function, IntegrationRange(range), *features))
/**
* A shortcut method to integrate a [function] in [range] with additional [features].
* A shortcut method to integrate a [function] in [range] with additional features.
* The [function] is placed in the end position to allow passing a lambda.
*/
@UnstableKMathAPI

View File

@ -18,7 +18,7 @@ import space.kscience.kmath.structures.asBuffer
/**
* And interpolator for data with x column type [X], y column type [Y].
*/
public fun interface Interpolator<T, X : T, Y : T> {
public fun interface Interpolator<T, in X : T, Y : T> {
public fun interpolate(points: XYColumnarData<T, X, Y>): (X) -> Y
}

View File

@ -5,7 +5,7 @@
package space.kscience.kmath.geometry
public data class Line<V : Vector>(val base: V, val direction: V)
public data class Line<out V : Vector>(val base: V, val direction: V)
public typealias Line2D = Line<Vector2D>
public typealias Line3D = Line<Vector3D>

View File

@ -13,14 +13,14 @@ import space.kscience.kmath.structures.asBuffer
/**
* The binned data element. Could be a histogram bin with a number of counts or an artificial construct
*/
public interface Bin<T : Any> : Domain<T> {
public interface Bin<in T : Any> : Domain<T> {
/**
* The value of this bin.
*/
public val value: Number
}
public interface Histogram<T : Any, out B : Bin<T>> {
public interface Histogram<in T : Any, out B : Bin<T>> {
/**
* Find existing bin, corresponding to given coordinates
*/
@ -34,7 +34,7 @@ public interface Histogram<T : Any, out B : Bin<T>> {
public val bins: Iterable<B>
}
public fun interface HistogramBuilder<T : Any> {
public fun interface HistogramBuilder<in T : Any> {
/**
* Increment appropriate bin

View File

@ -18,7 +18,7 @@ import space.kscience.kmath.operations.invoke
/**
* A simple histogram bin based on domain
*/
public data class DomainBin<T : Comparable<T>>(
public data class DomainBin<in T : Comparable<T>>(
public val domain: Domain<T>,
public override val value: Number,
) : Bin<T>, Domain<T> by domain

View File

@ -36,9 +36,10 @@ public class TreeHistogram(
}
@OptIn(UnstableKMathAPI::class)
private class TreeHistogramBuilder(val binFactory: (Double) -> UnivariateDomain) : UnivariateHistogramBuilder {
@PublishedApi
internal class TreeHistogramBuilder(val binFactory: (Double) -> UnivariateDomain) : UnivariateHistogramBuilder {
private class BinCounter(val domain: UnivariateDomain, val counter: Counter<Double> = Counter.real()) :
internal class BinCounter(val domain: UnivariateDomain, val counter: Counter<Double> = Counter.real()) :
ClosedFloatingPointRange<Double> by domain.range
private val bins: TreeMap<Double, BinCounter> = TreeMap()
@ -80,10 +81,10 @@ private class TreeHistogramBuilder(val binFactory: (Double) -> UnivariateDomain)
*/
@UnstableKMathAPI
public class TreeHistogramSpace(
private val binFactory: (Double) -> UnivariateDomain,
@PublishedApi internal val binFactory: (Double) -> UnivariateDomain,
) : Group<UnivariateHistogram>, ScaleOperations<UnivariateHistogram> {
public fun fill(block: UnivariateHistogramBuilder.() -> Unit): UnivariateHistogram =
public inline fun fill(block: UnivariateHistogramBuilder.() -> Unit): UnivariateHistogram =
TreeHistogramBuilder(binFactory).apply(block).build()
override fun add(
@ -115,8 +116,8 @@ public class TreeHistogramSpace(
bin.domain.center,
UnivariateBin(
bin.domain,
value = bin.value * value.toDouble(),
standardDeviation = abs(bin.standardDeviation * value.toDouble())
value = bin.value * value,
standardDeviation = abs(bin.standardDeviation * value)
)
)
}

View File

@ -18,14 +18,14 @@ readme {
feature(
"differentiable-mst-expression",
"src/main/kotlin/space/kscience/kmath/kotlingrad/DifferentiableMstExpression.kt",
"src/main/kotlin/space/kscience/kmath/kotlingrad/KotlingradExpression.kt",
) {
"MST based DifferentiableExpression."
}
feature(
"differentiable-mst-expression",
"src/main/kotlin/space/kscience/kmath/kotlingrad/DifferentiableMstExpression.kt",
"scalars-adapters",
"src/main/kotlin/space/kscience/kmath/kotlingrad/scalarsAdapters.kt",
) {
"Conversions between Kotlin∇'s SFun and MST"
}

View File

@ -25,7 +25,7 @@ private class Nd4jArrayIndicesIterator(private val iterateOver: INDArray) : Iter
internal fun INDArray.indicesIterator(): Iterator<IntArray> = Nd4jArrayIndicesIterator(this)
private sealed class Nd4jArrayIteratorBase<T>(protected val iterateOver: INDArray) : Iterator<Pair<IntArray, T>> {
private sealed class Nd4jArrayIteratorBase<out T>(protected val iterateOver: INDArray) : Iterator<Pair<IntArray, T>> {
private var i: Int = 0
final override fun hasNext(): Boolean = i < iterateOver.length()

View File

@ -9,7 +9,7 @@ import space.kscience.kmath.expressions.Symbol
public interface OptimizationFeature
public class OptimizationResult<T>(
public class OptimizationResult<out T>(
public val point: Map<Symbol, T>,
public val value: T,
public val features: Set<OptimizationFeature> = emptySet(),

View File

@ -1,18 +0,0 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.samplers
import space.kscience.kmath.chains.BlockingBufferChain
import space.kscience.kmath.stat.RandomGenerator
import space.kscience.kmath.stat.Sampler
import space.kscience.kmath.structures.Buffer
public class ConstantSampler<T : Any>(public val const: T) : Sampler<T> {
override fun sample(generator: RandomGenerator): BlockingBufferChain<T> = object : BlockingBufferChain<T> {
override fun nextBufferBlocking(size: Int): Buffer<T> = Buffer.boxing(size) { const }
override suspend fun fork(): BlockingBufferChain<T> = this
}
}

View File

@ -6,6 +6,8 @@
package space.kscience.kmath.stat
import kotlinx.coroutines.*
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
import kotlin.coroutines.CoroutineContext
import kotlin.coroutines.EmptyCoroutineContext
import kotlin.coroutines.coroutineContext
@ -23,14 +25,18 @@ public class MCScope(
/**
* Launches a supervised Monte-Carlo scope
*/
public suspend inline fun <T> mcScope(generator: RandomGenerator, block: MCScope.() -> T): T =
MCScope(coroutineContext, generator).block()
public suspend inline fun <T> mcScope(generator: RandomGenerator, block: MCScope.() -> T): T {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return MCScope(coroutineContext, generator).block()
}
/**
* Launch mc scope with a given seed
*/
public suspend inline fun <T> mcScope(seed: Long, block: MCScope.() -> T): T =
mcScope(RandomGenerator.default(seed), block)
public suspend inline fun <T> mcScope(seed: Long, block: MCScope.() -> T): T {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return mcScope(RandomGenerator.default(seed), block)
}
/**
* Specialized launch for [MCScope]. Behaves the same way as regular [CoroutineScope.launch], but also stores the generator fork.

View File

@ -12,7 +12,7 @@ import space.kscience.kmath.structures.*
import kotlin.jvm.JvmName
/**
* Sampler that generates chains of values of type [T] in a chain of type [C].
* Sampler that generates chains of values of type [T].
*/
public fun interface Sampler<out T : Any> {
/**

View File

@ -18,7 +18,7 @@ import space.kscience.kmath.operations.invoke
*
* @property value the value to sample.
*/
public class ConstantSampler<T : Any>(public val value: T) : Sampler<T> {
public class ConstantSampler<out T : Any>(public val value: T) : Sampler<T> {
public override fun sample(generator: RandomGenerator): Chain<T> = ConstantChain(value)
}
@ -27,7 +27,7 @@ public class ConstantSampler<T : Any>(public val value: T) : Sampler<T> {
*
* @property chainBuilder the provider of [Chain].
*/
public class BasicSampler<T : Any>(public val chainBuilder: (RandomGenerator) -> Chain<T>) : Sampler<T> {
public class BasicSampler<out T : Any>(public val chainBuilder: (RandomGenerator) -> Chain<T>) : Sampler<T> {
public override fun sample(generator: RandomGenerator): Chain<T> = chainBuilder(generator)
}
@ -36,7 +36,7 @@ public class BasicSampler<T : Any>(public val chainBuilder: (RandomGenerator) ->
*
* @property algebra the space to provide addition and scalar multiplication for [T].
*/
public class SamplerSpace<T : Any, S>(public val algebra: S) : Group<Sampler<T>>,
public class SamplerSpace<T : Any, out S>(public val algebra: S) : Group<Sampler<T>>,
ScaleOperations<Sampler<T>> where S : Group<T>, S : ScaleOperations<T> {
public override val zero: Sampler<T> = ConstantSampler(algebra.zero)

View File

@ -18,11 +18,11 @@ import space.kscience.kmath.structures.Buffer
/**
* A function, that transforms a buffer of random quantities to some resulting value
*/
public interface Statistic<T, R> {
public interface Statistic<in T, out R> {
public suspend fun evaluate(data: Buffer<T>): R
}
public interface BlockingStatistic<T,R>: Statistic<T,R>{
public interface BlockingStatistic<in T, out R> : Statistic<T, R> {
public fun evaluateBlocking(data: Buffer<T>): R
override suspend fun evaluate(data: Buffer<T>): R = evaluateBlocking(data)
@ -34,7 +34,7 @@ public interface BlockingStatistic<T,R>: Statistic<T,R>{
* @param I - intermediate block type
* @param R - result type
*/
public interface ComposableStatistic<T, I, R> : Statistic<T, R> {
public interface ComposableStatistic<in T, I, out R> : Statistic<T, R> {
//compute statistic on a single block
public suspend fun computeIntermediate(data: Buffer<T>): I

View File

@ -14,7 +14,8 @@ import space.kscience.kmath.expressions.interpret
import space.kscience.kmath.operations.NumericAlgebra
/**
* Represents [MST] based [DifferentiableExpression] relying on [Symja](https://github.com/axkr/symja_android_library).
* Represents [MST] based [space.kscience.kmath.expressions.DifferentiableExpression] relying on
* [Symja](https://github.com/axkr/symja_android_library).
*
* The principle of this API is converting the [mst] to an [org.matheclipse.core.interfaces.IExpr], differentiating it
* with Symja's [F.D], then converting [org.matheclipse.core.interfaces.IExpr] back to [MST].

View File

@ -3,7 +3,7 @@
Common linear algebra operations on tensors.
- [tensor algebra](src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt) : Basic linear algebra operations on tensors (plus, dot, etc.)
- [tensor algebra with broadcasting](src/commonMain/kotlin/space/kscience/kmath/tensors/core/algebras/BroadcastDoubleTensorAlgebra.kt) : Basic linear algebra operations implemented with broadcasting.
- [tensor algebra with broadcasting](src/commonMain/kotlin/space/kscience/kmath/tensors/core/BroadcastDoubleTensorAlgebra.kt) : Basic linear algebra operations implemented with broadcasting.
- [linear algebra operations](src/commonMain/kotlin/space/kscience/kmath/tensors/api/LinearOpsTensorAlgebra.kt) : Advanced linear algebra operations like LU decomposition, SVD, etc.

View File

@ -24,20 +24,16 @@ readme {
feature(
id = "tensor algebra",
description = "Basic linear algebra operations on tensors (plus, dot, etc.)",
ref = "src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt"
)
) { "Basic linear algebra operations on tensors (plus, dot, etc.)" }
feature(
id = "tensor algebra with broadcasting",
description = "Basic linear algebra operations implemented with broadcasting.",
ref = "src/commonMain/kotlin/space/kscience/kmath/tensors/core/algebras/BroadcastDoubleTensorAlgebra.kt"
)
ref = "src/commonMain/kotlin/space/kscience/kmath/tensors/core/BroadcastDoubleTensorAlgebra.kt"
) { "Basic linear algebra operations implemented with broadcasting." }
feature(
id = "linear algebra operations",
description = "Advanced linear algebra operations like LU decomposition, SVD, etc.",
ref = "src/commonMain/kotlin/space/kscience/kmath/tensors/api/LinearOpsTensorAlgebra.kt"
)
) { "Advanced linear algebra operations like LU decomposition, SVD, etc." }
}

View File

@ -11,8 +11,8 @@ import space.kscience.kmath.tensors.core.internal.TensorLinearStructure
*/
public open class BufferedTensor<T> internal constructor(
override val shape: IntArray,
internal val mutableBuffer: MutableBuffer<T>,
internal val bufferStart: Int
@PublishedApi internal val mutableBuffer: MutableBuffer<T>,
@PublishedApi internal val bufferStart: Int,
) : Tensor<T> {
/**

View File

@ -11,7 +11,7 @@ import space.kscience.kmath.tensors.core.internal.toPrettyString
/**
* Default [BufferedTensor] implementation for [Double] values
*/
public class DoubleTensor internal constructor(
public class DoubleTensor @PublishedApi internal constructor(
shape: IntArray,
buffer: DoubleArray,
offset: Int = 0

View File

@ -426,13 +426,11 @@ public open class DoubleTensorAlgebra :
* @param transform the function to be applied to each element of the tensor.
* @return the resulting tensor after applying the function.
*/
public fun Tensor<Double>.map(transform: (Double) -> Double): DoubleTensor {
return DoubleTensor(
public inline fun Tensor<Double>.map(transform: (Double) -> Double): DoubleTensor = DoubleTensor(
tensor.shape,
tensor.mutableBuffer.array().map { transform(it) }.toDoubleArray(),
tensor.bufferStart
)
}
/**
* Compares element-wise two tensors with a specified precision.
@ -519,14 +517,12 @@ public open class DoubleTensorAlgebra :
* @param indices the [IntArray] of 1-dimensional indices
* @return tensor with rows corresponding to rows by [indices]
*/
public fun Tensor<Double>.rowsByIndices(indices: IntArray): DoubleTensor {
return stack(indices.map { this[it] })
}
public fun Tensor<Double>.rowsByIndices(indices: IntArray): DoubleTensor = stack(indices.map { this[it] })
internal fun Tensor<Double>.fold(foldFunction: (DoubleArray) -> Double): Double =
internal inline fun Tensor<Double>.fold(foldFunction: (DoubleArray) -> Double): Double =
foldFunction(tensor.toDoubleArray())
internal fun Tensor<Double>.foldDim(
internal inline fun Tensor<Double>.foldDim(
foldFunction: (DoubleArray) -> Double,
dim: Int,
keepDim: Boolean,

View File

@ -10,15 +10,15 @@ import space.kscience.kmath.nd.MutableStructure2D
import space.kscience.kmath.nd.as1D
import space.kscience.kmath.nd.as2D
import space.kscience.kmath.operations.invoke
import space.kscience.kmath.tensors.core.*
import space.kscience.kmath.tensors.core.BufferedTensor
import space.kscience.kmath.tensors.core.DoubleTensor
import space.kscience.kmath.tensors.core.DoubleTensorAlgebra
import space.kscience.kmath.tensors.core.DoubleTensorAlgebra.Companion.valueOrNull
import space.kscience.kmath.tensors.core.IntTensor
import kotlin.math.abs
import kotlin.math.min
import kotlin.math.sign
import kotlin.math.sqrt
internal fun <T> BufferedTensor<T>.vectorSequence(): Sequence<BufferedTensor<T>> = sequence {
val n = shape.size
val vectorOffset = shape[n - 1]

View File

@ -31,6 +31,7 @@ internal fun <T> Tensor<T>.toBufferedTensor(): BufferedTensor<T> = when (this) {
else -> this.copyToBufferedTensor()
}
@PublishedApi
internal val Tensor<Double>.tensor: DoubleTensor
get() = when (this) {
is DoubleTensor -> this

View File

@ -24,6 +24,7 @@ internal fun Buffer<Int>.array(): IntArray = when (this) {
/**
* Returns a reference to [DoubleArray] containing all of the elements of this [Buffer] or copy the data.
*/
@PublishedApi
internal fun Buffer<Double>.array(): DoubleArray = when (this) {
is DoubleBuffer -> array
else -> this.toDoubleArray()

View File

@ -183,7 +183,7 @@ internal class TestDoubleLinearOpsTensorAlgebra {
}
private fun DoubleTensorAlgebra.testSVDFor(tensor: DoubleTensor, epsilon: Double = 1e-10): Unit {
private fun DoubleTensorAlgebra.testSVDFor(tensor: DoubleTensor, epsilon: Double = 1e-10) {
val svd = tensor.svd()
val tensorSVD = svd.first

View File

@ -17,7 +17,7 @@ internal class TestDoubleTensorAlgebra {
}
@Test
fun TestDoubleDiv() = DoubleTensorAlgebra {
fun testDoubleDiv() = DoubleTensorAlgebra {
val tensor = fromArray(intArrayOf(2), doubleArrayOf(2.0, 4.0))
val res = 2.0/tensor
assertTrue(res.mutableBuffer.array() contentEquals doubleArrayOf(1.0, 0.5))

View File

@ -92,7 +92,7 @@ public class ViktorFieldND(public override val shape: IntArray) : FieldND<Double
(a.f64Buffer + b.f64Buffer).asStructure()
public override fun scale(a: StructureND<Double>, value: Double): ViktorStructureND =
(a.f64Buffer * value.toDouble()).asStructure()
(a.f64Buffer * value).asStructure()
public override inline fun StructureND<Double>.plus(b: StructureND<Double>): ViktorStructureND =
(f64Buffer + b.f64Buffer).asStructure()