Merge pull request #380 from mipt-npm/commandertvis/contracts

Add contracts to some functions, fix multiple style issues
This commit is contained in:
Iaroslav Postovalov 2021-07-13 23:12:48 +07:00 committed by GitHub
commit ecd70f2139
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
73 changed files with 210 additions and 200 deletions

View File

@ -13,6 +13,8 @@ import space.kscience.kmath.jafama.JafamaDoubleField
import space.kscience.kmath.jafama.StrictJafamaDoubleField import space.kscience.kmath.jafama.StrictJafamaDoubleField
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.invoke import space.kscience.kmath.operations.invoke
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
import kotlin.random.Random import kotlin.random.Random
@State(Scope.Benchmark) @State(Scope.Benchmark)
@ -31,9 +33,10 @@ internal class JafamaBenchmark {
fun strictJafama(blackhole: Blackhole) = invokeBenchmarks(blackhole) { x -> fun strictJafama(blackhole: Blackhole) = invokeBenchmarks(blackhole) { x ->
StrictJafamaDoubleField { x * power(x, 4) * exp(x) / cos(x) + sin(x) } StrictJafamaDoubleField { x * power(x, 4) * exp(x) / cos(x) + sin(x) }
} }
}
private inline fun invokeBenchmarks(blackhole: Blackhole, expr: (Double) -> Double) {
val rng = Random(0) private inline fun invokeBenchmarks(blackhole: Blackhole, expr: (Double) -> Double) {
repeat(1000000) { blackhole.consume(expr(rng.nextDouble())) } contract { callsInPlace(expr, InvocationKind.AT_LEAST_ONCE) }
} val rng = Random(0)
repeat(1000000) { blackhole.consume(expr(rng.nextDouble())) }
} }

View File

@ -40,14 +40,14 @@ internal class MatrixInverseBenchmark {
@Benchmark @Benchmark
fun cmLUPInversion(blackhole: Blackhole) { fun cmLUPInversion(blackhole: Blackhole) {
with(CMLinearSpace) { CMLinearSpace {
blackhole.consume(inverse(matrix)) blackhole.consume(inverse(matrix))
} }
} }
@Benchmark @Benchmark
fun ejmlInverse(blackhole: Blackhole) { fun ejmlInverse(blackhole: Blackhole) {
with(EjmlLinearSpaceDDRM) { EjmlLinearSpaceDDRM {
blackhole.consume(matrix.getFeature<InverseMatrixFeature<Double>>()?.inverse) blackhole.consume(matrix.getFeature<InverseMatrixFeature<Double>>()?.inverse)
} }
} }

View File

@ -115,6 +115,8 @@ via extension function.
Usually it is bad idea to compare the direct numerical operation performance in different languages, but it hard to Usually it is bad idea to compare the direct numerical operation performance in different languages, but it hard to
work completely without frame of reference. In this case, simple numpy code: work completely without frame of reference. In this case, simple numpy code:
```python ```python
import numpy as np
res = np.ones((1000,1000)) res = np.ones((1000,1000))
for i in range(1000): for i in range(1000):
res = res + 1.0 res = res + 1.0

View File

@ -10,7 +10,7 @@ import space.kscience.kmath.ast.rendering.LatexSyntaxRenderer
import space.kscience.kmath.ast.rendering.MathMLSyntaxRenderer import space.kscience.kmath.ast.rendering.MathMLSyntaxRenderer
import space.kscience.kmath.ast.rendering.renderWithStringBuilder import space.kscience.kmath.ast.rendering.renderWithStringBuilder
public fun main() { fun main() {
val mst = "exp(sqrt(x))-asin(2*x)/(2e10+x^3)/(-12)".parseMath() val mst = "exp(sqrt(x))-asin(2*x)/(2e10+x^3)/(-12)".parseMath()
val syntax = FeaturedMathRendererWithPostProcess.Default.render(mst) val syntax = FeaturedMathRendererWithPostProcess.Default.render(mst)
println("MathSyntax:") println("MathSyntax:")

View File

@ -13,7 +13,7 @@ import space.kscience.kmath.operations.*
import kotlin.reflect.KClass import kotlin.reflect.KClass
/** /**
* Prints any [Symbol] as a [SymbolSyntax] containing the [Symbol.value] of it. * Prints any [Symbol] as a [SymbolSyntax] containing the [Symbol.identity] of it.
* *
* @author Iaroslav Postovalov * @author Iaroslav Postovalov
*/ */

View File

@ -11,7 +11,6 @@ import space.kscience.kmath.expressions.Symbol.Companion.x
import space.kscience.kmath.expressions.interpret import space.kscience.kmath.expressions.interpret
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.IntRing import space.kscience.kmath.operations.IntRing
import space.kscience.kmath.operations.bindSymbol
import space.kscience.kmath.operations.invoke import space.kscience.kmath.operations.invoke
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals

View File

@ -9,7 +9,6 @@ import space.kscience.kmath.expressions.MstRing
import space.kscience.kmath.expressions.Symbol.Companion.x import space.kscience.kmath.expressions.Symbol.Companion.x
import space.kscience.kmath.expressions.invoke import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.operations.IntRing import space.kscience.kmath.operations.IntRing
import space.kscience.kmath.operations.bindSymbol
import space.kscience.kmath.operations.invoke import space.kscience.kmath.operations.invoke
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals

View File

@ -3,6 +3,8 @@
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/ */
@file:Suppress("ClassName")
package space.kscience.kmath.internal.estree package space.kscience.kmath.internal.estree
import kotlin.js.RegExp import kotlin.js.RegExp

View File

@ -10,6 +10,8 @@ import space.kscience.kmath.expressions.MST
import space.kscience.kmath.expressions.Symbol import space.kscience.kmath.expressions.Symbol
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.IntRing import space.kscience.kmath.operations.IntRing
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
import space.kscience.kmath.estree.compile as estreeCompile import space.kscience.kmath.estree.compile as estreeCompile
import space.kscience.kmath.estree.compileToExpression as estreeCompileToExpression import space.kscience.kmath.estree.compileToExpression as estreeCompileToExpression
import space.kscience.kmath.wasm.compile as wasmCompile import space.kscience.kmath.wasm.compile as wasmCompile
@ -34,6 +36,7 @@ private object ESTreeCompilerTestContext : CompilerTestContext {
} }
internal actual inline fun runCompilerTest(action: CompilerTestContext.() -> Unit) { internal actual inline fun runCompilerTest(action: CompilerTestContext.() -> Unit) {
contract { callsInPlace(action, InvocationKind.AT_LEAST_ONCE) }
action(WasmCompilerTestContext) action(WasmCompilerTestContext)
action(ESTreeCompilerTestContext) action(ESTreeCompilerTestContext)
} }

View File

@ -11,7 +11,6 @@ import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.expressions.symbol import space.kscience.kmath.expressions.symbol
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.IntRing import space.kscience.kmath.operations.IntRing
import space.kscience.kmath.operations.bindSymbol
import space.kscience.kmath.operations.invoke import space.kscience.kmath.operations.invoke
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals

View File

@ -10,6 +10,8 @@ import space.kscience.kmath.expressions.MST
import space.kscience.kmath.expressions.Symbol import space.kscience.kmath.expressions.Symbol
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.IntRing import space.kscience.kmath.operations.IntRing
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
import space.kscience.kmath.asm.compile as asmCompile import space.kscience.kmath.asm.compile as asmCompile
import space.kscience.kmath.asm.compileToExpression as asmCompileToExpression import space.kscience.kmath.asm.compileToExpression as asmCompileToExpression
@ -22,4 +24,7 @@ private object AsmCompilerTestContext : CompilerTestContext {
asmCompile(algebra, arguments) asmCompile(algebra, arguments)
} }
internal actual inline fun runCompilerTest(action: CompilerTestContext.() -> Unit) = action(AsmCompilerTestContext) internal actual inline fun runCompilerTest(action: CompilerTestContext.() -> Unit) {
contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) }
action(AsmCompilerTestContext)
}

View File

@ -15,7 +15,7 @@ import space.kscience.kmath.operations.NumbersAddOperations
* A field over commons-math [DerivativeStructure]. * A field over commons-math [DerivativeStructure].
* *
* @property order The derivation order. * @property order The derivation order.
* @property bindings The map of bindings values. All bindings are considered free parameters * @param bindings The map of bindings values. All bindings are considered free parameters
*/ */
@OptIn(UnstableKMathAPI::class) @OptIn(UnstableKMathAPI::class)
public class DerivativeStructureField( public class DerivativeStructureField(

View File

@ -52,11 +52,11 @@ public class CMOptimization(
public fun exportOptimizationData(): List<OptimizationData> = optimizationData.values.toList() public fun exportOptimizationData(): List<OptimizationData> = optimizationData.values.toList()
public override fun initialGuess(map: Map<Symbol, Double>): Unit { public override fun initialGuess(map: Map<Symbol, Double>) {
addOptimizationData(InitialGuess(map.toDoubleArray())) addOptimizationData(InitialGuess(map.toDoubleArray()))
} }
public override fun function(expression: Expression<Double>): Unit { public override fun function(expression: Expression<Double>) {
val objectiveFunction = ObjectiveFunction { val objectiveFunction = ObjectiveFunction {
val args = it.toMap() val args = it.toMap()
expression(args) expression(args)

View File

@ -32,7 +32,7 @@ public object Transformations {
/** /**
* Create a virtual buffer on top of array * Create a virtual buffer on top of array
*/ */
private fun Array<org.apache.commons.math3.complex.Complex>.asBuffer() = VirtualBuffer<Complex>(size) { private fun Array<org.apache.commons.math3.complex.Complex>.asBuffer() = VirtualBuffer(size) {
val value = get(it) val value = get(it)
Complex(value.real, value.imaginary) Complex(value.real, value.imaginary)
} }

View File

@ -16,7 +16,7 @@ internal inline fun diff(
order: Int, order: Int,
vararg parameters: Pair<Symbol, Double>, vararg parameters: Pair<Symbol, Double>,
block: DerivativeStructureField.() -> Unit, block: DerivativeStructureField.() -> Unit,
): Unit { ) {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
DerivativeStructureField(order, mapOf(*parameters)).run(block) DerivativeStructureField(order, mapOf(*parameters)).run(block)
} }

View File

@ -147,7 +147,7 @@ public object QuaternionField : Field<Quaternion>, Norm<Quaternion, Quaternion>,
return if (arg.w > 0) return if (arg.w > 0)
Quaternion(ln(arg.w), 0, 0, 0) Quaternion(ln(arg.w), 0, 0, 0)
else { else {
val l = ComplexField { ComplexField.ln(arg.w.toComplex()) } val l = ComplexField { ln(arg.w.toComplex()) }
Quaternion(l.re, l.im, 0, 0) Quaternion(l.re, l.im, 0, 0)
} }

View File

@ -16,7 +16,7 @@ import kotlin.math.max
* The buffer of X values. * The buffer of X values.
*/ */
@UnstableKMathAPI @UnstableKMathAPI
public interface XYColumnarData<T, out X : T, out Y : T> : ColumnarData<T> { public interface XYColumnarData<out T, out X : T, out Y : T> : ColumnarData<T> {
/** /**
* The buffer of X values * The buffer of X values
*/ */

View File

@ -14,7 +14,7 @@ import space.kscience.kmath.structures.Buffer
* Inherits [XYColumnarData]. * Inherits [XYColumnarData].
*/ */
@UnstableKMathAPI @UnstableKMathAPI
public interface XYZColumnarData<T, out X : T, out Y : T, out Z : T> : XYColumnarData<T, X, Y> { public interface XYZColumnarData<out T, out X : T, out Y : T, out Z : T> : XYColumnarData<T, X, Y> {
public val z: Buffer<Z> public val z: Buffer<Z>
override fun get(symbol: Symbol): Buffer<T>? = when (symbol) { override fun get(symbol: Symbol): Buffer<T>? = when (symbol) {
@ -23,4 +23,4 @@ public interface XYZColumnarData<T, out X : T, out Y : T, out Z : T> : XYColumna
Symbol.z -> z Symbol.z -> z
else -> null else -> null
} }
} }

View File

@ -12,7 +12,7 @@ import space.kscience.kmath.linear.Point
* *
* @param T the type of element of this domain. * @param T the type of element of this domain.
*/ */
public interface Domain<T : Any> { public interface Domain<in T : Any> {
/** /**
* Checks if the specified point is contained in this domain. * Checks if the specified point is contained in this domain.
*/ */

View File

@ -9,7 +9,6 @@ package space.kscience.kmath.expressions
* Represents expression which structure can be differentiated. * Represents expression which structure can be differentiated.
* *
* @param T the type this expression takes as argument and returns. * @param T the type this expression takes as argument and returns.
* @param R the type of expression this expression can be differentiated to.
*/ */
public interface DifferentiableExpression<T> : Expression<T> { public interface DifferentiableExpression<T> : Expression<T> {
/** /**
@ -24,16 +23,18 @@ public interface DifferentiableExpression<T> : Expression<T> {
public fun <T> DifferentiableExpression<T>.derivative(symbols: List<Symbol>): Expression<T> = public fun <T> DifferentiableExpression<T>.derivative(symbols: List<Symbol>): Expression<T> =
derivativeOrNull(symbols) ?: error("Derivative by symbols $symbols not provided") derivativeOrNull(symbols) ?: error("Derivative by symbols $symbols not provided")
public fun <T> DifferentiableExpression<T>.derivative(vararg symbols: Symbol): Expression<T> = public fun <T> DifferentiableExpression<T>.derivative(vararg symbols: Symbol): Expression<T> =
derivative(symbols.toList()) derivative(symbols.toList())
public fun <T> DifferentiableExpression<T>.derivative(name: String): Expression<T> = public fun <T> DifferentiableExpression<T>.derivative(name: String): Expression<T> =
derivative(StringSymbol(name)) derivative(StringSymbol(name))
/** /**
* A special type of [DifferentiableExpression] which returns typed expressions as derivatives * A special type of [DifferentiableExpression] which returns typed expressions as derivatives.
*
* @param R the type of expression this expression can be differentiated to.
*/ */
public interface SpecialDifferentiableExpression<T, R: Expression<T>>: DifferentiableExpression<T> { public interface SpecialDifferentiableExpression<T, out R : Expression<T>> : DifferentiableExpression<T> {
override fun derivativeOrNull(symbols: List<Symbol>): R? override fun derivativeOrNull(symbols: List<Symbol>): R?
} }
@ -53,9 +54,9 @@ public abstract class FirstDerivativeExpression<T> : DifferentiableExpression<T>
/** /**
* Returns first derivative of this expression by given [symbol]. * Returns first derivative of this expression by given [symbol].
*/ */
public abstract fun derivativeOrNull(symbol: Symbol): Expression<T>? public abstract fun derivativeOrNull(symbol: Symbol): Expression<T>?
public final override fun derivativeOrNull(symbols: List<Symbol>): Expression<T>? { public final override fun derivativeOrNull(symbols: List<Symbol>): Expression<T>? {
val dSymbol = symbols.firstOrNull() ?: return null val dSymbol = symbols.firstOrNull() ?: return null
return derivativeOrNull(dSymbol) return derivativeOrNull(dSymbol)
} }
@ -64,6 +65,6 @@ public abstract class FirstDerivativeExpression<T> : DifferentiableExpression<T>
/** /**
* A factory that converts an expression in autodiff variables to a [DifferentiableExpression] * A factory that converts an expression in autodiff variables to a [DifferentiableExpression]
*/ */
public fun interface AutoDiffProcessor<T : Any, I : Any, A : ExpressionAlgebra<T, I>, out R : Expression<T>> { public fun interface AutoDiffProcessor<T : Any, I : Any, out A : ExpressionAlgebra<T, I>, out R : Expression<T>> {
public fun process(function: A.() -> I): DifferentiableExpression<T> public fun process(function: A.() -> I): DifferentiableExpression<T>
} }

View File

@ -6,13 +6,15 @@
package space.kscience.kmath.expressions package space.kscience.kmath.expressions
import space.kscience.kmath.operations.* import space.kscience.kmath.operations.*
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
/** /**
* A context class for [Expression] construction. * A context class for [Expression] construction.
* *
* @param algebra The algebra to provide for Expressions built. * @param algebra The algebra to provide for Expressions built.
*/ */
public abstract class FunctionalExpressionAlgebra<T, A : Algebra<T>>( public abstract class FunctionalExpressionAlgebra<T, out A : Algebra<T>>(
public val algebra: A, public val algebra: A,
) : ExpressionAlgebra<T, Expression<T>> { ) : ExpressionAlgebra<T, Expression<T>> {
/** /**
@ -29,9 +31,6 @@ public abstract class FunctionalExpressionAlgebra<T, A : Algebra<T>>(
?: error("Symbol '$value' is not supported in $this") ?: error("Symbol '$value' is not supported in $this")
} }
/**
* Builds an Expression of dynamic call of binary operation [operation] on [left] and [right].
*/
public override fun binaryOperationFunction(operation: String): (left: Expression<T>, right: Expression<T>) -> Expression<T> = public override fun binaryOperationFunction(operation: String): (left: Expression<T>, right: Expression<T>) -> Expression<T> =
{ left, right -> { left, right ->
Expression { arguments -> Expression { arguments ->
@ -39,9 +38,6 @@ public abstract class FunctionalExpressionAlgebra<T, A : Algebra<T>>(
} }
} }
/**
* Builds an Expression of dynamic call of unary operation with name [operation] on [arg].
*/
public override fun unaryOperationFunction(operation: String): (arg: Expression<T>) -> Expression<T> = { arg -> public override fun unaryOperationFunction(operation: String): (arg: Expression<T>) -> Expression<T> = { arg ->
Expression { arguments -> algebra.unaryOperationFunction(operation)(arg.invoke(arguments)) } Expression { arguments -> algebra.unaryOperationFunction(operation)(arg.invoke(arguments)) }
} }
@ -50,7 +46,7 @@ public abstract class FunctionalExpressionAlgebra<T, A : Algebra<T>>(
/** /**
* A context class for [Expression] construction for [Ring] algebras. * A context class for [Expression] construction for [Ring] algebras.
*/ */
public open class FunctionalExpressionGroup<T, A : Group<T>>( public open class FunctionalExpressionGroup<T, out A : Group<T>>(
algebra: A, algebra: A,
) : FunctionalExpressionAlgebra<T, A>(algebra), Group<Expression<T>> { ) : FunctionalExpressionAlgebra<T, A>(algebra), Group<Expression<T>> {
public override val zero: Expression<T> get() = const(algebra.zero) public override val zero: Expression<T> get() = const(algebra.zero)
@ -84,7 +80,7 @@ public open class FunctionalExpressionGroup<T, A : Group<T>>(
} }
public open class FunctionalExpressionRing<T, A : Ring<T>>( public open class FunctionalExpressionRing<T, out A : Ring<T>>(
algebra: A, algebra: A,
) : FunctionalExpressionGroup<T, A>(algebra), Ring<Expression<T>> { ) : FunctionalExpressionGroup<T, A>(algebra), Ring<Expression<T>> {
public override val one: Expression<T> get() = const(algebra.one) public override val one: Expression<T> get() = const(algebra.one)
@ -105,7 +101,7 @@ public open class FunctionalExpressionRing<T, A : Ring<T>>(
super<FunctionalExpressionGroup>.binaryOperationFunction(operation) super<FunctionalExpressionGroup>.binaryOperationFunction(operation)
} }
public open class FunctionalExpressionField<T, A : Field<T>>( public open class FunctionalExpressionField<T, out A : Field<T>>(
algebra: A, algebra: A,
) : FunctionalExpressionRing<T, A>(algebra), Field<Expression<T>>, ScaleOperations<Expression<T>> { ) : FunctionalExpressionRing<T, A>(algebra), Field<Expression<T>>, ScaleOperations<Expression<T>> {
/** /**
@ -131,7 +127,7 @@ public open class FunctionalExpressionField<T, A : Field<T>>(
super<FunctionalExpressionRing>.bindSymbolOrNull(value) super<FunctionalExpressionRing>.bindSymbolOrNull(value)
} }
public open class FunctionalExpressionExtendedField<T, A : ExtendedField<T>>( public open class FunctionalExpressionExtendedField<T, out A : ExtendedField<T>>(
algebra: A, algebra: A,
) : FunctionalExpressionField<T, A>(algebra), ExtendedField<Expression<T>> { ) : FunctionalExpressionField<T, A>(algebra), ExtendedField<Expression<T>> {
public override fun number(value: Number): Expression<T> = const(algebra.number(value)) public override fun number(value: Number): Expression<T> = const(algebra.number(value))
@ -172,14 +168,26 @@ public open class FunctionalExpressionExtendedField<T, A : ExtendedField<T>>(
public override fun bindSymbol(value: String): Expression<T> = super<FunctionalExpressionField>.bindSymbol(value) public override fun bindSymbol(value: String): Expression<T> = super<FunctionalExpressionField>.bindSymbol(value)
} }
public inline fun <T, A : Ring<T>> A.expressionInSpace(block: FunctionalExpressionGroup<T, A>.() -> Expression<T>): Expression<T> = public inline fun <T, A : Group<T>> A.expressionInGroup(
FunctionalExpressionGroup(this).block() block: FunctionalExpressionGroup<T, A>.() -> Expression<T>,
): Expression<T> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return FunctionalExpressionGroup(this).block()
}
public inline fun <T, A : Ring<T>> A.expressionInRing(block: FunctionalExpressionRing<T, A>.() -> Expression<T>): Expression<T> = public inline fun <T, A : Ring<T>> A.expressionInRing(
FunctionalExpressionRing(this).block() block: FunctionalExpressionRing<T, A>.() -> Expression<T>,
): Expression<T> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return FunctionalExpressionRing(this).block()
}
public inline fun <T, A : Field<T>> A.expressionInField(block: FunctionalExpressionField<T, A>.() -> Expression<T>): Expression<T> = public inline fun <T, A : Field<T>> A.expressionInField(
FunctionalExpressionField(this).block() block: FunctionalExpressionField<T, A>.() -> Expression<T>,
): Expression<T> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return FunctionalExpressionField(this).block()
}
public inline fun <T, A : ExtendedField<T>> A.expressionInExtendedField( public inline fun <T, A : ExtendedField<T>> A.expressionInExtendedField(
block: FunctionalExpressionExtendedField<T, A>.() -> Expression<T>, block: FunctionalExpressionExtendedField<T, A>.() -> Expression<T>,

View File

@ -119,8 +119,6 @@ public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
get() = getDerivative(this) get() = getDerivative(this)
set(value) = setDerivative(this, value) set(value) = setDerivative(this, value)
public inline fun const(block: F.() -> T): AutoDiffValue<T> = const(context.block())
/** /**
* Performs update of derivative after the rest of the formula in the back-pass. * Performs update of derivative after the rest of the formula in the back-pass.
* *
@ -194,6 +192,11 @@ public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
} }
} }
public inline fun <T : Any, F : Field<T>> SimpleAutoDiffField<T, F>.const(block: F.() -> T): AutoDiffValue<T> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return const(context.block())
}
/** /**
* Runs differentiation and establishes [SimpleAutoDiffField] context inside the block of code. * Runs differentiation and establishes [SimpleAutoDiffField] context inside the block of code.
@ -208,7 +211,7 @@ public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
* assertEquals(9.0, x.d) // dy/dx * assertEquals(9.0, x.d) // dy/dx
* ``` * ```
* *
* @param body the action in [SimpleAutoDiffField] context returning [AutoDiffVariable] to differentiate with respect to. * @param body the action in [SimpleAutoDiffField] context returning [AutoDiffValue] to differentiate with respect to.
* @return the result of differentiation. * @return the result of differentiation.
*/ */
public fun <T : Any, F : Field<T>> F.simpleAutoDiff( public fun <T : Any, F : Field<T>> F.simpleAutoDiff(

View File

@ -9,6 +9,8 @@ import space.kscience.kmath.linear.Point
import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.nd.Structure2D import space.kscience.kmath.nd.Structure2D
import space.kscience.kmath.structures.BufferFactory import space.kscience.kmath.structures.BufferFactory
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
import kotlin.jvm.JvmInline import kotlin.jvm.JvmInline
/** /**
@ -65,9 +67,13 @@ public value class SimpleSymbolIndexer(override val symbols: List<Symbol>) : Sym
* Execute the block with symbol indexer based on given symbol order * Execute the block with symbol indexer based on given symbol order
*/ */
@UnstableKMathAPI @UnstableKMathAPI
public inline fun <R> withSymbols(vararg symbols: Symbol, block: SymbolIndexer.() -> R): R = public inline fun <R> withSymbols(vararg symbols: Symbol, block: SymbolIndexer.() -> R): R {
with(SimpleSymbolIndexer(symbols.toList()), block) contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return with(SimpleSymbolIndexer(symbols.toList()), block)
}
@UnstableKMathAPI @UnstableKMathAPI
public inline fun <R> withSymbols(symbols: Collection<Symbol>, block: SymbolIndexer.() -> R): R = public inline fun <R> withSymbols(symbols: Collection<Symbol>, block: SymbolIndexer.() -> R): R {
with(SimpleSymbolIndexer(symbols.toList()), block) contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return with(SimpleSymbolIndexer(symbols.toList()), block)
}

View File

@ -15,7 +15,7 @@ import space.kscience.kmath.structures.VirtualBuffer
import space.kscience.kmath.structures.indices import space.kscience.kmath.structures.indices
public class BufferedLinearSpace<T : Any, A : Ring<T>>( public class BufferedLinearSpace<T : Any, out A : Ring<T>>(
override val elementAlgebra: A, override val elementAlgebra: A,
private val bufferFactory: BufferFactory<T>, private val bufferFactory: BufferFactory<T>,
) : LinearSpace<T, A> { ) : LinearSpace<T, A> {
@ -88,4 +88,4 @@ public class BufferedLinearSpace<T : Any, A : Ring<T>>(
override fun Matrix<T>.times(value: T): Matrix<T> = ndRing(rowNum, colNum).run { override fun Matrix<T>.times(value: T): Matrix<T> = ndRing(rowNum, colNum).run {
unwrap().map { it * value }.as2D() unwrap().map { it * value }.as2D()
} }
} }

View File

@ -32,7 +32,7 @@ public typealias Point<T> = Buffer<T>
* Basic operations on matrices and vectors. Operates on [Matrix]. * Basic operations on matrices and vectors. Operates on [Matrix].
* *
* @param T the type of items in the matrices. * @param T the type of items in the matrices.
* @param M the type of operated matrices. * @param A the type of ring over [T].
*/ */
public interface LinearSpace<T : Any, out A : Ring<T>> { public interface LinearSpace<T : Any, out A : Ring<T>> {
public val elementAlgebra: A public val elementAlgebra: A

View File

@ -35,7 +35,7 @@ public object UnitFeature : DiagonalFeature
* *
* @param T the type of matrices' items. * @param T the type of matrices' items.
*/ */
public interface InverseMatrixFeature<T : Any> : MatrixFeature { public interface InverseMatrixFeature<out T : Any> : MatrixFeature {
/** /**
* The inverse matrix of the matrix that owns this feature. * The inverse matrix of the matrix that owns this feature.
*/ */
@ -47,7 +47,7 @@ public interface InverseMatrixFeature<T : Any> : MatrixFeature {
* *
* @param T the type of matrices' items. * @param T the type of matrices' items.
*/ */
public interface DeterminantFeature<T : Any> : MatrixFeature { public interface DeterminantFeature<out T : Any> : MatrixFeature {
/** /**
* The determinant of the matrix that owns this feature. * The determinant of the matrix that owns this feature.
*/ */
@ -80,7 +80,7 @@ public object UFeature : MatrixFeature
* *
* @param T the type of matrices' items. * @param T the type of matrices' items.
*/ */
public interface LUDecompositionFeature<T : Any> : MatrixFeature { public interface LUDecompositionFeature<out T : Any> : MatrixFeature {
/** /**
* The lower triangular matrix in this decomposition. It may have [LFeature]. * The lower triangular matrix in this decomposition. It may have [LFeature].
*/ */
@ -98,7 +98,7 @@ public interface LUDecompositionFeature<T : Any> : MatrixFeature {
* *
* @param T the type of matrices' items. * @param T the type of matrices' items.
*/ */
public interface LupDecompositionFeature<T : Any> : MatrixFeature { public interface LupDecompositionFeature<out T : Any> : MatrixFeature {
/** /**
* The lower triangular matrix in this decomposition. It may have [LFeature]. * The lower triangular matrix in this decomposition. It may have [LFeature].
*/ */
@ -126,7 +126,7 @@ public object OrthogonalFeature : MatrixFeature
* *
* @param T the type of matrices' items. * @param T the type of matrices' items.
*/ */
public interface QRDecompositionFeature<T : Any> : MatrixFeature { public interface QRDecompositionFeature<out T : Any> : MatrixFeature {
/** /**
* The orthogonal matrix in this decomposition. It may have [OrthogonalFeature]. * The orthogonal matrix in this decomposition. It may have [OrthogonalFeature].
*/ */
@ -144,7 +144,7 @@ public interface QRDecompositionFeature<T : Any> : MatrixFeature {
* *
* @param T the type of matrices' items. * @param T the type of matrices' items.
*/ */
public interface CholeskyDecompositionFeature<T : Any> : MatrixFeature { public interface CholeskyDecompositionFeature<out T : Any> : MatrixFeature {
/** /**
* The triangular matrix in this decomposition. It may have either [UFeature] or [LFeature]. * The triangular matrix in this decomposition. It may have either [UFeature] or [LFeature].
*/ */
@ -157,7 +157,7 @@ public interface CholeskyDecompositionFeature<T : Any> : MatrixFeature {
* *
* @param T the type of matrices' items. * @param T the type of matrices' items.
*/ */
public interface SingularValueDecompositionFeature<T : Any> : MatrixFeature { public interface SingularValueDecompositionFeature<out T : Any> : MatrixFeature {
/** /**
* The matrix in this decomposition. It is unitary, and it consists from left singular vectors. * The matrix in this decomposition. It is unitary, and it consists from left singular vectors.
*/ */

View File

@ -34,7 +34,7 @@ public inline fun <T, R> Iterable<T>.cumulative(initial: R, crossinline operatio
public inline fun <T, R> Sequence<T>.cumulative(initial: R, crossinline operation: (R, T) -> R): Sequence<R> = public inline fun <T, R> Sequence<T>.cumulative(initial: R, crossinline operation: (R, T) -> R): Sequence<R> =
Sequence { this@cumulative.iterator().cumulative(initial, operation) } Sequence { this@cumulative.iterator().cumulative(initial, operation) }
public fun <T, R> List<T>.cumulative(initial: R, operation: (R, T) -> R): List<R> = public inline fun <T, R> List<T>.cumulative(initial: R, crossinline operation: (R, T) -> R): List<R> =
iterator().cumulative(initial, operation).asSequence().toList() iterator().cumulative(initial, operation).asSequence().toList()
//Cumulative sum //Cumulative sum

View File

@ -24,7 +24,6 @@ public class ShapeMismatchException(public val expected: IntArray, public val ac
* *
* @param T the type of ND-structure element. * @param T the type of ND-structure element.
* @param C the type of the element context. * @param C the type of the element context.
* @param N the type of the structure.
*/ */
public interface AlgebraND<T, out C : Algebra<T>> { public interface AlgebraND<T, out C : Algebra<T>> {
/** /**
@ -118,8 +117,7 @@ internal fun <T, C : Algebra<T>> AlgebraND<T, C>.checkShape(element: StructureND
* Space of [StructureND]. * Space of [StructureND].
* *
* @param T the type of the element contained in ND structure. * @param T the type of the element contained in ND structure.
* @param N the type of ND structure. * @param S the type of group over structure elements.
* @param S the type of space of structure elements.
*/ */
public interface GroupND<T, out S : Group<T>> : Group<StructureND<T>>, AlgebraND<T, S> { public interface GroupND<T, out S : Group<T>> : Group<StructureND<T>>, AlgebraND<T, S> {
/** /**
@ -186,8 +184,7 @@ public interface GroupND<T, out S : Group<T>> : Group<StructureND<T>>, AlgebraND
* Ring of [StructureND]. * Ring of [StructureND].
* *
* @param T the type of the element contained in ND structure. * @param T the type of the element contained in ND structure.
* @param N the type of ND structure. * @param R the type of ring over structure elements.
* @param R the type of ring of structure elements.
*/ */
public interface RingND<T, out R : Ring<T>> : Ring<StructureND<T>>, GroupND<T, R> { public interface RingND<T, out R : Ring<T>> : Ring<StructureND<T>>, GroupND<T, R> {
/** /**
@ -227,7 +224,7 @@ public interface RingND<T, out R : Ring<T>> : Ring<StructureND<T>>, GroupND<T, R
* Field of [StructureND]. * Field of [StructureND].
* *
* @param T the type of the element contained in ND structure. * @param T the type of the element contained in ND structure.
* @param F the type field of structure elements. * @param F the type field over structure elements.
*/ */
public interface FieldND<T, out F : Field<T>> : Field<StructureND<T>>, RingND<T, F>, ScaleOperations<StructureND<T>> { public interface FieldND<T, out F : Field<T>> : Field<StructureND<T>>, RingND<T, F>, ScaleOperations<StructureND<T>> {
/** /**

View File

@ -57,7 +57,7 @@ public interface BufferAlgebraND<T, out A : Algebra<T>> : AlgebraND<T, A> {
} }
} }
public open class BufferedGroupND<T, A : Group<T>>( public open class BufferedGroupND<T, out A : Group<T>>(
final override val shape: IntArray, final override val shape: IntArray,
final override val elementContext: A, final override val elementContext: A,
final override val bufferFactory: BufferFactory<T>, final override val bufferFactory: BufferFactory<T>,
@ -67,7 +67,7 @@ public open class BufferedGroupND<T, A : Group<T>>(
override fun StructureND<T>.unaryMinus(): StructureND<T> = produce { -get(it) } override fun StructureND<T>.unaryMinus(): StructureND<T> = produce { -get(it) }
} }
public open class BufferedRingND<T, R : Ring<T>>( public open class BufferedRingND<T, out R : Ring<T>>(
shape: IntArray, shape: IntArray,
elementContext: R, elementContext: R,
bufferFactory: BufferFactory<T>, bufferFactory: BufferFactory<T>,
@ -75,7 +75,7 @@ public open class BufferedRingND<T, R : Ring<T>>(
override val one: BufferND<T> by lazy { produce { one } } override val one: BufferND<T> by lazy { produce { one } }
} }
public open class BufferedFieldND<T, R : Field<T>>( public open class BufferedFieldND<T, out R : Field<T>>(
shape: IntArray, shape: IntArray,
elementContext: R, elementContext: R,
bufferFactory: BufferFactory<T>, bufferFactory: BufferFactory<T>,

View File

@ -31,11 +31,10 @@ public class ShortRingND(
/** /**
* Fast element production using function inlining. * Fast element production using function inlining.
*/ */
public inline fun BufferedRingND<Short, ShortRing>.produceInline(crossinline initializer: ShortRing.(Int) -> Short): BufferND<Short> { public inline fun BufferedRingND<Short, ShortRing>.produceInline(crossinline initializer: ShortRing.(Int) -> Short): BufferND<Short> =
return BufferND(strides, ShortBuffer(ShortArray(strides.linearSize) { offset -> ShortRing.initializer(offset) })) BufferND(strides, ShortBuffer(ShortArray(strides.linearSize) { offset -> ShortRing.initializer(offset) }))
}
public inline fun <R> ShortRing.nd(vararg shape: Int, action: ShortRingND.() -> R): R { public inline fun <R> ShortRing.nd(vararg shape: Int, action: ShortRingND.() -> R): R {
contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) } contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) }
return ShortRingND(shape).run(action) return ShortRingND(shape).run(action)
} }

View File

@ -53,7 +53,7 @@ public interface StructureND<out T> {
public fun elements(): Sequence<Pair<IntArray, T>> public fun elements(): Sequence<Pair<IntArray, T>>
/** /**
* Feature is some additional strucure information which allows to access it special properties or hints. * Feature is some additional structure information which allows to access it special properties or hints.
* If the feature is not present, null is returned. * If the feature is not present, null is returned.
*/ */
@UnstableKMathAPI @UnstableKMathAPI

View File

@ -15,7 +15,7 @@ import space.kscience.kmath.misc.UnstableKMathAPI
*/ */
@UnstableKMathAPI @UnstableKMathAPI
@Deprecated("AlgebraElements are considered odd and will be removed in future releases.") @Deprecated("AlgebraElements are considered odd and will be removed in future releases.")
public interface AlgebraElement<T, C : Algebra<T>> { public interface AlgebraElement<T, out C : Algebra<T>> {
/** /**
* The context this element belongs to. * The context this element belongs to.
*/ */

View File

@ -445,7 +445,7 @@ public fun UIntArray.toBigInt(sign: Byte): BigInt {
* Returns null if a valid number can not be read from a string * Returns null if a valid number can not be read from a string
*/ */
public fun String.parseBigInteger(): BigInt? { public fun String.parseBigInteger(): BigInt? {
if (this.isEmpty()) return null if (isEmpty()) return null
val sign: Int val sign: Int
val positivePartIndex = when (this[0]) { val positivePartIndex = when (this[0]) {

View File

@ -14,30 +14,30 @@ import space.kscience.kmath.nd.as2D
* A context that allows to operate on a [MutableBuffer] as on 2d array * A context that allows to operate on a [MutableBuffer] as on 2d array
*/ */
internal class BufferAccessor2D<T : Any>( internal class BufferAccessor2D<T : Any>(
public val rowNum: Int, val rowNum: Int,
public val colNum: Int, val colNum: Int,
val factory: MutableBufferFactory<T>, val factory: MutableBufferFactory<T>,
) { ) {
public operator fun Buffer<T>.get(i: Int, j: Int): T = get(i * colNum + j) operator fun Buffer<T>.get(i: Int, j: Int): T = get(i * colNum + j)
public operator fun MutableBuffer<T>.set(i: Int, j: Int, value: T) { operator fun MutableBuffer<T>.set(i: Int, j: Int, value: T) {
set(i * colNum + j, value) set(i * colNum + j, value)
} }
public inline fun create(crossinline init: (i: Int, j: Int) -> T): MutableBuffer<T> = inline fun create(crossinline init: (i: Int, j: Int) -> T): MutableBuffer<T> =
factory(rowNum * colNum) { offset -> init(offset / colNum, offset % colNum) } factory(rowNum * colNum) { offset -> init(offset / colNum, offset % colNum) }
public fun create(mat: Structure2D<T>): MutableBuffer<T> = create { i, j -> mat[i, j] } fun create(mat: Structure2D<T>): MutableBuffer<T> = create { i, j -> mat[i, j] }
//TODO optimize wrapper //TODO optimize wrapper
public fun MutableBuffer<T>.collect(): Structure2D<T> = StructureND.buffered( fun MutableBuffer<T>.collect(): Structure2D<T> = StructureND.buffered(
DefaultStrides(intArrayOf(rowNum, colNum)), DefaultStrides(intArrayOf(rowNum, colNum)),
factory factory
) { (i, j) -> ) { (i, j) ->
get(i, j) get(i, j)
}.as2D() }.as2D()
public inner class Row(public val buffer: MutableBuffer<T>, public val rowIndex: Int) : MutableBuffer<T> { inner class Row(val buffer: MutableBuffer<T>, val rowIndex: Int) : MutableBuffer<T> {
override val size: Int get() = colNum override val size: Int get() = colNum
override operator fun get(index: Int): T = buffer[rowIndex, index] override operator fun get(index: Int): T = buffer[rowIndex, index]
@ -54,5 +54,5 @@ internal class BufferAccessor2D<T : Any>(
/** /**
* Get row * Get row
*/ */
public fun MutableBuffer<T>.row(i: Int): Row = Row(this, i) fun MutableBuffer<T>.row(i: Int): Row = Row(this, i)
} }

View File

@ -67,9 +67,9 @@ public inline fun <T : Any, reified R : Any> Buffer<T>.map(block: (T) -> R): Buf
* Create a new buffer from this one with the given mapping function. * Create a new buffer from this one with the given mapping function.
* Provided [bufferFactory] is used to construct the new buffer. * Provided [bufferFactory] is used to construct the new buffer.
*/ */
public fun <T : Any, R : Any> Buffer<T>.map( public inline fun <T : Any, R : Any> Buffer<T>.map(
bufferFactory: BufferFactory<R>, bufferFactory: BufferFactory<R>,
block: (T) -> R, crossinline block: (T) -> R,
): Buffer<R> = bufferFactory(size) { block(get(it)) } ): Buffer<R> = bufferFactory(size) { block(get(it)) }
/** /**

View File

@ -1,5 +0,0 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/

View File

@ -10,7 +10,7 @@ import space.kscience.kmath.operations.invoke
import kotlin.test.assertEquals import kotlin.test.assertEquals
import kotlin.test.assertNotEquals import kotlin.test.assertNotEquals
internal class FieldVerifier<T, A : Field<T>>( internal class FieldVerifier<T, out A : Field<T>>(
algebra: A, a: T, b: T, c: T, x: Number, algebra: A, a: T, b: T, c: T, x: Number,
) : RingVerifier<T, A>(algebra, a, b, c, x) { ) : RingVerifier<T, A>(algebra, a, b, c, x) {

View File

@ -10,7 +10,7 @@ import space.kscience.kmath.operations.ScaleOperations
import space.kscience.kmath.operations.invoke import space.kscience.kmath.operations.invoke
import kotlin.test.assertEquals import kotlin.test.assertEquals
internal open class RingVerifier<T, A>(algebra: A, a: T, b: T, c: T, x: Number) : internal open class RingVerifier<T, out A>(algebra: A, a: T, b: T, c: T, x: Number) :
SpaceVerifier<T, A>(algebra, a, b, c, x) where A : Ring<T>, A : ScaleOperations<T> { SpaceVerifier<T, A>(algebra, a, b, c, x) where A : Ring<T>, A : ScaleOperations<T> {
override fun verify() { override fun verify() {

View File

@ -29,4 +29,4 @@ public fun BlockingDoubleChain.map(transform: (Double) -> Double): BlockingDoubl
} }
override suspend fun fork(): BlockingDoubleChain = this@map.fork().map(transform) override suspend fun fork(): BlockingDoubleChain = this@map.fork().map(transform)
} }

View File

@ -15,30 +15,33 @@ public val Dispatchers.Math: CoroutineDispatcher
/** /**
* An imitator of [Deferred] which holds a suspended function block and dispatcher * An imitator of [Deferred] which holds a suspended function block and dispatcher
*/ */
internal class LazyDeferred<T>(val dispatcher: CoroutineDispatcher, val block: suspend CoroutineScope.() -> T) { @PublishedApi
internal class LazyDeferred<out T>(val dispatcher: CoroutineDispatcher, val block: suspend CoroutineScope.() -> T) {
private var deferred: Deferred<T>? = null private var deferred: Deferred<T>? = null
internal fun start(scope: CoroutineScope) { fun start(scope: CoroutineScope) {
if (deferred == null) deferred = scope.async(dispatcher, block = block) if (deferred == null) deferred = scope.async(dispatcher, block = block)
} }
suspend fun await(): T = deferred?.await() ?: error("Coroutine not started") suspend fun await(): T = deferred?.await() ?: error("Coroutine not started")
} }
public class AsyncFlow<T> internal constructor(internal val deferredFlow: Flow<LazyDeferred<T>>) : Flow<T> { public class AsyncFlow<out T> @PublishedApi internal constructor(
@PublishedApi internal val deferredFlow: Flow<LazyDeferred<T>>,
) : Flow<T> {
override suspend fun collect(collector: FlowCollector<T>): Unit = override suspend fun collect(collector: FlowCollector<T>): Unit =
deferredFlow.collect { collector.emit((it.await())) } deferredFlow.collect { collector.emit((it.await())) }
} }
public fun <T, R> Flow<T>.async( public inline fun <T, R> Flow<T>.async(
dispatcher: CoroutineDispatcher = Dispatchers.Default, dispatcher: CoroutineDispatcher = Dispatchers.Default,
block: suspend CoroutineScope.(T) -> R, crossinline block: suspend CoroutineScope.(T) -> R,
): AsyncFlow<R> { ): AsyncFlow<R> {
val flow = map { LazyDeferred(dispatcher) { block(it) } } val flow = map { LazyDeferred(dispatcher) { block(it) } }
return AsyncFlow(flow) return AsyncFlow(flow)
} }
public fun <T, R> AsyncFlow<T>.map(action: (T) -> R): AsyncFlow<R> = public inline fun <T, R> AsyncFlow<T>.map(crossinline action: (T) -> R): AsyncFlow<R> =
AsyncFlow(deferredFlow.map { input -> AsyncFlow(deferredFlow.map { input ->
//TODO add function composition //TODO add function composition
LazyDeferred(input.dispatcher) { LazyDeferred(input.dispatcher) {

View File

@ -17,7 +17,7 @@ import kotlin.jvm.JvmInline
/** /**
* A matrix with compile-time controlled dimension * A matrix with compile-time controlled dimension
*/ */
public interface DMatrix<T, R : Dimension, C : Dimension> : Structure2D<T> { public interface DMatrix<out T, R : Dimension, C : Dimension> : Structure2D<T> {
public companion object { public companion object {
/** /**
* Coerces a regular matrix to a matrix with type-safe dimensions and throws a error if coercion failed * Coerces a regular matrix to a matrix with type-safe dimensions and throws a error if coercion failed
@ -46,7 +46,7 @@ public interface DMatrix<T, R : Dimension, C : Dimension> : Structure2D<T> {
* An inline wrapper for a Matrix * An inline wrapper for a Matrix
*/ */
@JvmInline @JvmInline
public value class DMatrixWrapper<T, R : Dimension, C : Dimension>( public value class DMatrixWrapper<out T, R : Dimension, C : Dimension>(
private val structure: Structure2D<T>, private val structure: Structure2D<T>,
) : DMatrix<T, R, C> { ) : DMatrix<T, R, C> {
override val shape: IntArray get() = structure.shape override val shape: IntArray get() = structure.shape
@ -58,7 +58,7 @@ public value class DMatrixWrapper<T, R : Dimension, C : Dimension>(
/** /**
* Dimension-safe point * Dimension-safe point
*/ */
public interface DPoint<T, D : Dimension> : Point<T> { public interface DPoint<out T, D : Dimension> : Point<T> {
public companion object { public companion object {
public inline fun <T, reified D : Dimension> coerce(point: Point<T>): DPoint<T, D> { public inline fun <T, reified D : Dimension> coerce(point: Point<T>): DPoint<T, D> {
require(point.size == Dimension.dim<D>().toInt()) { require(point.size == Dimension.dim<D>().toInt()) {
@ -76,7 +76,7 @@ public interface DPoint<T, D : Dimension> : Point<T> {
* Dimension-safe point wrapper * Dimension-safe point wrapper
*/ */
@JvmInline @JvmInline
public value class DPointWrapper<T, D : Dimension>(public val point: Point<T>) : public value class DPointWrapper<out T, D : Dimension>(public val point: Point<T>) :
DPoint<T, D> { DPoint<T, D> {
override val size: Int get() = point.size override val size: Int get() = point.size
@ -111,7 +111,7 @@ public value class DMatrixContext<T : Any, out A : Ring<T>>(public val context:
): DMatrix<T, R, C> { ): DMatrix<T, R, C> {
val rows = Dimension.dim<R>() val rows = Dimension.dim<R>()
val cols = Dimension.dim<C>() val cols = Dimension.dim<C>()
return context.buildMatrix(rows.toInt(), cols.toInt(), initializer).coerce<R, C>() return context.buildMatrix(rows.toInt(), cols.toInt(), initializer).coerce()
} }
public inline fun <reified D : Dimension> point(noinline initializer: A.(Int) -> T): DPoint<T, D> { public inline fun <reified D : Dimension> point(noinline initializer: A.(Int) -> T): DPoint<T, D> {

View File

@ -184,4 +184,4 @@ public fun tan(arg: RealMatrix): RealMatrix = arg.map { kotlin.math.tan(it) }
public fun ln(arg: RealMatrix): RealMatrix = arg.map { kotlin.math.ln(it) } public fun ln(arg: RealMatrix): RealMatrix = arg.map { kotlin.math.ln(it) }
public fun log10(arg: RealMatrix): RealMatrix = arg.map { kotlin.math.log10(it) } public fun log10(arg: RealMatrix): RealMatrix = arg.map { kotlin.math.log10(it) }

View File

@ -32,7 +32,7 @@ internal class DoubleMatrixTest {
@Test @Test
fun testSequenceToMatrix() { fun testSequenceToMatrix() {
val m = Sequence<DoubleArray> { val m = Sequence {
listOf( listOf(
DoubleArray(10) { 10.0 }, DoubleArray(10) { 10.0 },
DoubleArray(10) { 20.0 }, DoubleArray(10) { 20.0 },

View File

@ -14,7 +14,7 @@ import space.kscience.kmath.operations.Ring
* @param T the piece key type. * @param T the piece key type.
* @param R the sub-function type. * @param R the sub-function type.
*/ */
public fun interface Piecewise<T, R> { public fun interface Piecewise<in T, out R> {
/** /**
* Returns the appropriate sub-function for given piece key. * Returns the appropriate sub-function for given piece key.
*/ */
@ -23,7 +23,9 @@ public fun interface Piecewise<T, R> {
/** /**
* Represents piecewise-defined function where all the sub-functions are polynomials. * Represents piecewise-defined function where all the sub-functions are polynomials.
* @param pieces An ordered list of range-polynomial pairs. The list does not in general guarantee that there are no "holes" in it. *
* @property pieces An ordered list of range-polynomial pairs. The list does not in general guarantee that there are no
* "holes" in it.
*/ */
public interface PiecewisePolynomial<T : Comparable<T>> : Piecewise<T, Polynomial<T>> { public interface PiecewisePolynomial<T : Comparable<T>> : Piecewise<T, Polynomial<T>> {
public val pieces: Collection<Pair<ClosedRange<T>, Polynomial<T>>> public val pieces: Collection<Pair<ClosedRange<T>, Polynomial<T>>>
@ -45,7 +47,7 @@ public fun <T : Comparable<T>> PiecewisePolynomial(
/** /**
* An optimized piecewise which uses not separate pieces, but a range separated by delimiters. * An optimized piecewise which uses not separate pieces, but a range separated by delimiters.
* The pices search is logarithmic * The pieces search is logarithmic
*/ */
private class OrderedPiecewisePolynomial<T : Comparable<T>>( private class OrderedPiecewisePolynomial<T : Comparable<T>>(
override val pieces: List<Pair<ClosedRange<T>, Polynomial<T>>>, override val pieces: List<Pair<ClosedRange<T>, Polynomial<T>>>,

View File

@ -17,7 +17,7 @@ import kotlin.math.pow
* *
* @param coefficients constant is the leftmost coefficient. * @param coefficients constant is the leftmost coefficient.
*/ */
public class Polynomial<T>(public val coefficients: List<T>) { public class Polynomial<out T>(public val coefficients: List<T>) {
override fun toString(): String = "Polynomial$coefficients" override fun toString(): String = "Polynomial$coefficients"
} }
@ -69,7 +69,7 @@ public fun <T, A> Polynomial<T>.differentiate(
public fun <T, A> Polynomial<T>.integrate( public fun <T, A> Polynomial<T>.integrate(
algebra: A, algebra: A,
): Polynomial<T> where A : Field<T>, A : NumericAlgebra<T> = algebra { ): Polynomial<T> where A : Field<T>, A : NumericAlgebra<T> = algebra {
val integratedCoefficients = buildList<T>(coefficients.size + 1) { val integratedCoefficients = buildList(coefficients.size + 1) {
add(zero) add(zero)
coefficients.forEachIndexed{ index, t -> add(t / (number(index) + one)) } coefficients.forEachIndexed{ index, t -> add(t / (number(index) + one)) }
} }

View File

@ -18,7 +18,7 @@ public interface Integrand {
public inline fun <reified T : IntegrandFeature> Integrand.getFeature(): T? = getFeature(T::class) public inline fun <reified T : IntegrandFeature> Integrand.getFeature(): T? = getFeature(T::class)
public class IntegrandValue<T : Any>(public val value: T) : IntegrandFeature { public class IntegrandValue<out T : Any>(public val value: T) : IntegrandFeature {
override fun toString(): String = "Value($value)" override fun toString(): String = "Value($value)"
} }

View File

@ -93,7 +93,7 @@ public fun <T : Any> UnivariateIntegrator<T>.integrate(
): UnivariateIntegrand<T> = integrate(UnivariateIntegrand(function, IntegrationRange(range), *features)) ): UnivariateIntegrand<T> = integrate(UnivariateIntegrand(function, IntegrationRange(range), *features))
/** /**
* A shortcut method to integrate a [function] in [range] with additional [features]. * A shortcut method to integrate a [function] in [range] with additional features.
* The [function] is placed in the end position to allow passing a lambda. * The [function] is placed in the end position to allow passing a lambda.
*/ */
@UnstableKMathAPI @UnstableKMathAPI

View File

@ -18,7 +18,7 @@ import space.kscience.kmath.structures.asBuffer
/** /**
* And interpolator for data with x column type [X], y column type [Y]. * And interpolator for data with x column type [X], y column type [Y].
*/ */
public fun interface Interpolator<T, X : T, Y : T> { public fun interface Interpolator<T, in X : T, Y : T> {
public fun interpolate(points: XYColumnarData<T, X, Y>): (X) -> Y public fun interpolate(points: XYColumnarData<T, X, Y>): (X) -> Y
} }

View File

@ -5,7 +5,7 @@
package space.kscience.kmath.geometry package space.kscience.kmath.geometry
public data class Line<V : Vector>(val base: V, val direction: V) public data class Line<out V : Vector>(val base: V, val direction: V)
public typealias Line2D = Line<Vector2D> public typealias Line2D = Line<Vector2D>
public typealias Line3D = Line<Vector3D> public typealias Line3D = Line<Vector3D>

View File

@ -13,14 +13,14 @@ import space.kscience.kmath.structures.asBuffer
/** /**
* The binned data element. Could be a histogram bin with a number of counts or an artificial construct * The binned data element. Could be a histogram bin with a number of counts or an artificial construct
*/ */
public interface Bin<T : Any> : Domain<T> { public interface Bin<in T : Any> : Domain<T> {
/** /**
* The value of this bin. * The value of this bin.
*/ */
public val value: Number public val value: Number
} }
public interface Histogram<T : Any, out B : Bin<T>> { public interface Histogram<in T : Any, out B : Bin<T>> {
/** /**
* Find existing bin, corresponding to given coordinates * Find existing bin, corresponding to given coordinates
*/ */
@ -34,7 +34,7 @@ public interface Histogram<T : Any, out B : Bin<T>> {
public val bins: Iterable<B> public val bins: Iterable<B>
} }
public fun interface HistogramBuilder<T : Any> { public fun interface HistogramBuilder<in T : Any> {
/** /**
* Increment appropriate bin * Increment appropriate bin

View File

@ -18,7 +18,7 @@ import space.kscience.kmath.operations.invoke
/** /**
* A simple histogram bin based on domain * A simple histogram bin based on domain
*/ */
public data class DomainBin<T : Comparable<T>>( public data class DomainBin<in T : Comparable<T>>(
public val domain: Domain<T>, public val domain: Domain<T>,
public override val value: Number, public override val value: Number,
) : Bin<T>, Domain<T> by domain ) : Bin<T>, Domain<T> by domain

View File

@ -36,9 +36,10 @@ public class TreeHistogram(
} }
@OptIn(UnstableKMathAPI::class) @OptIn(UnstableKMathAPI::class)
private class TreeHistogramBuilder(val binFactory: (Double) -> UnivariateDomain) : UnivariateHistogramBuilder { @PublishedApi
internal class TreeHistogramBuilder(val binFactory: (Double) -> UnivariateDomain) : UnivariateHistogramBuilder {
private class BinCounter(val domain: UnivariateDomain, val counter: Counter<Double> = Counter.real()) : internal class BinCounter(val domain: UnivariateDomain, val counter: Counter<Double> = Counter.real()) :
ClosedFloatingPointRange<Double> by domain.range ClosedFloatingPointRange<Double> by domain.range
private val bins: TreeMap<Double, BinCounter> = TreeMap() private val bins: TreeMap<Double, BinCounter> = TreeMap()
@ -80,10 +81,10 @@ private class TreeHistogramBuilder(val binFactory: (Double) -> UnivariateDomain)
*/ */
@UnstableKMathAPI @UnstableKMathAPI
public class TreeHistogramSpace( public class TreeHistogramSpace(
private val binFactory: (Double) -> UnivariateDomain, @PublishedApi internal val binFactory: (Double) -> UnivariateDomain,
) : Group<UnivariateHistogram>, ScaleOperations<UnivariateHistogram> { ) : Group<UnivariateHistogram>, ScaleOperations<UnivariateHistogram> {
public fun fill(block: UnivariateHistogramBuilder.() -> Unit): UnivariateHistogram = public inline fun fill(block: UnivariateHistogramBuilder.() -> Unit): UnivariateHistogram =
TreeHistogramBuilder(binFactory).apply(block).build() TreeHistogramBuilder(binFactory).apply(block).build()
override fun add( override fun add(
@ -115,8 +116,8 @@ public class TreeHistogramSpace(
bin.domain.center, bin.domain.center,
UnivariateBin( UnivariateBin(
bin.domain, bin.domain,
value = bin.value * value.toDouble(), value = bin.value * value,
standardDeviation = abs(bin.standardDeviation * value.toDouble()) standardDeviation = abs(bin.standardDeviation * value)
) )
) )
} }

View File

@ -18,14 +18,14 @@ readme {
feature( feature(
"differentiable-mst-expression", "differentiable-mst-expression",
"src/main/kotlin/space/kscience/kmath/kotlingrad/DifferentiableMstExpression.kt", "src/main/kotlin/space/kscience/kmath/kotlingrad/KotlingradExpression.kt",
) { ) {
"MST based DifferentiableExpression." "MST based DifferentiableExpression."
} }
feature( feature(
"differentiable-mst-expression", "scalars-adapters",
"src/main/kotlin/space/kscience/kmath/kotlingrad/DifferentiableMstExpression.kt", "src/main/kotlin/space/kscience/kmath/kotlingrad/scalarsAdapters.kt",
) { ) {
"Conversions between Kotlin∇'s SFun and MST" "Conversions between Kotlin∇'s SFun and MST"
} }

View File

@ -25,7 +25,7 @@ private class Nd4jArrayIndicesIterator(private val iterateOver: INDArray) : Iter
internal fun INDArray.indicesIterator(): Iterator<IntArray> = Nd4jArrayIndicesIterator(this) internal fun INDArray.indicesIterator(): Iterator<IntArray> = Nd4jArrayIndicesIterator(this)
private sealed class Nd4jArrayIteratorBase<T>(protected val iterateOver: INDArray) : Iterator<Pair<IntArray, T>> { private sealed class Nd4jArrayIteratorBase<out T>(protected val iterateOver: INDArray) : Iterator<Pair<IntArray, T>> {
private var i: Int = 0 private var i: Int = 0
final override fun hasNext(): Boolean = i < iterateOver.length() final override fun hasNext(): Boolean = i < iterateOver.length()

View File

@ -9,7 +9,7 @@ import space.kscience.kmath.expressions.Symbol
public interface OptimizationFeature public interface OptimizationFeature
public class OptimizationResult<T>( public class OptimizationResult<out T>(
public val point: Map<Symbol, T>, public val point: Map<Symbol, T>,
public val value: T, public val value: T,
public val features: Set<OptimizationFeature> = emptySet(), public val features: Set<OptimizationFeature> = emptySet(),

View File

@ -1,18 +0,0 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.samplers
import space.kscience.kmath.chains.BlockingBufferChain
import space.kscience.kmath.stat.RandomGenerator
import space.kscience.kmath.stat.Sampler
import space.kscience.kmath.structures.Buffer
public class ConstantSampler<T : Any>(public val const: T) : Sampler<T> {
override fun sample(generator: RandomGenerator): BlockingBufferChain<T> = object : BlockingBufferChain<T> {
override fun nextBufferBlocking(size: Int): Buffer<T> = Buffer.boxing(size) { const }
override suspend fun fork(): BlockingBufferChain<T> = this
}
}

View File

@ -6,6 +6,8 @@
package space.kscience.kmath.stat package space.kscience.kmath.stat
import kotlinx.coroutines.* import kotlinx.coroutines.*
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
import kotlin.coroutines.CoroutineContext import kotlin.coroutines.CoroutineContext
import kotlin.coroutines.EmptyCoroutineContext import kotlin.coroutines.EmptyCoroutineContext
import kotlin.coroutines.coroutineContext import kotlin.coroutines.coroutineContext
@ -23,14 +25,18 @@ public class MCScope(
/** /**
* Launches a supervised Monte-Carlo scope * Launches a supervised Monte-Carlo scope
*/ */
public suspend inline fun <T> mcScope(generator: RandomGenerator, block: MCScope.() -> T): T = public suspend inline fun <T> mcScope(generator: RandomGenerator, block: MCScope.() -> T): T {
MCScope(coroutineContext, generator).block() contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return MCScope(coroutineContext, generator).block()
}
/** /**
* Launch mc scope with a given seed * Launch mc scope with a given seed
*/ */
public suspend inline fun <T> mcScope(seed: Long, block: MCScope.() -> T): T = public suspend inline fun <T> mcScope(seed: Long, block: MCScope.() -> T): T {
mcScope(RandomGenerator.default(seed), block) contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return mcScope(RandomGenerator.default(seed), block)
}
/** /**
* Specialized launch for [MCScope]. Behaves the same way as regular [CoroutineScope.launch], but also stores the generator fork. * Specialized launch for [MCScope]. Behaves the same way as regular [CoroutineScope.launch], but also stores the generator fork.

View File

@ -12,7 +12,7 @@ import space.kscience.kmath.structures.*
import kotlin.jvm.JvmName import kotlin.jvm.JvmName
/** /**
* Sampler that generates chains of values of type [T] in a chain of type [C]. * Sampler that generates chains of values of type [T].
*/ */
public fun interface Sampler<out T : Any> { public fun interface Sampler<out T : Any> {
/** /**

View File

@ -18,7 +18,7 @@ import space.kscience.kmath.operations.invoke
* *
* @property value the value to sample. * @property value the value to sample.
*/ */
public class ConstantSampler<T : Any>(public val value: T) : Sampler<T> { public class ConstantSampler<out T : Any>(public val value: T) : Sampler<T> {
public override fun sample(generator: RandomGenerator): Chain<T> = ConstantChain(value) public override fun sample(generator: RandomGenerator): Chain<T> = ConstantChain(value)
} }
@ -27,7 +27,7 @@ public class ConstantSampler<T : Any>(public val value: T) : Sampler<T> {
* *
* @property chainBuilder the provider of [Chain]. * @property chainBuilder the provider of [Chain].
*/ */
public class BasicSampler<T : Any>(public val chainBuilder: (RandomGenerator) -> Chain<T>) : Sampler<T> { public class BasicSampler<out T : Any>(public val chainBuilder: (RandomGenerator) -> Chain<T>) : Sampler<T> {
public override fun sample(generator: RandomGenerator): Chain<T> = chainBuilder(generator) public override fun sample(generator: RandomGenerator): Chain<T> = chainBuilder(generator)
} }
@ -36,7 +36,7 @@ public class BasicSampler<T : Any>(public val chainBuilder: (RandomGenerator) ->
* *
* @property algebra the space to provide addition and scalar multiplication for [T]. * @property algebra the space to provide addition and scalar multiplication for [T].
*/ */
public class SamplerSpace<T : Any, S>(public val algebra: S) : Group<Sampler<T>>, public class SamplerSpace<T : Any, out S>(public val algebra: S) : Group<Sampler<T>>,
ScaleOperations<Sampler<T>> where S : Group<T>, S : ScaleOperations<T> { ScaleOperations<Sampler<T>> where S : Group<T>, S : ScaleOperations<T> {
public override val zero: Sampler<T> = ConstantSampler(algebra.zero) public override val zero: Sampler<T> = ConstantSampler(algebra.zero)

View File

@ -18,14 +18,14 @@ import space.kscience.kmath.structures.Buffer
/** /**
* A function, that transforms a buffer of random quantities to some resulting value * A function, that transforms a buffer of random quantities to some resulting value
*/ */
public interface Statistic<T, R> { public interface Statistic<in T, out R> {
public suspend fun evaluate(data: Buffer<T>): R public suspend fun evaluate(data: Buffer<T>): R
} }
public interface BlockingStatistic<T,R>: Statistic<T,R>{ public interface BlockingStatistic<in T, out R> : Statistic<T, R> {
public fun evaluateBlocking(data: Buffer<T>): R public fun evaluateBlocking(data: Buffer<T>): R
override suspend fun evaluate(data: Buffer<T>): R = evaluateBlocking(data) override suspend fun evaluate(data: Buffer<T>): R = evaluateBlocking(data)
} }
/** /**
@ -34,7 +34,7 @@ public interface BlockingStatistic<T,R>: Statistic<T,R>{
* @param I - intermediate block type * @param I - intermediate block type
* @param R - result type * @param R - result type
*/ */
public interface ComposableStatistic<T, I, R> : Statistic<T, R> { public interface ComposableStatistic<in T, I, out R> : Statistic<T, R> {
//compute statistic on a single block //compute statistic on a single block
public suspend fun computeIntermediate(data: Buffer<T>): I public suspend fun computeIntermediate(data: Buffer<T>): I

View File

@ -14,7 +14,8 @@ import space.kscience.kmath.expressions.interpret
import space.kscience.kmath.operations.NumericAlgebra import space.kscience.kmath.operations.NumericAlgebra
/** /**
* Represents [MST] based [DifferentiableExpression] relying on [Symja](https://github.com/axkr/symja_android_library). * Represents [MST] based [space.kscience.kmath.expressions.DifferentiableExpression] relying on
* [Symja](https://github.com/axkr/symja_android_library).
* *
* The principle of this API is converting the [mst] to an [org.matheclipse.core.interfaces.IExpr], differentiating it * The principle of this API is converting the [mst] to an [org.matheclipse.core.interfaces.IExpr], differentiating it
* with Symja's [F.D], then converting [org.matheclipse.core.interfaces.IExpr] back to [MST]. * with Symja's [F.D], then converting [org.matheclipse.core.interfaces.IExpr] back to [MST].

View File

@ -3,7 +3,7 @@
Common linear algebra operations on tensors. Common linear algebra operations on tensors.
- [tensor algebra](src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt) : Basic linear algebra operations on tensors (plus, dot, etc.) - [tensor algebra](src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt) : Basic linear algebra operations on tensors (plus, dot, etc.)
- [tensor algebra with broadcasting](src/commonMain/kotlin/space/kscience/kmath/tensors/core/algebras/BroadcastDoubleTensorAlgebra.kt) : Basic linear algebra operations implemented with broadcasting. - [tensor algebra with broadcasting](src/commonMain/kotlin/space/kscience/kmath/tensors/core/BroadcastDoubleTensorAlgebra.kt) : Basic linear algebra operations implemented with broadcasting.
- [linear algebra operations](src/commonMain/kotlin/space/kscience/kmath/tensors/api/LinearOpsTensorAlgebra.kt) : Advanced linear algebra operations like LU decomposition, SVD, etc. - [linear algebra operations](src/commonMain/kotlin/space/kscience/kmath/tensors/api/LinearOpsTensorAlgebra.kt) : Advanced linear algebra operations like LU decomposition, SVD, etc.

View File

@ -24,20 +24,16 @@ readme {
feature( feature(
id = "tensor algebra", id = "tensor algebra",
description = "Basic linear algebra operations on tensors (plus, dot, etc.)",
ref = "src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt" ref = "src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt"
) ) { "Basic linear algebra operations on tensors (plus, dot, etc.)" }
feature( feature(
id = "tensor algebra with broadcasting", id = "tensor algebra with broadcasting",
description = "Basic linear algebra operations implemented with broadcasting.", ref = "src/commonMain/kotlin/space/kscience/kmath/tensors/core/BroadcastDoubleTensorAlgebra.kt"
ref = "src/commonMain/kotlin/space/kscience/kmath/tensors/core/algebras/BroadcastDoubleTensorAlgebra.kt" ) { "Basic linear algebra operations implemented with broadcasting." }
)
feature( feature(
id = "linear algebra operations", id = "linear algebra operations",
description = "Advanced linear algebra operations like LU decomposition, SVD, etc.",
ref = "src/commonMain/kotlin/space/kscience/kmath/tensors/api/LinearOpsTensorAlgebra.kt" ref = "src/commonMain/kotlin/space/kscience/kmath/tensors/api/LinearOpsTensorAlgebra.kt"
) ) { "Advanced linear algebra operations like LU decomposition, SVD, etc." }
}
}

View File

@ -11,8 +11,8 @@ import space.kscience.kmath.tensors.core.internal.TensorLinearStructure
*/ */
public open class BufferedTensor<T> internal constructor( public open class BufferedTensor<T> internal constructor(
override val shape: IntArray, override val shape: IntArray,
internal val mutableBuffer: MutableBuffer<T>, @PublishedApi internal val mutableBuffer: MutableBuffer<T>,
internal val bufferStart: Int @PublishedApi internal val bufferStart: Int,
) : Tensor<T> { ) : Tensor<T> {
/** /**

View File

@ -11,7 +11,7 @@ import space.kscience.kmath.tensors.core.internal.toPrettyString
/** /**
* Default [BufferedTensor] implementation for [Double] values * Default [BufferedTensor] implementation for [Double] values
*/ */
public class DoubleTensor internal constructor( public class DoubleTensor @PublishedApi internal constructor(
shape: IntArray, shape: IntArray,
buffer: DoubleArray, buffer: DoubleArray,
offset: Int = 0 offset: Int = 0

View File

@ -426,13 +426,11 @@ public open class DoubleTensorAlgebra :
* @param transform the function to be applied to each element of the tensor. * @param transform the function to be applied to each element of the tensor.
* @return the resulting tensor after applying the function. * @return the resulting tensor after applying the function.
*/ */
public fun Tensor<Double>.map(transform: (Double) -> Double): DoubleTensor { public inline fun Tensor<Double>.map(transform: (Double) -> Double): DoubleTensor = DoubleTensor(
return DoubleTensor( tensor.shape,
tensor.shape, tensor.mutableBuffer.array().map { transform(it) }.toDoubleArray(),
tensor.mutableBuffer.array().map { transform(it) }.toDoubleArray(), tensor.bufferStart
tensor.bufferStart )
)
}
/** /**
* Compares element-wise two tensors with a specified precision. * Compares element-wise two tensors with a specified precision.
@ -519,14 +517,12 @@ public open class DoubleTensorAlgebra :
* @param indices the [IntArray] of 1-dimensional indices * @param indices the [IntArray] of 1-dimensional indices
* @return tensor with rows corresponding to rows by [indices] * @return tensor with rows corresponding to rows by [indices]
*/ */
public fun Tensor<Double>.rowsByIndices(indices: IntArray): DoubleTensor { public fun Tensor<Double>.rowsByIndices(indices: IntArray): DoubleTensor = stack(indices.map { this[it] })
return stack(indices.map { this[it] })
}
internal fun Tensor<Double>.fold(foldFunction: (DoubleArray) -> Double): Double = internal inline fun Tensor<Double>.fold(foldFunction: (DoubleArray) -> Double): Double =
foldFunction(tensor.toDoubleArray()) foldFunction(tensor.toDoubleArray())
internal fun Tensor<Double>.foldDim( internal inline fun Tensor<Double>.foldDim(
foldFunction: (DoubleArray) -> Double, foldFunction: (DoubleArray) -> Double,
dim: Int, dim: Int,
keepDim: Boolean, keepDim: Boolean,

View File

@ -10,15 +10,15 @@ import space.kscience.kmath.nd.MutableStructure2D
import space.kscience.kmath.nd.as1D import space.kscience.kmath.nd.as1D
import space.kscience.kmath.nd.as2D import space.kscience.kmath.nd.as2D
import space.kscience.kmath.operations.invoke import space.kscience.kmath.operations.invoke
import space.kscience.kmath.tensors.core.* import space.kscience.kmath.tensors.core.BufferedTensor
import space.kscience.kmath.tensors.core.DoubleTensor
import space.kscience.kmath.tensors.core.DoubleTensorAlgebra import space.kscience.kmath.tensors.core.DoubleTensorAlgebra
import space.kscience.kmath.tensors.core.DoubleTensorAlgebra.Companion.valueOrNull import space.kscience.kmath.tensors.core.IntTensor
import kotlin.math.abs import kotlin.math.abs
import kotlin.math.min import kotlin.math.min
import kotlin.math.sign import kotlin.math.sign
import kotlin.math.sqrt import kotlin.math.sqrt
internal fun <T> BufferedTensor<T>.vectorSequence(): Sequence<BufferedTensor<T>> = sequence { internal fun <T> BufferedTensor<T>.vectorSequence(): Sequence<BufferedTensor<T>> = sequence {
val n = shape.size val n = shape.size
val vectorOffset = shape[n - 1] val vectorOffset = shape[n - 1]

View File

@ -31,6 +31,7 @@ internal fun <T> Tensor<T>.toBufferedTensor(): BufferedTensor<T> = when (this) {
else -> this.copyToBufferedTensor() else -> this.copyToBufferedTensor()
} }
@PublishedApi
internal val Tensor<Double>.tensor: DoubleTensor internal val Tensor<Double>.tensor: DoubleTensor
get() = when (this) { get() = when (this) {
is DoubleTensor -> this is DoubleTensor -> this

View File

@ -24,6 +24,7 @@ internal fun Buffer<Int>.array(): IntArray = when (this) {
/** /**
* Returns a reference to [DoubleArray] containing all of the elements of this [Buffer] or copy the data. * Returns a reference to [DoubleArray] containing all of the elements of this [Buffer] or copy the data.
*/ */
@PublishedApi
internal fun Buffer<Double>.array(): DoubleArray = when (this) { internal fun Buffer<Double>.array(): DoubleArray = when (this) {
is DoubleBuffer -> array is DoubleBuffer -> array
else -> this.toDoubleArray() else -> this.toDoubleArray()

View File

@ -183,7 +183,7 @@ internal class TestDoubleLinearOpsTensorAlgebra {
} }
private fun DoubleTensorAlgebra.testSVDFor(tensor: DoubleTensor, epsilon: Double = 1e-10): Unit { private fun DoubleTensorAlgebra.testSVDFor(tensor: DoubleTensor, epsilon: Double = 1e-10) {
val svd = tensor.svd() val svd = tensor.svd()
val tensorSVD = svd.first val tensorSVD = svd.first

View File

@ -17,7 +17,7 @@ internal class TestDoubleTensorAlgebra {
} }
@Test @Test
fun TestDoubleDiv() = DoubleTensorAlgebra { fun testDoubleDiv() = DoubleTensorAlgebra {
val tensor = fromArray(intArrayOf(2), doubleArrayOf(2.0, 4.0)) val tensor = fromArray(intArrayOf(2), doubleArrayOf(2.0, 4.0))
val res = 2.0/tensor val res = 2.0/tensor
assertTrue(res.mutableBuffer.array() contentEquals doubleArrayOf(1.0, 0.5)) assertTrue(res.mutableBuffer.array() contentEquals doubleArrayOf(1.0, 0.5))

View File

@ -92,7 +92,7 @@ public class ViktorFieldND(public override val shape: IntArray) : FieldND<Double
(a.f64Buffer + b.f64Buffer).asStructure() (a.f64Buffer + b.f64Buffer).asStructure()
public override fun scale(a: StructureND<Double>, value: Double): ViktorStructureND = public override fun scale(a: StructureND<Double>, value: Double): ViktorStructureND =
(a.f64Buffer * value.toDouble()).asStructure() (a.f64Buffer * value).asStructure()
public override inline fun StructureND<Double>.plus(b: StructureND<Double>): ViktorStructureND = public override inline fun StructureND<Double>.plus(b: StructureND<Double>): ViktorStructureND =
(f64Buffer + b.f64Buffer).asStructure() (f64Buffer + b.f64Buffer).asStructure()