This commit is contained in:
Alexander Nozik 2023-11-03 09:56:19 +03:00
parent ea887b8c72
commit 46eacbb750
44 changed files with 451 additions and 269 deletions

View File

@ -3,10 +3,11 @@
## Unreleased ## Unreleased
### Added ### Added
- Integer division algebras - Explicit `SafeType` for algebras and buffers.
- Float32 geometries - Integer division algebras.
- New Attributes-kt module that could be used as stand-alone. It declares type-safe attributes containers. - Float32 geometries.
- Explicit `mutableStructureND` builders for mutable structures - New Attributes-kt module that could be used as stand-alone. It declares. type-safe attributes containers.
- Explicit `mutableStructureND` builders for mutable structures.
### Changed ### Changed
- Default naming for algebra and buffers now uses IntXX/FloatXX notation instead of Java types. - Default naming for algebra and buffers now uses IntXX/FloatXX notation instead of Java types.

View File

@ -16,7 +16,7 @@ import kotlin.reflect.typeOf
* @param kType raw [KType] * @param kType raw [KType]
*/ */
@JvmInline @JvmInline
public value class SafeType<T> @PublishedApi internal constructor(public val kType: KType) public value class SafeType<out T> @PublishedApi internal constructor(public val kType: KType)
public inline fun <reified T> safeTypeOf(): SafeType<T> = SafeType(typeOf<T>()) public inline fun <reified T> safeTypeOf(): SafeType<T> = SafeType(typeOf<T>())

View File

@ -52,10 +52,10 @@ public interface XYColumnarData<out T, out X : T, out Y : T> : ColumnarData<T> {
* Create two-column data from a list of row-objects (zero-copy) * Create two-column data from a list of row-objects (zero-copy)
*/ */
@UnstableKMathAPI @UnstableKMathAPI
public fun <I, T, X : T, Y : T> ofList( public inline fun <I, T, reified X : T, reified Y : T> ofList(
list: List<I>, list: List<I>,
xConverter: (I) -> X, noinline xConverter: (I) -> X,
yConverter: (I) -> Y, noinline yConverter: (I) -> Y,
): XYColumnarData<T, X, Y> = object : XYColumnarData<T, X, Y> { ): XYColumnarData<T, X, Y> = object : XYColumnarData<T, X, Y> {
override val size: Int get() = list.size override val size: Int get() = list.size

View File

@ -13,8 +13,6 @@ import space.kscience.kmath.structures.MutableBufferFactory
import space.kscience.kmath.structures.asBuffer import space.kscience.kmath.structures.asBuffer
import kotlin.math.max import kotlin.math.max
import kotlin.math.min import kotlin.math.min
import kotlin.reflect.KType
import kotlin.reflect.typeOf
/** /**
* Class representing both the value and the differentials of a function. * Class representing both the value and the differentials of a function.
@ -84,6 +82,8 @@ public abstract class DSAlgebra<T, A : Ring<T>>(
bindings: Map<Symbol, T>, bindings: Map<Symbol, T>,
) : ExpressionAlgebra<T, DS<T, A>>, SymbolIndexer { ) : ExpressionAlgebra<T, DS<T, A>>, SymbolIndexer {
override val bufferFactory: MutableBufferFactory<DS<T, A>> = MutableBufferFactory()
/** /**
* Get the compiler for number of free parameters and order. * Get the compiler for number of free parameters and order.
* *

View File

@ -6,6 +6,7 @@
package space.kscience.kmath.expressions package space.kscience.kmath.expressions
import space.kscience.kmath.operations.* import space.kscience.kmath.operations.*
import space.kscience.kmath.structures.MutableBufferFactory
import kotlin.contracts.InvocationKind import kotlin.contracts.InvocationKind
import kotlin.contracts.contract import kotlin.contracts.contract
@ -17,6 +18,8 @@ import kotlin.contracts.contract
public abstract class FunctionalExpressionAlgebra<T, out A : Algebra<T>>( public abstract class FunctionalExpressionAlgebra<T, out A : Algebra<T>>(
public val algebra: A, public val algebra: A,
) : ExpressionAlgebra<T, Expression<T>> { ) : ExpressionAlgebra<T, Expression<T>> {
override val bufferFactory: MutableBufferFactory<Expression<T>> = MutableBufferFactory<Expression<T>>()
/** /**
* Builds an Expression of constant expression that does not depend on arguments. * Builds an Expression of constant expression that does not depend on arguments.
*/ */
@ -49,6 +52,7 @@ public abstract class FunctionalExpressionAlgebra<T, out A : Algebra<T>>(
public open class FunctionalExpressionGroup<T, out A : Group<T>>( public open class FunctionalExpressionGroup<T, out A : Group<T>>(
algebra: A, algebra: A,
) : FunctionalExpressionAlgebra<T, A>(algebra), Group<Expression<T>> { ) : FunctionalExpressionAlgebra<T, A>(algebra), Group<Expression<T>> {
override val zero: Expression<T> get() = const(algebra.zero) override val zero: Expression<T> get() = const(algebra.zero)
override fun Expression<T>.unaryMinus(): Expression<T> = override fun Expression<T>.unaryMinus(): Expression<T> =

View File

@ -7,11 +7,14 @@ package space.kscience.kmath.expressions
import space.kscience.kmath.UnstableKMathAPI import space.kscience.kmath.UnstableKMathAPI
import space.kscience.kmath.operations.* import space.kscience.kmath.operations.*
import space.kscience.kmath.structures.MutableBufferFactory
/** /**
* [Algebra] over [MST] nodes. * [Algebra] over [MST] nodes.
*/ */
public object MstNumericAlgebra : NumericAlgebra<MST> { public object MstNumericAlgebra : NumericAlgebra<MST> {
override val bufferFactory: MutableBufferFactory<MST> = MutableBufferFactory()
override fun number(value: Number): MST.Numeric = MST.Numeric(value) override fun number(value: Number): MST.Numeric = MST.Numeric(value)
override fun bindSymbolOrNull(value: String): Symbol = StringSymbol(value) override fun bindSymbolOrNull(value: String): Symbol = StringSymbol(value)
override fun bindSymbol(value: String): Symbol = bindSymbolOrNull(value) override fun bindSymbol(value: String): Symbol = bindSymbolOrNull(value)
@ -27,6 +30,9 @@ public object MstNumericAlgebra : NumericAlgebra<MST> {
* [Group] over [MST] nodes. * [Group] over [MST] nodes.
*/ */
public object MstGroup : Group<MST>, NumericAlgebra<MST>, ScaleOperations<MST> { public object MstGroup : Group<MST>, NumericAlgebra<MST>, ScaleOperations<MST> {
override val bufferFactory: MutableBufferFactory<MST> = MutableBufferFactory()
override val zero: MST.Numeric = number(0.0) override val zero: MST.Numeric = number(0.0)
override fun number(value: Number): MST.Numeric = MstNumericAlgebra.number(value) override fun number(value: Number): MST.Numeric = MstNumericAlgebra.number(value)
@ -57,7 +63,11 @@ public object MstGroup : Group<MST>, NumericAlgebra<MST>, ScaleOperations<MST> {
@Suppress("OVERRIDE_BY_INLINE") @Suppress("OVERRIDE_BY_INLINE")
@OptIn(UnstableKMathAPI::class) @OptIn(UnstableKMathAPI::class)
public object MstRing : Ring<MST>, NumbersAddOps<MST>, ScaleOperations<MST> { public object MstRing : Ring<MST>, NumbersAddOps<MST>, ScaleOperations<MST> {
override val bufferFactory: MutableBufferFactory<MST> = MutableBufferFactory()
override inline val zero: MST.Numeric get() = MstGroup.zero override inline val zero: MST.Numeric get() = MstGroup.zero
override val one: MST.Numeric = number(1.0) override val one: MST.Numeric = number(1.0)
override fun number(value: Number): MST.Numeric = MstGroup.number(value) override fun number(value: Number): MST.Numeric = MstGroup.number(value)
@ -87,7 +97,11 @@ public object MstRing : Ring<MST>, NumbersAddOps<MST>, ScaleOperations<MST> {
@Suppress("OVERRIDE_BY_INLINE") @Suppress("OVERRIDE_BY_INLINE")
@OptIn(UnstableKMathAPI::class) @OptIn(UnstableKMathAPI::class)
public object MstField : Field<MST>, NumbersAddOps<MST>, ScaleOperations<MST> { public object MstField : Field<MST>, NumbersAddOps<MST>, ScaleOperations<MST> {
override val bufferFactory: MutableBufferFactory<MST> = MutableBufferFactory()
override inline val zero: MST.Numeric get() = MstRing.zero override inline val zero: MST.Numeric get() = MstRing.zero
override inline val one: MST.Numeric get() = MstRing.one override inline val one: MST.Numeric get() = MstRing.one
override fun bindSymbolOrNull(value: String): Symbol = MstNumericAlgebra.bindSymbolOrNull(value) override fun bindSymbolOrNull(value: String): Symbol = MstNumericAlgebra.bindSymbolOrNull(value)
@ -117,7 +131,11 @@ public object MstField : Field<MST>, NumbersAddOps<MST>, ScaleOperations<MST> {
*/ */
@Suppress("OVERRIDE_BY_INLINE") @Suppress("OVERRIDE_BY_INLINE")
public object MstExtendedField : ExtendedField<MST>, NumericAlgebra<MST> { public object MstExtendedField : ExtendedField<MST>, NumericAlgebra<MST> {
override val bufferFactory: MutableBufferFactory<MST> = MutableBufferFactory()
override inline val zero: MST.Numeric get() = MstField.zero override inline val zero: MST.Numeric get() = MstField.zero
override inline val one: MST.Numeric get() = MstField.one override inline val one: MST.Numeric get() = MstField.one
override fun bindSymbolOrNull(value: String): Symbol = MstNumericAlgebra.bindSymbolOrNull(value) override fun bindSymbolOrNull(value: String): Symbol = MstNumericAlgebra.bindSymbolOrNull(value)
@ -164,6 +182,9 @@ public object MstExtendedField : ExtendedField<MST>, NumericAlgebra<MST> {
*/ */
@UnstableKMathAPI @UnstableKMathAPI
public object MstLogicAlgebra : LogicAlgebra<MST> { public object MstLogicAlgebra : LogicAlgebra<MST> {
override val bufferFactory: MutableBufferFactory<MST> = MutableBufferFactory()
override fun bindSymbolOrNull(value: String): MST = super.bindSymbolOrNull(value) ?: StringSymbol(value) override fun bindSymbolOrNull(value: String): MST = super.bindSymbolOrNull(value) ?: StringSymbol(value)
override fun const(boolean: Boolean): Symbol = if (boolean) { override fun const(boolean: Boolean): Symbol = if (boolean) {

View File

@ -5,9 +5,11 @@
package space.kscience.kmath.expressions package space.kscience.kmath.expressions
import space.kscience.attributes.SafeType
import space.kscience.kmath.UnstableKMathAPI import space.kscience.kmath.UnstableKMathAPI
import space.kscience.kmath.linear.Point import space.kscience.kmath.linear.Point
import space.kscience.kmath.operations.* import space.kscience.kmath.operations.*
import space.kscience.kmath.structures.MutableBufferFactory
import space.kscience.kmath.structures.asBuffer import space.kscience.kmath.structures.asBuffer
import kotlin.contracts.InvocationKind import kotlin.contracts.InvocationKind
import kotlin.contracts.contract import kotlin.contracts.contract
@ -30,9 +32,10 @@ public open class AutoDiffValue<out T>(public val value: T)
*/ */
public class DerivationResult<T : Any>( public class DerivationResult<T : Any>(
public val value: T, public val value: T,
override val type: SafeType<T>,
private val derivativeValues: Map<String, T>, private val derivativeValues: Map<String, T>,
public val context: Field<T>, public val context: Field<T>,
) { ) : WithType<T> {
/** /**
* Returns derivative of [variable] or returns [Ring.zero] in [context]. * Returns derivative of [variable] or returns [Ring.zero] in [context].
*/ */
@ -49,19 +52,23 @@ public class DerivationResult<T : Any>(
*/ */
public fun <T : Any> DerivationResult<T>.grad(vararg variables: Symbol): Point<T> { public fun <T : Any> DerivationResult<T>.grad(vararg variables: Symbol): Point<T> {
check(variables.isNotEmpty()) { "Variable order is not provided for gradient construction" } check(variables.isNotEmpty()) { "Variable order is not provided for gradient construction" }
return variables.map(::derivative).asBuffer() return variables.map(::derivative).asBuffer(type)
} }
/** /**
* Represents field in context of which functions can be derived. * Represents field. Function derivatives could be computed in this field
*/ */
@OptIn(UnstableKMathAPI::class) @OptIn(UnstableKMathAPI::class)
public open class SimpleAutoDiffField<T : Any, F : Field<T>>( public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
public val context: F, public val algebra: F,
bindings: Map<Symbol, T>, bindings: Map<Symbol, T>,
) : Field<AutoDiffValue<T>>, ExpressionAlgebra<T, AutoDiffValue<T>>, NumbersAddOps<AutoDiffValue<T>> { ) : Field<AutoDiffValue<T>>, ExpressionAlgebra<T, AutoDiffValue<T>>, NumbersAddOps<AutoDiffValue<T>> {
override val zero: AutoDiffValue<T> get() = const(context.zero)
override val one: AutoDiffValue<T> get() = const(context.one) override val bufferFactory: MutableBufferFactory<AutoDiffValue<T>> = MutableBufferFactory<AutoDiffValue<T>>()
override val zero: AutoDiffValue<T> get() = const(algebra.zero)
override val one: AutoDiffValue<T> get() = const(algebra.one)
// this stack contains pairs of blocks and values to apply them to // this stack contains pairs of blocks and values to apply them to
private var stack: Array<Any?> = arrayOfNulls<Any?>(8) private var stack: Array<Any?> = arrayOfNulls<Any?>(8)
@ -69,7 +76,7 @@ public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
private val derivatives: MutableMap<AutoDiffValue<T>, T> = hashMapOf() private val derivatives: MutableMap<AutoDiffValue<T>, T> = hashMapOf()
private val bindings: Map<String, AutoDiffVariableWithDerivative<T>> = bindings.entries.associate { private val bindings: Map<String, AutoDiffVariableWithDerivative<T>> = bindings.entries.associate {
it.key.identity to AutoDiffVariableWithDerivative(it.key.identity, it.value, context.zero) it.key.identity to AutoDiffVariableWithDerivative(it.key.identity, it.value, algebra.zero)
} }
/** /**
@ -92,7 +99,7 @@ public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
override fun bindSymbolOrNull(value: String): AutoDiffValue<T>? = bindings[value] override fun bindSymbolOrNull(value: String): AutoDiffValue<T>? = bindings[value]
private fun getDerivative(variable: AutoDiffValue<T>): T = private fun getDerivative(variable: AutoDiffValue<T>): T =
(variable as? AutoDiffVariableWithDerivative)?.d ?: derivatives[variable] ?: context.zero (variable as? AutoDiffVariableWithDerivative)?.d ?: derivatives[variable] ?: algebra.zero
private fun setDerivative(variable: AutoDiffValue<T>, value: T) { private fun setDerivative(variable: AutoDiffValue<T>, value: T) {
if (variable is AutoDiffVariableWithDerivative) variable.d = value else derivatives[variable] = value if (variable is AutoDiffVariableWithDerivative) variable.d = value else derivatives[variable] = value
@ -103,7 +110,7 @@ public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
while (sp > 0) { while (sp > 0) {
val value = stack[--sp] val value = stack[--sp]
val block = stack[--sp] as F.(Any?) -> Unit val block = stack[--sp] as F.(Any?) -> Unit
context.block(value) algebra.block(value)
} }
} }
@ -130,7 +137,6 @@ public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
* } * }
* ``` * ```
*/ */
@Suppress("UNCHECKED_CAST")
public fun <R> derive(value: R, block: F.(R) -> Unit): R { public fun <R> derive(value: R, block: F.(R) -> Unit): R {
// save block to stack for backward pass // save block to stack for backward pass
if (sp >= stack.size) stack = stack.copyOf(stack.size * 2) if (sp >= stack.size) stack = stack.copyOf(stack.size * 2)
@ -142,9 +148,9 @@ public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
internal fun differentiate(function: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>): DerivationResult<T> { internal fun differentiate(function: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>): DerivationResult<T> {
val result = function() val result = function()
result.d = context.one // computing derivative w.r.t result result.d = algebra.one // computing derivative w.r.t result
runBackwardPass() runBackwardPass()
return DerivationResult(result.value, bindings.mapValues { it.value.d }, context) return DerivationResult(result.value, algebra.type, bindings.mapValues { it.value.d }, algebra)
} }
// // Overloads for Double constants // // Overloads for Double constants
@ -194,7 +200,7 @@ public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
public inline fun <T : Any, F : Field<T>> SimpleAutoDiffField<T, F>.const(block: F.() -> T): AutoDiffValue<T> { public inline fun <T : Any, F : Field<T>> SimpleAutoDiffField<T, F>.const(block: F.() -> T): AutoDiffValue<T> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return const(context.block()) return const(algebra.block())
} }

View File

@ -10,10 +10,9 @@ import space.kscience.kmath.UnstableKMathAPI
import space.kscience.kmath.nd.* import space.kscience.kmath.nd.*
import space.kscience.kmath.operations.BufferRingOps import space.kscience.kmath.operations.BufferRingOps
import space.kscience.kmath.operations.Ring import space.kscience.kmath.operations.Ring
import space.kscience.kmath.operations.WithType
import space.kscience.kmath.operations.invoke import space.kscience.kmath.operations.invoke
import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.Buffer
import kotlin.reflect.KType
import kotlin.reflect.typeOf
/** /**
* Alias for [Structure2D] with more familiar name. * Alias for [Structure2D] with more familiar name.
@ -34,7 +33,7 @@ public typealias Point<T> = Buffer<T>
* A marker interface for algebras that operate on matrices * A marker interface for algebras that operate on matrices
* @param T type of matrix element * @param T type of matrix element
*/ */
public interface MatrixOperations<T> public interface MatrixOperations<T> : WithType<T>
/** /**
* Basic operations on matrices and vectors. * Basic operations on matrices and vectors.
@ -45,6 +44,8 @@ public interface MatrixOperations<T>
public interface LinearSpace<T, out A : Ring<T>> : MatrixOperations<T> { public interface LinearSpace<T, out A : Ring<T>> : MatrixOperations<T> {
public val elementAlgebra: A public val elementAlgebra: A
override val type: SafeType<T> get() = elementAlgebra.type
/** /**
* Produces a matrix with this context and given dimensions. * Produces a matrix with this context and given dimensions.
*/ */
@ -212,4 +213,4 @@ public fun <T : Any> Matrix<T>.asVector(): Point<T> =
* @receiver a buffer. * @receiver a buffer.
* @return the new matrix. * @return the new matrix.
*/ */
public fun <T : Any> Point<T>.asMatrix(): VirtualMatrix<T> = VirtualMatrix(size, 1) { i, _ -> get(i) } public fun <T : Any> Point<T>.asMatrix(): VirtualMatrix<T> = VirtualMatrix(type, size, 1) { i, _ -> get(i) }

View File

@ -22,11 +22,27 @@ import space.kscience.kmath.structures.*
* @param l The lower triangular matrix in this decomposition. It may have [LowerTriangular]. * @param l The lower triangular matrix in this decomposition. It may have [LowerTriangular].
* @param u The upper triangular matrix in this decomposition. It may have [UpperTriangular]. * @param u The upper triangular matrix in this decomposition. It may have [UpperTriangular].
*/ */
public data class LupDecomposition<T>( public class LupDecomposition<T>(
public val linearSpace: LinearSpace<T, Ring<T>>,
public val l: Matrix<T>, public val l: Matrix<T>,
public val u: Matrix<T>, public val u: Matrix<T>,
public val pivot: IntBuffer, public val pivot: IntBuffer,
) ) {
public val elementAlgebra: Ring<T> get() = linearSpace.elementAlgebra
public val pivotMatrix: VirtualMatrix<T>
get() = VirtualMatrix(linearSpace.type, l.rowNum, l.colNum) { row, column ->
if (column == pivot[row]) elementAlgebra.one else elementAlgebra.zero
}
public val <T> LupDecomposition<T>.determinant by lazy {
elementAlgebra { (0 until l.shape[0]).fold(if (even) one else -one) { value, i -> value * lu[i, i] } }
}
}
public class LupDecompositionAttribute<T>(type: SafeType<LupDecomposition<T>>) : public class LupDecompositionAttribute<T>(type: SafeType<LupDecomposition<T>>) :
@ -168,7 +184,7 @@ public fun <T : Comparable<T>> LinearSpace<T, Field<T>>.lup(
for (row in col + 1 until m) lu[row, col] /= luDiag for (row in col + 1 until m) lu[row, col] /= luDiag
} }
val l: MatrixWrapper<T> = VirtualMatrix(rowNum, colNum) { i, j -> val l: MatrixWrapper<T> = VirtualMatrix(type, rowNum, colNum) { i, j ->
when { when {
j < i -> lu[i, j] j < i -> lu[i, j]
j == i -> one j == i -> one
@ -176,7 +192,7 @@ public fun <T : Comparable<T>> LinearSpace<T, Field<T>>.lup(
} }
}.withAttribute(LowerTriangular) }.withAttribute(LowerTriangular)
val u = VirtualMatrix(rowNum, colNum) { i, j -> val u = VirtualMatrix(type, rowNum, colNum) { i, j ->
if (j >= i) lu[i, j] else zero if (j >= i) lu[i, j] else zero
}.withAttribute(UpperTriangular) }.withAttribute(UpperTriangular)
// //
@ -184,7 +200,7 @@ public fun <T : Comparable<T>> LinearSpace<T, Field<T>>.lup(
// if (j == pivot[i]) one else zero // if (j == pivot[i]) one else zero
// }.withAttribute(Determinant, if (even) one else -one) // }.withAttribute(Determinant, if (even) one else -one)
return LupDecomposition(l, u, pivot.asBuffer()) return LupDecomposition(this@lup, l, u, pivot.asBuffer())
} }
} }
} }

View File

@ -6,16 +6,21 @@
package space.kscience.kmath.linear package space.kscience.kmath.linear
import space.kscience.attributes.FlagAttribute import space.kscience.attributes.FlagAttribute
import space.kscience.attributes.SafeType
import space.kscience.kmath.UnstableKMathAPI import space.kscience.kmath.UnstableKMathAPI
import space.kscience.kmath.operations.Ring import space.kscience.kmath.operations.Ring
import space.kscience.kmath.operations.WithType
import space.kscience.kmath.structures.BufferAccessor2D import space.kscience.kmath.structures.BufferAccessor2D
import space.kscience.kmath.structures.MutableBuffer import space.kscience.kmath.structures.MutableBufferFactory
public class MatrixBuilder<T : Any, out A : Ring<T>>( public class MatrixBuilder<T : Any, out A : Ring<T>>(
public val linearSpace: LinearSpace<T, A>, public val linearSpace: LinearSpace<T, A>,
public val rows: Int, public val rows: Int,
public val columns: Int, public val columns: Int,
) { ) : WithType<T> {
override val type: SafeType<T> get() = linearSpace.type
public operator fun invoke(vararg elements: T): Matrix<T> { public operator fun invoke(vararg elements: T): Matrix<T> {
require(rows * columns == elements.size) { "The number of elements ${elements.size} is not equal $rows * $columns" } require(rows * columns == elements.size) { "The number of elements ${elements.size} is not equal $rows * $columns" }
return linearSpace.buildMatrix(rows, columns) { i, j -> elements[i * columns + j] } return linearSpace.buildMatrix(rows, columns) { i, j -> elements[i * columns + j] }
@ -61,7 +66,7 @@ public fun <T : Any, A : Ring<T>> MatrixBuilder<T, A>.symmetric(
builder: (i: Int, j: Int) -> T, builder: (i: Int, j: Int) -> T,
): Matrix<T> { ): Matrix<T> {
require(columns == rows) { "In order to build symmetric matrix, number of rows $rows should be equal to number of columns $columns" } require(columns == rows) { "In order to build symmetric matrix, number of rows $rows should be equal to number of columns $columns" }
return with(BufferAccessor2D<T?>(rows, rows, MutableBuffer.Companion::boxing)) { return with(BufferAccessor2D<T?>(rows, rows, MutableBufferFactory(type))) {
val cache = factory(rows * rows) { null } val cache = factory(rows * rows) { null }
linearSpace.buildMatrix(rows, rows) { i, j -> linearSpace.buildMatrix(rows, rows) { i, j ->
val cached = cache[i, j] val cached = cache[i, j]

View File

@ -39,7 +39,7 @@ public fun <T : Any, A : Attribute<T>> Matrix<T>.withAttribute(
attribute: A, attribute: A,
attrValue: T, attrValue: T,
): MatrixWrapper<T> = if (this is MatrixWrapper) { ): MatrixWrapper<T> = if (this is MatrixWrapper) {
MatrixWrapper(origin, attributes.withAttribute(attribute,attrValue)) MatrixWrapper(origin, attributes.withAttribute(attribute, attrValue))
} else { } else {
MatrixWrapper(this, Attributes(attribute, attrValue)) MatrixWrapper(this, Attributes(attribute, attrValue))
} }
@ -68,7 +68,7 @@ public fun <T : Any> Matrix<T>.modifyAttributes(modifier: (Attributes) -> Attrib
public fun <T : Any> LinearSpace<T, Ring<T>>.one( public fun <T : Any> LinearSpace<T, Ring<T>>.one(
rows: Int, rows: Int,
columns: Int, columns: Int,
): MatrixWrapper<T> = VirtualMatrix(rows, columns) { i, j -> ): MatrixWrapper<T> = VirtualMatrix(type, rows, columns) { i, j ->
if (i == j) elementAlgebra.one else elementAlgebra.zero if (i == j) elementAlgebra.one else elementAlgebra.zero
}.withAttribute(IsUnit) }.withAttribute(IsUnit)
@ -79,6 +79,6 @@ public fun <T : Any> LinearSpace<T, Ring<T>>.one(
public fun <T : Any> LinearSpace<T, Ring<T>>.zero( public fun <T : Any> LinearSpace<T, Ring<T>>.zero(
rows: Int, rows: Int,
columns: Int, columns: Int,
): MatrixWrapper<T> = VirtualMatrix(rows, columns) { _, _ -> ): MatrixWrapper<T> = VirtualMatrix(type, rows, columns) { _, _ ->
elementAlgebra.zero elementAlgebra.zero
}.withAttribute(IsZero) }.withAttribute(IsZero)

View File

@ -6,10 +6,14 @@
package space.kscience.kmath.linear package space.kscience.kmath.linear
import space.kscience.attributes.Attributes import space.kscience.attributes.Attributes
import space.kscience.attributes.SafeType
public class TransposedMatrix<T>(public val origin: Matrix<T>) : Matrix<T> { public class TransposedMatrix<T>(public val origin: Matrix<T>) : Matrix<T> {
override val type: SafeType<T> get() = origin.type
override val rowNum: Int get() = origin.colNum override val rowNum: Int get() = origin.colNum
override val colNum: Int get() = origin.rowNum override val colNum: Int get() = origin.rowNum
override fun get(i: Int, j: Int): T = origin[j, i] override fun get(i: Int, j: Int): T = origin[j, i]

View File

@ -6,6 +6,7 @@
package space.kscience.kmath.linear package space.kscience.kmath.linear
import space.kscience.attributes.Attributes import space.kscience.attributes.Attributes
import space.kscience.attributes.SafeType
import space.kscience.kmath.nd.ShapeND import space.kscience.kmath.nd.ShapeND
@ -14,7 +15,8 @@ import space.kscience.kmath.nd.ShapeND
* *
* @property generator the function that provides elements. * @property generator the function that provides elements.
*/ */
public class VirtualMatrix<out T : Any>( public class VirtualMatrix<out T>(
override val type: SafeType<T>,
override val rowNum: Int, override val rowNum: Int,
override val colNum: Int, override val colNum: Int,
override val attributes: Attributes = Attributes.EMPTY, override val attributes: Attributes = Attributes.EMPTY,
@ -29,4 +31,4 @@ public class VirtualMatrix<out T : Any>(
public fun <T : Any> MatrixBuilder<T, *>.virtual( public fun <T : Any> MatrixBuilder<T, *>.virtual(
attributes: Attributes = Attributes.EMPTY, attributes: Attributes = Attributes.EMPTY,
generator: (i: Int, j: Int) -> T, generator: (i: Int, j: Int) -> T,
): VirtualMatrix<T> = VirtualMatrix(rows, columns, attributes, generator) ): VirtualMatrix<T> = VirtualMatrix(type, rows, columns, attributes, generator)

View File

@ -3,13 +3,11 @@
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/ */
@file:OptIn(UnstableKMathAPI::class)
@file:Suppress("UnusedReceiverParameter") @file:Suppress("UnusedReceiverParameter")
package space.kscience.kmath.linear package space.kscience.kmath.linear
import space.kscience.attributes.* import space.kscience.attributes.*
import space.kscience.kmath.UnstableKMathAPI
import space.kscience.kmath.nd.StructureAttribute import space.kscience.kmath.nd.StructureAttribute
/** /**
@ -51,9 +49,11 @@ public val <T> MatrixOperations<T>.Inverted: Inverted<T> get() = Inverted(safeTy
* *
* @param T the type of matrices' items. * @param T the type of matrices' items.
*/ */
public class Determinant<T> : MatrixAttribute<T> public class Determinant<T>(type: SafeType<T>) :
PolymorphicAttribute<T>(type),
MatrixAttribute<T>
public val <T> MatrixOperations<T>.Determinant: Determinant<T> get() = Determinant() public val <T> MatrixOperations<T>.Determinant: Determinant<T> get() = Determinant(type)
/** /**
* Matrices with this feature are lower triangular ones. * Matrices with this feature are lower triangular ones.

View File

@ -4,6 +4,8 @@
*/ */
@file:OptIn(UnstableKMathAPI::class)
package space.kscience.kmath.misc package space.kscience.kmath.misc
import space.kscience.kmath.UnstableKMathAPI import space.kscience.kmath.UnstableKMathAPI
@ -22,10 +24,9 @@ public fun <V : Comparable<V>> Buffer<V>.indicesSorted(): IntArray = permSortInd
/** /**
* Create a zero-copy virtual buffer that contains the same elements but in ascending order * Create a zero-copy virtual buffer that contains the same elements but in ascending order
*/ */
@OptIn(UnstableKMathAPI::class)
public fun <V : Comparable<V>> Buffer<V>.sorted(): Buffer<V> { public fun <V : Comparable<V>> Buffer<V>.sorted(): Buffer<V> {
val permutations = indicesSorted() val permutations = indicesSorted()
return VirtualBuffer(size) { this[permutations[it]] } return VirtualBuffer(type, size) { this[permutations[it]] }
} }
@UnstableKMathAPI @UnstableKMathAPI
@ -35,30 +36,27 @@ public fun <V : Comparable<V>> Buffer<V>.indicesSortedDescending(): IntArray =
/** /**
* Create a zero-copy virtual buffer that contains the same elements but in descending order * Create a zero-copy virtual buffer that contains the same elements but in descending order
*/ */
@OptIn(UnstableKMathAPI::class)
public fun <V : Comparable<V>> Buffer<V>.sortedDescending(): Buffer<V> { public fun <V : Comparable<V>> Buffer<V>.sortedDescending(): Buffer<V> {
val permutations = indicesSortedDescending() val permutations = indicesSortedDescending()
return VirtualBuffer(size) { this[permutations[it]] } return VirtualBuffer(type, size) { this[permutations[it]] }
} }
@UnstableKMathAPI @UnstableKMathAPI
public fun <V, C : Comparable<C>> Buffer<V>.indicesSortedBy(selector: (V) -> C): IntArray = public fun <V, C : Comparable<C>> Buffer<V>.indicesSortedBy(selector: (V) -> C): IntArray =
permSortIndicesWith(compareBy { selector(get(it)) }) permSortIndicesWith(compareBy { selector(get(it)) })
@OptIn(UnstableKMathAPI::class)
public fun <V, C : Comparable<C>> Buffer<V>.sortedBy(selector: (V) -> C): Buffer<V> { public fun <V, C : Comparable<C>> Buffer<V>.sortedBy(selector: (V) -> C): Buffer<V> {
val permutations = indicesSortedBy(selector) val permutations = indicesSortedBy(selector)
return VirtualBuffer(size) { this[permutations[it]] } return VirtualBuffer(type, size) { this[permutations[it]] }
} }
@UnstableKMathAPI @UnstableKMathAPI
public fun <V, C : Comparable<C>> Buffer<V>.indicesSortedByDescending(selector: (V) -> C): IntArray = public fun <V, C : Comparable<C>> Buffer<V>.indicesSortedByDescending(selector: (V) -> C): IntArray =
permSortIndicesWith(compareByDescending { selector(get(it)) }) permSortIndicesWith(compareByDescending { selector(get(it)) })
@OptIn(UnstableKMathAPI::class)
public fun <V, C : Comparable<C>> Buffer<V>.sortedByDescending(selector: (V) -> C): Buffer<V> { public fun <V, C : Comparable<C>> Buffer<V>.sortedByDescending(selector: (V) -> C): Buffer<V> {
val permutations = indicesSortedByDescending(selector) val permutations = indicesSortedByDescending(selector)
return VirtualBuffer(size) { this[permutations[it]] } return VirtualBuffer(type, size) { this[permutations[it]] }
} }
@UnstableKMathAPI @UnstableKMathAPI
@ -68,10 +66,9 @@ public fun <V> Buffer<V>.indicesSortedWith(comparator: Comparator<V>): IntArray
/** /**
* Create virtual zero-copy buffer with elements sorted by [comparator] * Create virtual zero-copy buffer with elements sorted by [comparator]
*/ */
@OptIn(UnstableKMathAPI::class)
public fun <V> Buffer<V>.sortedWith(comparator: Comparator<V>): Buffer<V> { public fun <V> Buffer<V>.sortedWith(comparator: Comparator<V>): Buffer<V> {
val permutations = indicesSortedWith(comparator) val permutations = indicesSortedWith(comparator)
return VirtualBuffer(size) { this[permutations[it]] } return VirtualBuffer(type,size) { this[permutations[it]] }
} }
private fun <V> Buffer<V>.permSortIndicesWith(comparator: Comparator<Int>): IntArray { private fun <V> Buffer<V>.permSortIndicesWith(comparator: Comparator<Int>): IntArray {

View File

@ -8,7 +8,7 @@ package space.kscience.kmath.nd
import space.kscience.kmath.PerformancePitfall import space.kscience.kmath.PerformancePitfall
import space.kscience.kmath.UnstableKMathAPI import space.kscience.kmath.UnstableKMathAPI
import space.kscience.kmath.operations.* import space.kscience.kmath.operations.*
import kotlin.reflect.KClass import space.kscience.kmath.structures.MutableBufferFactory
/** /**
* The base interface for all ND-algebra implementations. * The base interface for all ND-algebra implementations.
@ -91,6 +91,8 @@ public interface AlgebraND<T, out C : Algebra<T>> : Algebra<StructureND<T>> {
* @param A the type of group over structure elements. * @param A the type of group over structure elements.
*/ */
public interface GroupOpsND<T, out A : GroupOps<T>> : GroupOps<StructureND<T>>, AlgebraND<T, A> { public interface GroupOpsND<T, out A : GroupOps<T>> : GroupOps<StructureND<T>>, AlgebraND<T, A> {
override val bufferFactory: MutableBufferFactory<StructureND<T>> get() = MutableBufferFactory<StructureND<T>>()
/** /**
* Element-wise addition. * Element-wise addition.
* *

View File

@ -7,6 +7,8 @@
package space.kscience.kmath.nd package space.kscience.kmath.nd
import space.kscience.attributes.SafeType
import space.kscience.attributes.safeTypeOf
import space.kscience.kmath.PerformancePitfall import space.kscience.kmath.PerformancePitfall
import space.kscience.kmath.UnstableKMathAPI import space.kscience.kmath.UnstableKMathAPI
import space.kscience.kmath.operations.* import space.kscience.kmath.operations.*
@ -105,6 +107,9 @@ public open class BufferedGroupNDOps<T, out A : Group<T>>(
override val bufferAlgebra: BufferAlgebra<T, A>, override val bufferAlgebra: BufferAlgebra<T, A>,
override val indexerBuilder: (ShapeND) -> ShapeIndexer = BufferAlgebraND.defaultIndexerBuilder, override val indexerBuilder: (ShapeND) -> ShapeIndexer = BufferAlgebraND.defaultIndexerBuilder,
) : GroupOpsND<T, A>, BufferAlgebraND<T, A> { ) : GroupOpsND<T, A>, BufferAlgebraND<T, A> {
override val type: SafeType<StructureND<T>> get() = safeTypeOf<StructureND<T>>()
override fun StructureND<T>.unaryMinus(): StructureND<T> = map { -it } override fun StructureND<T>.unaryMinus(): StructureND<T> = map { -it }
} }

View File

@ -5,6 +5,7 @@
package space.kscience.kmath.nd package space.kscience.kmath.nd
import space.kscience.attributes.SafeType
import space.kscience.kmath.PerformancePitfall import space.kscience.kmath.PerformancePitfall
import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.Buffer
import space.kscience.kmath.structures.BufferFactory import space.kscience.kmath.structures.BufferFactory
@ -23,6 +24,8 @@ public open class BufferND<out T>(
public open val buffer: Buffer<T>, public open val buffer: Buffer<T>,
) : StructureND<T> { ) : StructureND<T> {
override val type: SafeType<T> get() = buffer.type
@PerformancePitfall @PerformancePitfall
override operator fun get(index: IntArray): T = buffer[indices.offset(index)] override operator fun get(index: IntArray): T = buffer[indices.offset(index)]
@ -36,7 +39,7 @@ public open class BufferND<out T>(
*/ */
public inline fun <reified T> BufferND( public inline fun <reified T> BufferND(
shape: ShapeND, shape: ShapeND,
bufferFactory: BufferFactory<T> = BufferFactory.auto<T>(), bufferFactory: BufferFactory<T> = BufferFactory<T>(),
crossinline initializer: (IntArray) -> T, crossinline initializer: (IntArray) -> T,
): BufferND<T> { ): BufferND<T> {
val strides = Strides(shape) val strides = Strides(shape)
@ -84,10 +87,10 @@ public open class MutableBufferND<T>(
/** /**
* Create a generic [BufferND] using provided [initializer] * Create a generic [BufferND] using provided [initializer]
*/ */
public fun <T> MutableBufferND( public inline fun <reified T> MutableBufferND(
shape: ShapeND, shape: ShapeND,
bufferFactory: MutableBufferFactory<T> = MutableBufferFactory.boxing(), bufferFactory: MutableBufferFactory<T> = MutableBufferFactory(),
initializer: (IntArray) -> T, crossinline initializer: (IntArray) -> T,
): MutableBufferND<T> { ): MutableBufferND<T> {
val strides = Strides(shape) val strides = Strides(shape)
return MutableBufferND(strides, bufferFactory(strides.linearSize) { initializer(strides.index(it)) }) return MutableBufferND(strides, bufferFactory(strides.linearSize) { initializer(strides.index(it)) })

View File

@ -5,6 +5,7 @@
package space.kscience.kmath.nd package space.kscience.kmath.nd
import space.kscience.attributes.SafeType
import space.kscience.kmath.PerformancePitfall import space.kscience.kmath.PerformancePitfall
@ -13,6 +14,8 @@ public class PermutedStructureND<T>(
public val permutation: (IntArray) -> IntArray, public val permutation: (IntArray) -> IntArray,
) : StructureND<T> { ) : StructureND<T> {
override val type: SafeType<T> get() = origin.type
override val shape: ShapeND override val shape: ShapeND
get() = origin.shape get() = origin.shape
@ -32,6 +35,7 @@ public class PermutedMutableStructureND<T>(
public val permutation: (IntArray) -> IntArray, public val permutation: (IntArray) -> IntArray,
) : MutableStructureND<T> { ) : MutableStructureND<T> {
override val type: SafeType<T> get() = origin.type
@OptIn(PerformancePitfall::class) @OptIn(PerformancePitfall::class)
override fun set(index: IntArray, value: T) { override fun set(index: IntArray, value: T) {

View File

@ -5,6 +5,7 @@
package space.kscience.kmath.nd package space.kscience.kmath.nd
import space.kscience.attributes.SafeType
import space.kscience.kmath.PerformancePitfall import space.kscience.kmath.PerformancePitfall
import space.kscience.kmath.operations.asSequence import space.kscience.kmath.operations.asSequence
import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.Buffer
@ -46,6 +47,9 @@ public interface MutableStructure1D<T> : Structure1D<T>, MutableStructureND<T>,
*/ */
@JvmInline @JvmInline
private value class Structure1DWrapper<out T>(val structure: StructureND<T>) : Structure1D<T> { private value class Structure1DWrapper<out T>(val structure: StructureND<T>) : Structure1D<T> {
override val type: SafeType<T> get() = structure.type
override val shape: ShapeND get() = structure.shape override val shape: ShapeND get() = structure.shape
override val size: Int get() = structure.shape[0] override val size: Int get() = structure.shape[0]
@ -60,6 +64,9 @@ private value class Structure1DWrapper<out T>(val structure: StructureND<T>) : S
* A 1D wrapper for a mutable nd-structure * A 1D wrapper for a mutable nd-structure
*/ */
private class MutableStructure1DWrapper<T>(val structure: MutableStructureND<T>) : MutableStructure1D<T> { private class MutableStructure1DWrapper<T>(val structure: MutableStructureND<T>) : MutableStructure1D<T> {
override val type: SafeType<T> get() = structure.type
override val shape: ShapeND get() = structure.shape override val shape: ShapeND get() = structure.shape
override val size: Int get() = structure.shape[0] override val size: Int get() = structure.shape[0]
@ -79,7 +86,7 @@ private class MutableStructure1DWrapper<T>(val structure: MutableStructureND<T>)
.elements() .elements()
.map(Pair<IntArray, T>::second) .map(Pair<IntArray, T>::second)
.toMutableList() .toMutableList()
.asMutableBuffer() .asMutableBuffer(type)
override fun toString(): String = Buffer.toString(this) override fun toString(): String = Buffer.toString(this)
} }
@ -90,6 +97,9 @@ private class MutableStructure1DWrapper<T>(val structure: MutableStructureND<T>)
*/ */
@JvmInline @JvmInline
private value class Buffer1DWrapper<out T>(val buffer: Buffer<T>) : Structure1D<T> { private value class Buffer1DWrapper<out T>(val buffer: Buffer<T>) : Structure1D<T> {
override val type: SafeType<T> get() = buffer.type
override val shape: ShapeND get() = ShapeND(buffer.size) override val shape: ShapeND get() = ShapeND(buffer.size)
override val size: Int get() = buffer.size override val size: Int get() = buffer.size
@ -102,6 +112,9 @@ private value class Buffer1DWrapper<out T>(val buffer: Buffer<T>) : Structure1D<
} }
internal class MutableBuffer1DWrapper<T>(val buffer: MutableBuffer<T>) : MutableStructure1D<T> { internal class MutableBuffer1DWrapper<T>(val buffer: MutableBuffer<T>) : MutableStructure1D<T> {
override val type: SafeType<T> get() = buffer.type
override val shape: ShapeND get() = ShapeND(buffer.size) override val shape: ShapeND get() = ShapeND(buffer.size)
override val size: Int get() = buffer.size override val size: Int get() = buffer.size

View File

@ -6,13 +6,12 @@
package space.kscience.kmath.nd package space.kscience.kmath.nd
import space.kscience.attributes.Attributes import space.kscience.attributes.Attributes
import space.kscience.attributes.SafeType
import space.kscience.kmath.PerformancePitfall import space.kscience.kmath.PerformancePitfall
import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.Buffer
import space.kscience.kmath.structures.MutableBuffer import space.kscience.kmath.structures.MutableBuffer
import space.kscience.kmath.structures.MutableListBuffer
import space.kscience.kmath.structures.VirtualBuffer import space.kscience.kmath.structures.VirtualBuffer
import kotlin.jvm.JvmInline import kotlin.jvm.JvmInline
import kotlin.reflect.KClass
/** /**
* A structure that is guaranteed to be two-dimensional. * A structure that is guaranteed to be two-dimensional.
@ -33,18 +32,18 @@ public interface Structure2D<out T> : StructureND<T> {
override val shape: ShapeND get() = ShapeND(rowNum, colNum) override val shape: ShapeND get() = ShapeND(rowNum, colNum)
/** /**
* The buffer of rows of this structure. It gets elements from the structure dynamically. * The buffer of rows for this structure. It gets elements from the structure dynamically.
*/ */
@PerformancePitfall @PerformancePitfall
public val rows: List<Buffer<T>> public val rows: List<Buffer<T>>
get() = List(rowNum) { i -> VirtualBuffer(colNum) { j -> get(i, j) } } get() = List(rowNum) { i -> VirtualBuffer(type, colNum) { j -> get(i, j) } }
/** /**
* The buffer of columns of this structure. It gets elements from the structure dynamically. * The buffer of columns for this structure. It gets elements from the structure dynamically.
*/ */
@PerformancePitfall @PerformancePitfall
public val columns: List<Buffer<T>> public val columns: List<Buffer<T>>
get() = List(colNum) { j -> VirtualBuffer(rowNum) { i -> get(i, j) } } get() = List(colNum) { j -> VirtualBuffer(type, rowNum) { i -> get(i, j) } }
/** /**
* Retrieves an element from the structure by two indices. * Retrieves an element from the structure by two indices.
@ -88,14 +87,14 @@ public interface MutableStructure2D<T> : Structure2D<T>, MutableStructureND<T> {
*/ */
@PerformancePitfall @PerformancePitfall
override val rows: List<MutableBuffer<T>> override val rows: List<MutableBuffer<T>>
get() = List(rowNum) { i -> MutableListBuffer(colNum) { j -> get(i, j) } } get() = List(rowNum) { i -> MutableBuffer(type, colNum) { j -> get(i, j) } }
/** /**
* The buffer of columns of this structure. It gets elements from the structure dynamically. * The buffer of columns for this structure. It gets elements from the structure dynamically.
*/ */
@PerformancePitfall @PerformancePitfall
override val columns: List<MutableBuffer<T>> override val columns: List<MutableBuffer<T>>
get() = List(colNum) { j -> MutableListBuffer(rowNum) { i -> get(i, j) } } get() = List(colNum) { j -> MutableBuffer(type, rowNum) { i -> get(i, j) } }
} }
/** /**
@ -103,6 +102,9 @@ public interface MutableStructure2D<T> : Structure2D<T>, MutableStructureND<T> {
*/ */
@JvmInline @JvmInline
private value class Structure2DWrapper<out T>(val structure: StructureND<T>) : Structure2D<T> { private value class Structure2DWrapper<out T>(val structure: StructureND<T>) : Structure2D<T> {
override val type: SafeType<T> get() = structure.type
override val shape: ShapeND get() = structure.shape override val shape: ShapeND get() = structure.shape
override val rowNum: Int get() = shape[0] override val rowNum: Int get() = shape[0]
@ -122,6 +124,9 @@ private value class Structure2DWrapper<out T>(val structure: StructureND<T>) : S
* A 2D wrapper for a mutable nd-structure * A 2D wrapper for a mutable nd-structure
*/ */
private class MutableStructure2DWrapper<T>(val structure: MutableStructureND<T>) : MutableStructure2D<T> { private class MutableStructure2DWrapper<T>(val structure: MutableStructureND<T>) : MutableStructure2D<T> {
override val type: SafeType<T> get() = structure.type
override val shape: ShapeND get() = structure.shape override val shape: ShapeND get() = structure.shape
override val rowNum: Int get() = shape[0] override val rowNum: Int get() = shape[0]

View File

@ -12,6 +12,7 @@ import space.kscience.attributes.SafeType
import space.kscience.kmath.PerformancePitfall import space.kscience.kmath.PerformancePitfall
import space.kscience.kmath.linear.LinearSpace import space.kscience.kmath.linear.LinearSpace
import space.kscience.kmath.operations.Ring import space.kscience.kmath.operations.Ring
import space.kscience.kmath.operations.WithType
import space.kscience.kmath.operations.invoke import space.kscience.kmath.operations.invoke
import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.Buffer
import kotlin.jvm.JvmName import kotlin.jvm.JvmName
@ -28,9 +29,9 @@ public interface StructureAttribute<T> : Attribute<T>
* *
* @param T the type of items. * @param T the type of items.
*/ */
public interface StructureND<out T> : AttributeContainer, WithShape { public interface StructureND<out T> : AttributeContainer, WithShape, WithType<T> {
/** /**
* The shape of structure i.e., non-empty sequence of non-negative integers that specify sizes of dimensions of * The shape of structure i.e., non-empty sequence of non-negative integers that specify sizes of dimensions for
* this structure. * this structure.
*/ */
override val shape: ShapeND override val shape: ShapeND
@ -117,56 +118,60 @@ public interface StructureND<out T> : AttributeContainer, WithShape {
return "$className(shape=${structure.shape}, buffer=$bufferRepr)" return "$className(shape=${structure.shape}, buffer=$bufferRepr)"
} }
/** }
}
/**
* Creates a NDStructure with explicit buffer factory. * Creates a NDStructure with explicit buffer factory.
* *
* Strides should be reused if possible. * Strides should be reused if possible.
*/ */
public fun <T> buffered( public fun <T> BufferND(
type: SafeType<T>,
strides: Strides, strides: Strides,
initializer: (IntArray) -> T, initializer: (IntArray) -> T,
): BufferND<T> = BufferND(strides, Buffer.boxing(strides.linearSize) { i -> initializer(strides.index(i)) }) ): BufferND<T> = BufferND(strides, Buffer(type, strides.linearSize) { i -> initializer(strides.index(i)) })
public fun <T> buffered( public fun <T> BufferND(
type: SafeType<T>,
shape: ShapeND, shape: ShapeND,
initializer: (IntArray) -> T, initializer: (IntArray) -> T,
): BufferND<T> = buffered(ColumnStrides(shape), initializer) ): BufferND<T> = BufferND(type, ColumnStrides(shape), initializer)
/** /**
* Inline create NDStructure with non-boxing buffer implementation if it is possible * Inline create NDStructure with non-boxing buffer implementation if it is possible
*/ */
public inline fun <reified T : Any> auto( public inline fun <reified T : Any> BufferND(
strides: Strides, strides: Strides,
crossinline initializer: (IntArray) -> T, crossinline initializer: (IntArray) -> T,
): BufferND<T> = BufferND(strides, Buffer.auto(strides.linearSize) { i -> initializer(strides.index(i)) }) ): BufferND<T> = BufferND(strides, Buffer(strides.linearSize) { i -> initializer(strides.index(i)) })
public inline fun <T : Any> auto( public inline fun <T : Any> BufferND(
type: SafeType<T>, type: SafeType<T>,
strides: Strides, strides: Strides,
crossinline initializer: (IntArray) -> T, crossinline initializer: (IntArray) -> T,
): BufferND<T> = BufferND(strides, Buffer.auto(type, strides.linearSize) { i -> initializer(strides.index(i)) }) ): BufferND<T> = BufferND(strides, Buffer(type, strides.linearSize) { i -> initializer(strides.index(i)) })
public inline fun <reified T : Any> auto( public inline fun <reified T : Any> BufferND(
shape: ShapeND, shape: ShapeND,
crossinline initializer: (IntArray) -> T, crossinline initializer: (IntArray) -> T,
): BufferND<T> = auto(ColumnStrides(shape), initializer) ): BufferND<T> = BufferND(ColumnStrides(shape), initializer)
@JvmName("autoVarArg") @JvmName("autoVarArg")
public inline fun <reified T : Any> auto( public inline fun <reified T : Any> BufferND(
vararg shape: Int, vararg shape: Int,
crossinline initializer: (IntArray) -> T, crossinline initializer: (IntArray) -> T,
): BufferND<T> = ): BufferND<T> = BufferND(ColumnStrides(ShapeND(shape)), initializer)
auto(ColumnStrides(ShapeND(shape)), initializer)
public inline fun <T : Any> auto( public inline fun <T : Any> BufferND(
type: SafeType<T>, type: SafeType<T>,
vararg shape: Int, vararg shape: Int,
crossinline initializer: (IntArray) -> T, crossinline initializer: (IntArray) -> T,
): BufferND<T> = auto(type, ColumnStrides(ShapeND(shape)), initializer) ): BufferND<T> = BufferND(type, ColumnStrides(ShapeND(shape)), initializer)
}
}
/** /**
* Indicates whether some [StructureND] is equal to another one. * Indicates whether some [StructureND] is equal to another one.

View File

@ -5,10 +5,13 @@
package space.kscience.kmath.nd package space.kscience.kmath.nd
import space.kscience.attributes.SafeType
import space.kscience.attributes.safeTypeOf
import space.kscience.kmath.PerformancePitfall import space.kscience.kmath.PerformancePitfall
import space.kscience.kmath.UnstableKMathAPI import space.kscience.kmath.UnstableKMathAPI
public open class VirtualStructureND<T>( public open class VirtualStructureND<T>(
override val type: SafeType<T>,
override val shape: ShapeND, override val shape: ShapeND,
public val producer: (IntArray) -> T, public val producer: (IntArray) -> T,
) : StructureND<T> { ) : StructureND<T> {
@ -24,10 +27,10 @@ public open class VirtualStructureND<T>(
public class VirtualDoubleStructureND( public class VirtualDoubleStructureND(
shape: ShapeND, shape: ShapeND,
producer: (IntArray) -> Double, producer: (IntArray) -> Double,
) : VirtualStructureND<Double>(shape, producer) ) : VirtualStructureND<Double>(safeTypeOf(), shape, producer)
@UnstableKMathAPI @UnstableKMathAPI
public class VirtualIntStructureND( public class VirtualIntStructureND(
shape: ShapeND, shape: ShapeND,
producer: (IntArray) -> Int, producer: (IntArray) -> Int,
) : VirtualStructureND<Int>(shape, producer) ) : VirtualStructureND<Int>(safeTypeOf(), shape, producer)

View File

@ -10,7 +10,7 @@ import space.kscience.kmath.PerformancePitfall
@OptIn(PerformancePitfall::class) @OptIn(PerformancePitfall::class)
public fun <T> StructureND<T>.roll(axis: Int, step: Int = 1): StructureND<T> { public fun <T> StructureND<T>.roll(axis: Int, step: Int = 1): StructureND<T> {
require(axis in shape.indices) { "Axis $axis is outside of shape dimensions: [0, ${shape.size})" } require(axis in shape.indices) { "Axis $axis is outside of shape dimensions: [0, ${shape.size})" }
return VirtualStructureND(shape) { index -> return VirtualStructureND(type, shape) { index ->
val newIndex: IntArray = IntArray(index.size) { indexAxis -> val newIndex: IntArray = IntArray(index.size) { indexAxis ->
if (indexAxis == axis) { if (indexAxis == axis) {
(index[indexAxis] + step).mod(shape[indexAxis]) (index[indexAxis] + step).mod(shape[indexAxis])
@ -26,7 +26,7 @@ public fun <T> StructureND<T>.roll(axis: Int, step: Int = 1): StructureND<T> {
public fun <T> StructureND<T>.roll(pair: Pair<Int, Int>, vararg others: Pair<Int, Int>): StructureND<T> { public fun <T> StructureND<T>.roll(pair: Pair<Int, Int>, vararg others: Pair<Int, Int>): StructureND<T> {
val axisMap: Map<Int, Int> = mapOf(pair, *others) val axisMap: Map<Int, Int> = mapOf(pair, *others)
require(axisMap.keys.all { it in shape.indices }) { "Some of axes ${axisMap.keys} is outside of shape dimensions: [0, ${shape.size})" } require(axisMap.keys.all { it in shape.indices }) { "Some of axes ${axisMap.keys} is outside of shape dimensions: [0, ${shape.size})" }
return VirtualStructureND(shape) { index -> return VirtualStructureND(type, shape) { index ->
val newIndex: IntArray = IntArray(index.size) { indexAxis -> val newIndex: IntArray = IntArray(index.size) { indexAxis ->
val offset = axisMap[indexAxis] ?: 0 val offset = axisMap[indexAxis] ?: 0
(index[indexAxis] + offset).mod(shape[indexAxis]) (index[indexAxis] + offset).mod(shape[indexAxis])

View File

@ -5,22 +5,32 @@
package space.kscience.kmath.operations package space.kscience.kmath.operations
import space.kscience.attributes.SafeType
import space.kscience.kmath.UnstableKMathAPI import space.kscience.kmath.UnstableKMathAPI
import space.kscience.kmath.expressions.Symbol import space.kscience.kmath.expressions.Symbol
import space.kscience.kmath.operations.Ring.Companion.optimizedPower import space.kscience.kmath.operations.Ring.Companion.optimizedPower
import space.kscience.kmath.structures.MutableBufferFactory import space.kscience.kmath.structures.MutableBufferFactory
import kotlin.reflect.KType
/**
* An interface containing [type] for dynamic type checking.
*/
public interface WithType<out T> {
public val type: SafeType<T>
}
/** /**
* Represents an algebraic structure. * Represents an algebraic structure.
* *
* @param T the type of element which Algebra operates on. * @param T the type of element which Algebra operates on.
*/ */
public interface Algebra<T> { public interface Algebra<T> : WithType<T> {
/** /**
* Provide a factory for buffers, associated with this [Algebra] * Provide a factory for buffers, associated with this [Algebra]
*/ */
public val bufferFactory: MutableBufferFactory<T> get() = MutableBufferFactory.boxing() public val bufferFactory: MutableBufferFactory<T>
override val type: SafeType<T> get() = bufferFactory.type
/** /**
* Wraps a raw string to [T] object. This method is designed for three purposes: * Wraps a raw string to [T] object. This method is designed for three purposes:
@ -265,7 +275,7 @@ public interface Ring<T> : Group<T>, RingOps<T> {
*/ */
public fun power(arg: T, pow: UInt): T = optimizedPower(arg, pow) public fun power(arg: T, pow: UInt): T = optimizedPower(arg, pow)
public companion object{ public companion object {
/** /**
* Raises [arg] to the non-negative integer power [exponent]. * Raises [arg] to the non-negative integer power [exponent].
* *
@ -348,7 +358,7 @@ public interface Field<T> : Ring<T>, FieldOps<T>, ScaleOperations<T>, NumericAlg
public fun power(arg: T, pow: Int): T = optimizedPower(arg, pow) public fun power(arg: T, pow: Int): T = optimizedPower(arg, pow)
public companion object{ public companion object {
/** /**
* Raises [arg] to the integer power [exponent]. * Raises [arg] to the integer power [exponent].
* *
@ -361,7 +371,11 @@ public interface Field<T> : Ring<T>, FieldOps<T>, ScaleOperations<T>, NumericAlg
* @author Iaroslav Postovalov, Evgeniy Zhelenskiy * @author Iaroslav Postovalov, Evgeniy Zhelenskiy
*/ */
private fun <T> Field<T>.optimizedPower(arg: T, exponent: Int): T = when { private fun <T> Field<T>.optimizedPower(arg: T, exponent: Int): T = when {
exponent < 0 -> one / (this as Ring<T>).optimizedPower(arg, if (exponent == Int.MIN_VALUE) Int.MAX_VALUE.toUInt().inc() else (-exponent).toUInt()) exponent < 0 -> one / (this as Ring<T>).optimizedPower(
arg,
if (exponent == Int.MIN_VALUE) Int.MAX_VALUE.toUInt().inc() else (-exponent).toUInt()
)
else -> (this as Ring<T>).optimizedPower(arg, exponent.toUInt()) else -> (this as Ring<T>).optimizedPower(arg, exponent.toUInt())
} }
} }

View File

@ -5,6 +5,8 @@
package space.kscience.kmath.operations package space.kscience.kmath.operations
import space.kscience.attributes.SafeType
import space.kscience.attributes.safeTypeOf
import space.kscience.kmath.UnstableKMathAPI import space.kscience.kmath.UnstableKMathAPI
import space.kscience.kmath.nd.BufferedRingOpsND import space.kscience.kmath.nd.BufferedRingOpsND
import space.kscience.kmath.operations.BigInt.Companion.BASE import space.kscience.kmath.operations.BigInt.Companion.BASE
@ -26,6 +28,9 @@ private typealias TBase = ULong
*/ */
@OptIn(UnstableKMathAPI::class) @OptIn(UnstableKMathAPI::class)
public object BigIntField : Field<BigInt>, NumbersAddOps<BigInt>, ScaleOperations<BigInt> { public object BigIntField : Field<BigInt>, NumbersAddOps<BigInt>, ScaleOperations<BigInt> {
override val type: SafeType<BigInt> = safeTypeOf()
override val zero: BigInt = BigInt.ZERO override val zero: BigInt = BigInt.ZERO
override val one: BigInt = BigInt.ONE override val one: BigInt = BigInt.ONE
@ -528,10 +533,10 @@ public fun String.parseBigInteger(): BigInt? {
public val BigInt.algebra: BigIntField get() = BigIntField public val BigInt.algebra: BigIntField get() = BigIntField
public inline fun BigInt.Companion.buffer(size: Int, initializer: (Int) -> BigInt): Buffer<BigInt> = public inline fun BigInt.Companion.buffer(size: Int, initializer: (Int) -> BigInt): Buffer<BigInt> =
Buffer.boxing(size, initializer) Buffer(size, initializer)
public inline fun BigInt.Companion.mutableBuffer(size: Int, initializer: (Int) -> BigInt): Buffer<BigInt> = public inline fun BigInt.Companion.mutableBuffer(size: Int, initializer: (Int) -> BigInt): Buffer<BigInt> =
Buffer.boxing(size, initializer) Buffer(size, initializer)
public val BigIntField.nd: BufferedRingOpsND<BigInt, BigIntField> public val BigIntField.nd: BufferedRingOpsND<BigInt, BigIntField>
get() = BufferedRingOpsND(BufferRingOps(BigIntField)) get() = BufferedRingOpsND(BufferRingOps(BigIntField))

View File

@ -5,10 +5,11 @@
package space.kscience.kmath.operations package space.kscience.kmath.operations
import space.kscience.attributes.SafeType
import space.kscience.attributes.safeTypeOf
import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.Buffer
import space.kscience.kmath.structures.MutableBuffer import space.kscience.kmath.structures.MutableBuffer
import space.kscience.kmath.structures.MutableBufferFactory import space.kscience.kmath.structures.MutableBufferFactory
import kotlin.reflect.KType
public interface WithSize { public interface WithSize {
public val size: Int public val size: Int
@ -20,6 +21,8 @@ public interface WithSize {
public interface BufferAlgebra<T, out A : Algebra<T>> : Algebra<Buffer<T>> { public interface BufferAlgebra<T, out A : Algebra<T>> : Algebra<Buffer<T>> {
public val elementAlgebra: A public val elementAlgebra: A
override val type: SafeType<Buffer<T>> get() = safeTypeOf<Buffer<T>>()
public val elementBufferFactory: MutableBufferFactory<T> get() = elementAlgebra.bufferFactory public val elementBufferFactory: MutableBufferFactory<T> get() = elementAlgebra.bufferFactory
public fun buffer(size: Int, vararg elements: T): Buffer<T> { public fun buffer(size: Int, vararg elements: T): Buffer<T> {

View File

@ -5,6 +5,8 @@
package space.kscience.kmath.operations package space.kscience.kmath.operations
import space.kscience.attributes.SafeType
import space.kscience.attributes.safeTypeOf
import space.kscience.kmath.UnstableKMathAPI import space.kscience.kmath.UnstableKMathAPI
import space.kscience.kmath.expressions.Symbol import space.kscience.kmath.expressions.Symbol
@ -72,6 +74,8 @@ public interface LogicAlgebra<T : Any> : Algebra<T> {
@Suppress("EXTENSION_SHADOWED_BY_MEMBER") @Suppress("EXTENSION_SHADOWED_BY_MEMBER")
public object BooleanAlgebra : LogicAlgebra<Boolean> { public object BooleanAlgebra : LogicAlgebra<Boolean> {
override val type: SafeType<Boolean> get() = safeTypeOf()
override fun const(boolean: Boolean): Boolean = boolean override fun const(boolean: Boolean): Boolean = boolean
override fun Boolean.not(): Boolean = !this override fun Boolean.not(): Boolean = !this

View File

@ -8,7 +8,10 @@ package space.kscience.kmath.operations
import space.kscience.kmath.operations.Int16Field.div import space.kscience.kmath.operations.Int16Field.div
import space.kscience.kmath.operations.Int32Field.div import space.kscience.kmath.operations.Int32Field.div
import space.kscience.kmath.operations.Int64Field.div import space.kscience.kmath.operations.Int64Field.div
import space.kscience.kmath.structures.* import space.kscience.kmath.structures.Int16
import space.kscience.kmath.structures.Int32
import space.kscience.kmath.structures.Int64
import space.kscience.kmath.structures.MutableBufferFactory
import kotlin.math.roundToInt import kotlin.math.roundToInt
import kotlin.math.roundToLong import kotlin.math.roundToLong
@ -22,7 +25,7 @@ import kotlin.math.roundToLong
*/ */
@Suppress("EXTENSION_SHADOWED_BY_MEMBER") @Suppress("EXTENSION_SHADOWED_BY_MEMBER")
public object Int16Field : Field<Int16>, Norm<Int16, Int16>, NumericAlgebra<Int16> { public object Int16Field : Field<Int16>, Norm<Int16, Int16>, NumericAlgebra<Int16> {
override val bufferFactory: MutableBufferFactory<Int16> = MutableBufferFactory(::Int16Buffer) override val bufferFactory: MutableBufferFactory<Int16> = MutableBufferFactory<Int16>()
override val zero: Int16 get() = 0 override val zero: Int16 get() = 0
override val one: Int16 get() = 1 override val one: Int16 get() = 1
@ -45,7 +48,8 @@ public object Int16Field : Field<Int16>, Norm<Int16, Int16>, NumericAlgebra<Int1
*/ */
@Suppress("EXTENSION_SHADOWED_BY_MEMBER") @Suppress("EXTENSION_SHADOWED_BY_MEMBER")
public object Int32Field : Field<Int32>, Norm<Int32, Int32>, NumericAlgebra<Int32> { public object Int32Field : Field<Int32>, Norm<Int32, Int32>, NumericAlgebra<Int32> {
override val bufferFactory: MutableBufferFactory<Int> = MutableBufferFactory(::Int32Buffer) override val bufferFactory: MutableBufferFactory<Int> = MutableBufferFactory()
override val zero: Int get() = 0 override val zero: Int get() = 0
override val one: Int get() = 1 override val one: Int get() = 1
@ -68,7 +72,7 @@ public object Int32Field : Field<Int32>, Norm<Int32, Int32>, NumericAlgebra<Int3
*/ */
@Suppress("EXTENSION_SHADOWED_BY_MEMBER") @Suppress("EXTENSION_SHADOWED_BY_MEMBER")
public object Int64Field : Field<Int64>, Norm<Int64, Int64>, NumericAlgebra<Int64> { public object Int64Field : Field<Int64>, Norm<Int64, Int64>, NumericAlgebra<Int64> {
override val bufferFactory: MutableBufferFactory<Int64> = MutableBufferFactory(::Int64Buffer) override val bufferFactory: MutableBufferFactory<Int64> = MutableBufferFactory()
override val zero: Int64 get() = 0L override val zero: Int64 get() = 0L
override val one: Int64 get() = 1L override val one: Int64 get() = 1L

View File

@ -67,7 +67,7 @@ public interface ExtendedField<T> : ExtendedFieldOps<T>, Field<T>, NumericAlgebr
*/ */
@Suppress("EXTENSION_SHADOWED_BY_MEMBER") @Suppress("EXTENSION_SHADOWED_BY_MEMBER")
public object Float64Field : ExtendedField<Double>, Norm<Double, Double>, ScaleOperations<Double> { public object Float64Field : ExtendedField<Double>, Norm<Double, Double>, ScaleOperations<Double> {
override val bufferFactory: MutableBufferFactory<Double> = MutableBufferFactory(::Float64Buffer) override val bufferFactory: MutableBufferFactory<Double> = MutableBufferFactory()
override val zero: Double get() = 0.0 override val zero: Double get() = 0.0
override val one: Double get() = 1.0 override val one: Double get() = 1.0

View File

@ -8,8 +8,8 @@ package space.kscience.kmath.structures
import space.kscience.attributes.SafeType import space.kscience.attributes.SafeType
import space.kscience.attributes.safeTypeOf import space.kscience.attributes.safeTypeOf
import space.kscience.kmath.operations.WithSize import space.kscience.kmath.operations.WithSize
import space.kscience.kmath.operations.WithType
import space.kscience.kmath.operations.asSequence import space.kscience.kmath.operations.asSequence
import kotlin.jvm.JvmInline
import kotlin.reflect.typeOf import kotlin.reflect.typeOf
/** /**
@ -17,35 +17,46 @@ import kotlin.reflect.typeOf
* *
* @param T the type of buffer. * @param T the type of buffer.
*/ */
public fun interface BufferFactory<T> { public interface BufferFactory<T> : WithType<T> {
public operator fun invoke(size: Int, builder: (Int) -> T): Buffer<T> public operator fun invoke(size: Int, builder: (Int) -> T): Buffer<T>
public companion object{
public inline fun <reified T> auto(): BufferFactory<T> =
BufferFactory(Buffer.Companion::auto)
public fun <T> boxing(): BufferFactory<T> =
BufferFactory(Buffer.Companion::boxing)
}
} }
/**
* Create a [BufferFactory] for given [type], using primitive storage if possible
*/
public fun <T> BufferFactory(type: SafeType<T>): BufferFactory<T> = object : BufferFactory<T> {
override val type: SafeType<T> = type
override fun invoke(size: Int, builder: (Int) -> T): Buffer<T> = Buffer(type, size, builder)
}
/**
* Create [BufferFactory] using the reified type
*/
public inline fun <reified T> BufferFactory(): BufferFactory<T> = BufferFactory(safeTypeOf())
/** /**
* Function that produces [MutableBuffer] from its size and function that supplies values. * Function that produces [MutableBuffer] from its size and function that supplies values.
* *
* @param T the type of buffer. * @param T the type of buffer.
*/ */
public fun interface MutableBufferFactory<T> : BufferFactory<T> { public interface MutableBufferFactory<T> : BufferFactory<T> {
override fun invoke(size: Int, builder: (Int) -> T): MutableBuffer<T> override fun invoke(size: Int, builder: (Int) -> T): MutableBuffer<T>
public companion object {
public inline fun <reified T : Any> auto(): MutableBufferFactory<T> =
MutableBufferFactory(MutableBuffer.Companion::auto)
public fun <T> boxing(): MutableBufferFactory<T> =
MutableBufferFactory(MutableBuffer.Companion::boxing)
}
} }
/**
* Create a [MutableBufferFactory] for given [type], using primitive storage if possible
*/
public fun <T> MutableBufferFactory(type: SafeType<T>): MutableBufferFactory<T> = object : MutableBufferFactory<T> {
override val type: SafeType<T> = type
override fun invoke(size: Int, builder: (Int) -> T): MutableBuffer<T> = MutableBuffer(type, size, builder)
}
/**
* Create [BufferFactory] using the reified type
*/
public inline fun <reified T> MutableBufferFactory(): MutableBufferFactory<T> = MutableBufferFactory(safeTypeOf())
/** /**
* A generic read-only random-access structure for both primitives and objects. * A generic read-only random-access structure for both primitives and objects.
* *
@ -53,14 +64,14 @@ public fun interface MutableBufferFactory<T> : BufferFactory<T> {
* *
* @param T the type of elements contained in the buffer. * @param T the type of elements contained in the buffer.
*/ */
public interface Buffer<out T> : WithSize { public interface Buffer<out T> : WithSize, WithType<T> {
/** /**
* The size of this buffer. * The size of this buffer.
*/ */
override val size: Int override val size: Int
/** /**
* Gets element at given index. * Gets an element at given index.
*/ */
public operator fun get(index: Int): T public operator fun get(index: Int): T
@ -87,41 +98,48 @@ public interface Buffer<out T> : WithSize {
return true return true
} }
/**
* Creates a [ListBuffer] of given type [T] with given [size]. Each element is calculated by calling the
* specified [initializer] function.
*/
public inline fun <T> boxing(size: Int, initializer: (Int) -> T): Buffer<T> =
List(size, initializer).asBuffer()
/**
* Creates a [Buffer] of given [type]. If the type is primitive, specialized buffers are used ([Int32Buffer],
* [Float64Buffer], etc.), [ListBuffer] is returned otherwise.
*
* The [size] is specified, and each element is calculated by calling the specified [initializer] function.
*/
@Suppress("UNCHECKED_CAST")
public inline fun <T> auto(type: SafeType<T>, size: Int, initializer: (Int) -> T): Buffer<T> =
when (type.kType) {
typeOf<Double>() -> MutableBuffer.double(size) { initializer(it) as Double } as Buffer<T>
typeOf<Short>() -> MutableBuffer.short(size) { initializer(it) as Short } as Buffer<T>
typeOf<Int>() -> MutableBuffer.int(size) { initializer(it) as Int } as Buffer<T>
typeOf<Long>() -> MutableBuffer.long(size) { initializer(it) as Long } as Buffer<T>
typeOf<Float>() -> MutableBuffer.float(size) { initializer(it) as Float } as Buffer<T>
else -> boxing(size, initializer)
} }
}
/** /**
* Creates a [Buffer] of given type [T]. If the type is primitive, specialized buffers are used ([Int32Buffer], * Creates a [Buffer] of given type [T]. If the type is primitive, specialized buffers are used ([Int32Buffer],
* [Float64Buffer], etc.), [ListBuffer] is returned otherwise. * [Float64Buffer], etc.), [ListBuffer] is returned otherwise.
* *
* The [size] is specified, and each element is calculated by calling the specified [initializer] function. * The [size] is specified, and each element is calculated by calling the specified [initializer] function.
*/ */
public inline fun <reified T> auto(size: Int, initializer: (Int) -> T): Buffer<T> = @Suppress("UNCHECKED_CAST")
auto(safeTypeOf<T>(), size, initializer) public inline fun <reified T> Buffer(size: Int, initializer: (Int) -> T): Buffer<T> {
val type = safeTypeOf<T>()
return when (type.kType) {
typeOf<Double>() -> MutableBuffer.double(size) { initializer(it) as Double } as Buffer<T>
typeOf<Short>() -> MutableBuffer.short(size) { initializer(it) as Short } as Buffer<T>
typeOf<Int>() -> MutableBuffer.int(size) { initializer(it) as Int } as Buffer<T>
typeOf<Long>() -> MutableBuffer.long(size) { initializer(it) as Long } as Buffer<T>
typeOf<Float>() -> MutableBuffer.float(size) { initializer(it) as Float } as Buffer<T>
else -> List(size, initializer).asBuffer(type)
} }
} }
/**
* Creates a [Buffer] of given [type]. If the type is primitive, specialized buffers are used ([Int32Buffer],
* [Float64Buffer], etc.), [ListBuffer] is returned otherwise.
*
* The [size] is specified, and each element is calculated by calling the specified [initializer] function.
*/
@Suppress("UNCHECKED_CAST")
public fun <T> Buffer(
type: SafeType<T>,
size: Int,
initializer: (Int) -> T,
): Buffer<T> = when (type.kType) {
typeOf<Double>() -> MutableBuffer.double(size) { initializer(it) as Double } as Buffer<T>
typeOf<Short>() -> MutableBuffer.short(size) { initializer(it) as Short } as Buffer<T>
typeOf<Int>() -> MutableBuffer.int(size) { initializer(it) as Int } as Buffer<T>
typeOf<Long>() -> MutableBuffer.long(size) { initializer(it) as Long } as Buffer<T>
typeOf<Float>() -> MutableBuffer.float(size) { initializer(it) as Float } as Buffer<T>
else -> List(size, initializer).asBuffer(type)
}
/** /**
* Returns an [IntRange] of the valid indices for this [Buffer]. * Returns an [IntRange] of the valid indices for this [Buffer].
*/ */
@ -144,28 +162,17 @@ public fun <T> Buffer<T>.last(): T {
return get(size - 1) return get(size - 1)
} }
/**
* Immutable wrapper for [MutableBuffer].
*
* @param T the type of elements contained in the buffer.
* @property buffer The underlying buffer.
*/
@JvmInline
public value class ReadOnlyBuffer<T>(public val buffer: MutableBuffer<T>) : Buffer<T> {
override val size: Int get() = buffer.size
override operator fun get(index: Int): T = buffer[index]
override operator fun iterator(): Iterator<T> = buffer.iterator()
}
/** /**
* A buffer with content calculated on-demand. The calculated content is not stored, so it is recalculated on each call. * A buffer with content calculated on-demand. The calculated content is not stored, so it is recalculated on each call.
* Useful when one needs single element from the buffer. * Useful when one needs a single element from the buffer.
* *
* @param T the type of elements provided by the buffer. * @param T the type of elements provided by the buffer.
*/ */
public class VirtualBuffer<out T>(override val size: Int, private val generator: (Int) -> T) : Buffer<T> { public class VirtualBuffer<out T>(
override val type: SafeType<T>,
override val size: Int,
private val generator: (Int) -> T,
) : Buffer<T> {
override operator fun get(index: Int): T { override operator fun get(index: Int): T {
if (index < 0 || index >= size) throw IndexOutOfBoundsException("Expected index from 0 to ${size - 1}, but found $index") if (index < 0 || index >= size) throw IndexOutOfBoundsException("Expected index from 0 to ${size - 1}, but found $index")
return generator(index) return generator(index)
@ -177,6 +184,9 @@ public class VirtualBuffer<out T>(override val size: Int, private val generator:
} }
/** /**
* Convert this buffer to read-only buffer. * Inline builder for [VirtualBuffer]
*/ */
public fun <T> Buffer<T>.asReadOnly(): Buffer<T> = if (this is MutableBuffer) ReadOnlyBuffer(this) else this public inline fun <reified T> VirtualBuffer(
size: Int,
noinline generator: (Int) -> T,
): VirtualBuffer<T> = VirtualBuffer(safeTypeOf(), size, generator)

View File

@ -5,7 +5,13 @@
package space.kscience.kmath.structures package space.kscience.kmath.structures
import space.kscience.kmath.nd.* import space.kscience.attributes.SafeType
import space.kscience.kmath.nd.BufferND
import space.kscience.kmath.nd.ShapeND
import space.kscience.kmath.nd.Structure2D
import space.kscience.kmath.nd.as2D
import kotlin.collections.component1
import kotlin.collections.component2
/** /**
* A context that allows to operate on a [MutableBuffer] as on 2d array * A context that allows to operate on a [MutableBuffer] as on 2d array
@ -27,11 +33,13 @@ internal class BufferAccessor2D<T>(
fun create(mat: Structure2D<T>): MutableBuffer<T> = create { i, j -> mat[i, j] } fun create(mat: Structure2D<T>): MutableBuffer<T> = create { i, j -> mat[i, j] }
//TODO optimize wrapper //TODO optimize wrapper
fun MutableBuffer<T>.toStructure2D(): Structure2D<T> = StructureND.buffered( fun MutableBuffer<T>.toStructure2D(): Structure2D<T> = BufferND(
ColumnStrides(ShapeND(rowNum, colNum)) type, ShapeND(rowNum, colNum)
) { (i, j) -> get(i, j) }.as2D() ) { (i, j) -> get(i, j) }.as2D()
inner class Row(val buffer: MutableBuffer<T>, val rowIndex: Int) : MutableBuffer<T> { inner class Row(val buffer: MutableBuffer<T>, val rowIndex: Int) : MutableBuffer<T> {
override val type: SafeType<T> get() = buffer.type
override val size: Int get() = colNum override val size: Int get() = colNum
override operator fun get(index: Int): T = buffer[rowIndex, index] override operator fun get(index: Int): T = buffer[rowIndex, index]

View File

@ -5,6 +5,8 @@
package space.kscience.kmath.structures package space.kscience.kmath.structures
import space.kscience.attributes.SafeType
import space.kscience.attributes.safeTypeOf
import kotlin.experimental.and import kotlin.experimental.and
/** /**
@ -57,6 +59,9 @@ public class FlaggedDoubleBuffer(
public val values: DoubleArray, public val values: DoubleArray,
public val flags: ByteArray public val flags: ByteArray
) : FlaggedBuffer<Double?>, Buffer<Double?> { ) : FlaggedBuffer<Double?>, Buffer<Double?> {
override val type: SafeType<Double?> = safeTypeOf()
init { init {
require(values.size == flags.size) { "Values and flags must have the same dimensions" } require(values.size == flags.size) { "Values and flags must have the same dimensions" }
} }

View File

@ -5,6 +5,8 @@
package space.kscience.kmath.structures package space.kscience.kmath.structures
import space.kscience.attributes.SafeType
import space.kscience.attributes.safeTypeOf
import kotlin.jvm.JvmInline import kotlin.jvm.JvmInline
/** /**
@ -15,6 +17,9 @@ import kotlin.jvm.JvmInline
*/ */
@JvmInline @JvmInline
public value class Float32Buffer(public val array: FloatArray) : PrimitiveBuffer<Float> { public value class Float32Buffer(public val array: FloatArray) : PrimitiveBuffer<Float> {
override val type: SafeType<Float32> get() = safeTypeOf<Float32>()
override val size: Int get() = array.size override val size: Int get() = array.size
override operator fun get(index: Int): Float = array[index] override operator fun get(index: Int): Float = array[index]

View File

@ -5,6 +5,8 @@
package space.kscience.kmath.structures package space.kscience.kmath.structures
import space.kscience.attributes.SafeType
import space.kscience.attributes.safeTypeOf
import space.kscience.kmath.operations.BufferTransform import space.kscience.kmath.operations.BufferTransform
import kotlin.jvm.JvmInline import kotlin.jvm.JvmInline
@ -15,6 +17,9 @@ import kotlin.jvm.JvmInline
*/ */
@JvmInline @JvmInline
public value class Float64Buffer(public val array: DoubleArray) : PrimitiveBuffer<Double> { public value class Float64Buffer(public val array: DoubleArray) : PrimitiveBuffer<Double> {
override val type: SafeType<Double> get() = safeTypeOf()
override val size: Int get() = array.size override val size: Int get() = array.size
override operator fun get(index: Int): Double = array[index] override operator fun get(index: Int): Double = array[index]

View File

@ -5,6 +5,8 @@
package space.kscience.kmath.structures package space.kscience.kmath.structures
import space.kscience.attributes.SafeType
import space.kscience.attributes.safeTypeOf
import kotlin.jvm.JvmInline import kotlin.jvm.JvmInline
/** /**
@ -14,6 +16,8 @@ import kotlin.jvm.JvmInline
*/ */
@JvmInline @JvmInline
public value class Int16Buffer(public val array: ShortArray) : MutableBuffer<Short> { public value class Int16Buffer(public val array: ShortArray) : MutableBuffer<Short> {
override val type: SafeType<Short> get() = safeTypeOf()
override val size: Int get() = array.size override val size: Int get() = array.size
override operator fun get(index: Int): Short = array[index] override operator fun get(index: Int): Short = array[index]

View File

@ -5,6 +5,8 @@
package space.kscience.kmath.structures package space.kscience.kmath.structures
import space.kscience.attributes.SafeType
import space.kscience.attributes.safeTypeOf
import kotlin.jvm.JvmInline import kotlin.jvm.JvmInline
/** /**
@ -14,6 +16,9 @@ import kotlin.jvm.JvmInline
*/ */
@JvmInline @JvmInline
public value class Int32Buffer(public val array: IntArray) : PrimitiveBuffer<Int> { public value class Int32Buffer(public val array: IntArray) : PrimitiveBuffer<Int> {
override val type: SafeType<Int> get() = safeTypeOf()
override val size: Int get() = array.size override val size: Int get() = array.size
override operator fun get(index: Int): Int = array[index] override operator fun get(index: Int): Int = array[index]

View File

@ -5,6 +5,8 @@
package space.kscience.kmath.structures package space.kscience.kmath.structures
import space.kscience.attributes.SafeType
import space.kscience.attributes.safeTypeOf
import kotlin.jvm.JvmInline import kotlin.jvm.JvmInline
/** /**
@ -14,6 +16,9 @@ import kotlin.jvm.JvmInline
*/ */
@JvmInline @JvmInline
public value class Int64Buffer(public val array: LongArray) : PrimitiveBuffer<Long> { public value class Int64Buffer(public val array: LongArray) : PrimitiveBuffer<Long> {
override val type: SafeType<Long> get() = safeTypeOf()
override val size: Int get() = array.size override val size: Int get() = array.size
override operator fun get(index: Int): Long = array[index] override operator fun get(index: Int): Long = array[index]

View File

@ -5,6 +5,8 @@
package space.kscience.kmath.structures package space.kscience.kmath.structures
import space.kscience.attributes.SafeType
import space.kscience.attributes.safeTypeOf
import kotlin.jvm.JvmInline import kotlin.jvm.JvmInline
/** /**
@ -14,6 +16,9 @@ import kotlin.jvm.JvmInline
*/ */
@JvmInline @JvmInline
public value class Int8Buffer(public val array: ByteArray) : MutableBuffer<Byte> { public value class Int8Buffer(public val array: ByteArray) : MutableBuffer<Byte> {
override val type: SafeType<Byte> get() = safeTypeOf()
override val size: Int get() = array.size override val size: Int get() = array.size
override operator fun get(index: Int): Byte = array[index] override operator fun get(index: Int): Byte = array[index]

View File

@ -5,7 +5,8 @@
package space.kscience.kmath.structures package space.kscience.kmath.structures
import kotlin.jvm.JvmInline import space.kscience.attributes.SafeType
import space.kscience.attributes.safeTypeOf
/** /**
* [Buffer] implementation over [List]. * [Buffer] implementation over [List].
@ -13,9 +14,7 @@ import kotlin.jvm.JvmInline
* @param T the type of elements contained in the buffer. * @param T the type of elements contained in the buffer.
* @property list The underlying list. * @property list The underlying list.
*/ */
public class ListBuffer<T>(public val list: List<T>) : Buffer<T> { public class ListBuffer<T>(override val type: SafeType<T>, public val list: List<T>) : Buffer<T> {
public constructor(size: Int, initializer: (Int) -> T) : this(List(size, initializer))
override val size: Int get() = list.size override val size: Int get() = list.size
@ -29,7 +28,12 @@ public class ListBuffer<T>(public val list: List<T>) : Buffer<T> {
/** /**
* Returns an [ListBuffer] that wraps the original list. * Returns an [ListBuffer] that wraps the original list.
*/ */
public fun <T> List<T>.asBuffer(): ListBuffer<T> = ListBuffer(this) public fun <T> List<T>.asBuffer(type: SafeType<T>): ListBuffer<T> = ListBuffer(type, this)
/**
* Returns an [ListBuffer] that wraps the original list.
*/
public inline fun <reified T> List<T>.asBuffer(): ListBuffer<T> = asBuffer(safeTypeOf())
/** /**
* [MutableBuffer] implementation over [MutableList]. * [MutableBuffer] implementation over [MutableList].
@ -37,10 +41,7 @@ public fun <T> List<T>.asBuffer(): ListBuffer<T> = ListBuffer(this)
* @param T the type of elements contained in the buffer. * @param T the type of elements contained in the buffer.
* @property list The underlying list. * @property list The underlying list.
*/ */
@JvmInline public class MutableListBuffer<T>(override val type: SafeType<T>, public val list: MutableList<T>) : MutableBuffer<T> {
public value class MutableListBuffer<T>(public val list: MutableList<T>) : MutableBuffer<T> {
public constructor(size: Int, initializer: (Int) -> T) : this(MutableList(size, initializer))
override val size: Int get() = list.size override val size: Int get() = list.size
@ -51,10 +52,24 @@ public value class MutableListBuffer<T>(public val list: MutableList<T>) : Mutab
} }
override operator fun iterator(): Iterator<T> = list.iterator() override operator fun iterator(): Iterator<T> = list.iterator()
override fun copy(): MutableBuffer<T> = MutableListBuffer(ArrayList(list)) override fun copy(): MutableBuffer<T> = MutableListBuffer(type, ArrayList(list))
override fun toString(): String = Buffer.toString(this)
} }
/** /**
* Returns an [MutableListBuffer] that wraps the original list. * Returns an [MutableListBuffer] that wraps the original list.
*/ */
public fun <T> MutableList<T>.asMutableBuffer(): MutableListBuffer<T> = MutableListBuffer(this) public fun <T> MutableList<T>.asMutableBuffer(type: SafeType<T>): MutableListBuffer<T> = MutableListBuffer(
type,
this
)
/**
* Returns an [MutableListBuffer] that wraps the original list.
*/
public inline fun <reified T> MutableList<T>.asMutableBuffer(): MutableListBuffer<T> = MutableListBuffer(
safeTypeOf(),
this
)

View File

@ -61,41 +61,35 @@ public interface MutableBuffer<T> : Buffer<T> {
*/ */
public inline fun float(size: Int, initializer: (Int) -> Float): Float32Buffer = public inline fun float(size: Int, initializer: (Int) -> Float): Float32Buffer =
Float32Buffer(size, initializer) Float32Buffer(size, initializer)
}
}
/** /**
* Create a boxing mutable buffer of given type
*/
public inline fun <T> boxing(size: Int, initializer: (Int) -> T): MutableBuffer<T> =
MutableListBuffer(MutableList(size, initializer))
/**
* Creates a [MutableBuffer] of given [type]. If the type is primitive, specialized buffers are used * Creates a [MutableBuffer] of given [type]. If the type is primitive, specialized buffers are used
* ([Int32Buffer], [Float64Buffer], etc.), [ListBuffer] is returned otherwise. * ([Int32Buffer], [Float64Buffer], etc.), [ListBuffer] is returned otherwise.
* *
* The [size] is specified, and each element is calculated by calling the specified [initializer] function. * The [size] is specified, and each element is calculated by calling the specified [initializer] function.
*/ */
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
public inline fun <T> auto(type: SafeType<T>, size: Int, initializer: (Int) -> T): MutableBuffer<T> = public inline fun <T> MutableBuffer(type: SafeType<T>, size: Int, initializer: (Int) -> T): MutableBuffer<T> =
when (type.kType) { when (type.kType) {
typeOf<Double>() -> double(size) { initializer(it) as Double } as MutableBuffer<T> typeOf<Double>() -> MutableBuffer.double(size) { initializer(it) as Double } as MutableBuffer<T>
typeOf<Short>() -> short(size) { initializer(it) as Short } as MutableBuffer<T> typeOf<Short>() -> MutableBuffer.short(size) { initializer(it) as Short } as MutableBuffer<T>
typeOf<Int>() -> int(size) { initializer(it) as Int } as MutableBuffer<T> typeOf<Int>() -> MutableBuffer.int(size) { initializer(it) as Int } as MutableBuffer<T>
typeOf<Float>() -> float(size) { initializer(it) as Float } as MutableBuffer<T> typeOf<Float>() -> MutableBuffer.float(size) { initializer(it) as Float } as MutableBuffer<T>
typeOf<Long>() -> long(size) { initializer(it) as Long } as MutableBuffer<T> typeOf<Long>() -> MutableBuffer.long(size) { initializer(it) as Long } as MutableBuffer<T>
else -> boxing(size, initializer) else -> MutableListBuffer(type, MutableList(size, initializer))
} }
/** /**
* Creates a [MutableBuffer] of given type [T]. If the type is primitive, specialized buffers are used * Creates a [MutableBuffer] of given type [T]. If the type is primitive, specialized buffers are used
* ([Int32Buffer], [Float64Buffer], etc.), [ListBuffer] is returned otherwise. * ([Int32Buffer], [Float64Buffer], etc.), [ListBuffer] is returned otherwise.
* *
* The [size] is specified, and each element is calculated by calling the specified [initializer] function. * The [size] is specified, and each element is calculated by calling the specified [initializer] function.
*/ */
public inline fun <reified T> auto(size: Int, initializer: (Int) -> T): MutableBuffer<T> = public inline fun <reified T> MutableBuffer(size: Int, initializer: (Int) -> T): MutableBuffer<T> =
auto(safeTypeOf<T>(), size, initializer) MutableBuffer(safeTypeOf<T>(), size, initializer)
}
}
public sealed interface PrimitiveBuffer<T> : MutableBuffer<T> public sealed interface PrimitiveBuffer<T> : MutableBuffer<T>

View File

@ -40,7 +40,7 @@ class DoubleLUSolverTest {
//Check determinant //Check determinant
assertEquals(7.0, lup.determinant) assertEquals(7.0, lup.determinant)
assertMatrixEquals(lup.p dot matrix, lup.l dot lup.u) assertMatrixEquals(lup.pivotMatrix dot matrix, lup.l dot lup.u)
} }
@Test @Test

View File

@ -19,19 +19,13 @@ import org.ejml.sparse.csc.factory.DecompositionFactory_DSCC
import org.ejml.sparse.csc.factory.DecompositionFactory_FSCC import org.ejml.sparse.csc.factory.DecompositionFactory_FSCC
import org.ejml.sparse.csc.factory.LinearSolverFactory_DSCC import org.ejml.sparse.csc.factory.LinearSolverFactory_DSCC
import org.ejml.sparse.csc.factory.LinearSolverFactory_FSCC import org.ejml.sparse.csc.factory.LinearSolverFactory_FSCC
import space.kscience.kmath.UnstableKMathAPI
import space.kscience.kmath.linear.* import space.kscience.kmath.linear.*
import space.kscience.kmath.linear.Matrix import space.kscience.kmath.linear.Matrix
import space.kscience.kmath.UnstableKMathAPI
import space.kscience.kmath.nd.StructureFeature import space.kscience.kmath.nd.StructureFeature
import space.kscience.kmath.structures.Float64
import space.kscience.kmath.structures.Float32
import space.kscience.kmath.operations.Float64Field
import space.kscience.kmath.operations.Float32Field import space.kscience.kmath.operations.Float32Field
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.Float64Field
import space.kscience.kmath.operations.FloatField
import space.kscience.kmath.operations.invoke import space.kscience.kmath.operations.invoke
import space.kscience.kmath.structures.Float64Buffer
import space.kscience.kmath.structures.Float32Buffer
import space.kscience.kmath.structures.DoubleBuffer import space.kscience.kmath.structures.DoubleBuffer
import space.kscience.kmath.structures.FloatBuffer import space.kscience.kmath.structures.FloatBuffer
import kotlin.reflect.KClass import kotlin.reflect.KClass

View File

@ -147,12 +147,12 @@ public fun <V : Any, A : Field<V>> Histogram.Companion.uniformNDFromRanges(
bufferFactory: BufferFactory<V> = valueAlgebraND.elementAlgebra.bufferFactory, bufferFactory: BufferFactory<V> = valueAlgebraND.elementAlgebra.bufferFactory,
): UniformHistogramGroupND<V, A> = UniformHistogramGroupND( ): UniformHistogramGroupND<V, A> = UniformHistogramGroupND(
valueAlgebraND, valueAlgebraND,
ListBuffer( DoubleBuffer(
ranges ranges
.map(Pair<ClosedFloatingPointRange<Double>, Int>::first) .map(Pair<ClosedFloatingPointRange<Double>, Int>::first)
.map(ClosedFloatingPointRange<Double>::start) .map(ClosedFloatingPointRange<Double>::start)
), ),
ListBuffer( DoubleBuffer(
ranges ranges
.map(Pair<ClosedFloatingPointRange<Double>, Int>::first) .map(Pair<ClosedFloatingPointRange<Double>, Int>::first)
.map(ClosedFloatingPointRange<Double>::endInclusive) .map(ClosedFloatingPointRange<Double>::endInclusive)