From 46eacbb7506d40962fa91f70cc0a047d63ed5a17 Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Fri, 3 Nov 2023 09:56:19 +0300 Subject: [PATCH] 0.4 WIP --- CHANGELOG.md | 9 +- .../space/kscience/attributes/SafeType.kt | 2 +- .../kscience/kmath/data/XYColumnarData.kt | 6 +- .../kscience/kmath/expressions/DSAlgebra.kt | 4 +- .../FunctionalExpressionAlgebra.kt | 4 + .../kscience/kmath/expressions/MstAlgebra.kt | 25 ++- .../kmath/expressions/SimpleAutoDiff.kt | 32 ++-- .../kscience/kmath/linear/LinearSpace.kt | 9 +- .../kscience/kmath/linear/LupDecomposition.kt | 26 ++- .../kscience/kmath/linear/MatrixBuilder.kt | 11 +- .../kscience/kmath/linear/MatrixWrapper.kt | 6 +- .../space/kscience/kmath/linear/Transposed.kt | 4 + .../kscience/kmath/linear/VirtualMatrix.kt | 6 +- .../kscience/kmath/linear/matrixAttributes.kt | 8 +- .../space/kscience/kmath/misc/sorting.kt | 17 +- .../space/kscience/kmath/nd/AlgebraND.kt | 4 +- .../kscience/kmath/nd/BufferAlgebraND.kt | 5 + .../space/kscience/kmath/nd/BufferND.kt | 11 +- .../kscience/kmath/nd/PermutedStructureND.kt | 4 + .../space/kscience/kmath/nd/Structure1D.kt | 15 +- .../space/kscience/kmath/nd/Structure2D.kt | 23 ++- .../space/kscience/kmath/nd/StructureND.kt | 105 ++++++------ .../kscience/kmath/nd/VirtualStructureND.kt | 7 +- .../space/kscience/kmath/nd/operationsND.kt | 4 +- .../kscience/kmath/operations/Algebra.kt | 26 ++- .../space/kscience/kmath/operations/BigInt.kt | 9 +- .../kmath/operations/BufferAlgebra.kt | 5 +- .../kscience/kmath/operations/LogicAlgebra.kt | 4 + .../kmath/operations/integerFields.kt | 12 +- .../kscience/kmath/operations/numbers.kt | 2 +- .../space/kscience/kmath/structures/Buffer.kt | 154 ++++++++++-------- .../kmath/structures/BufferAccessor2D.kt | 14 +- .../kmath/structures/FlaggedBuffer.kt | 5 + .../kmath/structures/Float32Buffer.kt | 5 + .../kmath/structures/Float64Buffer.kt | 5 + .../kscience/kmath/structures/Int16Buffer.kt | 4 + .../kscience/kmath/structures/Int32Buffer.kt | 5 + .../kscience/kmath/structures/Int64Buffer.kt | 5 + .../kscience/kmath/structures/Int8Buffer.kt | 5 + .../kscience/kmath/structures/ListBuffer.kt | 37 +++-- .../kmath/structures/MutableBuffer.kt | 60 +++---- .../kmath/linear/DoubleLUSolverTest.kt | 2 +- .../space/kscience/kmath/ejml/_generated.kt | 10 +- .../histogram/UniformHistogramGroupND.kt | 4 +- 44 files changed, 451 insertions(+), 269 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a1b23d817..ef5e0490c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,10 +3,11 @@ ## Unreleased ### Added -- Integer division algebras -- Float32 geometries -- New Attributes-kt module that could be used as stand-alone. It declares type-safe attributes containers. -- Explicit `mutableStructureND` builders for mutable structures +- Explicit `SafeType` for algebras and buffers. +- Integer division algebras. +- Float32 geometries. +- New Attributes-kt module that could be used as stand-alone. It declares. type-safe attributes containers. +- Explicit `mutableStructureND` builders for mutable structures. ### Changed - Default naming for algebra and buffers now uses IntXX/FloatXX notation instead of Java types. diff --git a/attributes-kt/src/commonMain/kotlin/space/kscience/attributes/SafeType.kt b/attributes-kt/src/commonMain/kotlin/space/kscience/attributes/SafeType.kt index 7225d1a6b..0fac2477c 100644 --- a/attributes-kt/src/commonMain/kotlin/space/kscience/attributes/SafeType.kt +++ b/attributes-kt/src/commonMain/kotlin/space/kscience/attributes/SafeType.kt @@ -16,7 +16,7 @@ import kotlin.reflect.typeOf * @param kType raw [KType] */ @JvmInline -public value class SafeType @PublishedApi internal constructor(public val kType: KType) +public value class SafeType @PublishedApi internal constructor(public val kType: KType) public inline fun safeTypeOf(): SafeType = SafeType(typeOf()) diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/data/XYColumnarData.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/data/XYColumnarData.kt index 23cbcd0c7..b06ce1349 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/data/XYColumnarData.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/data/XYColumnarData.kt @@ -52,10 +52,10 @@ public interface XYColumnarData : ColumnarData { * Create two-column data from a list of row-objects (zero-copy) */ @UnstableKMathAPI - public fun ofList( + public inline fun ofList( list: List, - xConverter: (I) -> X, - yConverter: (I) -> Y, + noinline xConverter: (I) -> X, + noinline yConverter: (I) -> Y, ): XYColumnarData = object : XYColumnarData { override val size: Int get() = list.size diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/DSAlgebra.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/DSAlgebra.kt index 63f96b0af..817f38ff0 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/DSAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/DSAlgebra.kt @@ -13,8 +13,6 @@ import space.kscience.kmath.structures.MutableBufferFactory import space.kscience.kmath.structures.asBuffer import kotlin.math.max import kotlin.math.min -import kotlin.reflect.KType -import kotlin.reflect.typeOf /** * Class representing both the value and the differentials of a function. @@ -84,6 +82,8 @@ public abstract class DSAlgebra>( bindings: Map, ) : ExpressionAlgebra>, SymbolIndexer { + override val bufferFactory: MutableBufferFactory> = MutableBufferFactory() + /** * Get the compiler for number of free parameters and order. * diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt index 8e37f3d32..5b4dcd638 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt @@ -6,6 +6,7 @@ package space.kscience.kmath.expressions import space.kscience.kmath.operations.* +import space.kscience.kmath.structures.MutableBufferFactory import kotlin.contracts.InvocationKind import kotlin.contracts.contract @@ -17,6 +18,8 @@ import kotlin.contracts.contract public abstract class FunctionalExpressionAlgebra>( public val algebra: A, ) : ExpressionAlgebra> { + override val bufferFactory: MutableBufferFactory> = MutableBufferFactory>() + /** * Builds an Expression of constant expression that does not depend on arguments. */ @@ -49,6 +52,7 @@ public abstract class FunctionalExpressionAlgebra>( public open class FunctionalExpressionGroup>( algebra: A, ) : FunctionalExpressionAlgebra(algebra), Group> { + override val zero: Expression get() = const(algebra.zero) override fun Expression.unaryMinus(): Expression = diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/MstAlgebra.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/MstAlgebra.kt index c894cf00a..c568eece0 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/MstAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/MstAlgebra.kt @@ -7,11 +7,14 @@ package space.kscience.kmath.expressions import space.kscience.kmath.UnstableKMathAPI import space.kscience.kmath.operations.* +import space.kscience.kmath.structures.MutableBufferFactory /** * [Algebra] over [MST] nodes. */ public object MstNumericAlgebra : NumericAlgebra { + override val bufferFactory: MutableBufferFactory = MutableBufferFactory() + override fun number(value: Number): MST.Numeric = MST.Numeric(value) override fun bindSymbolOrNull(value: String): Symbol = StringSymbol(value) override fun bindSymbol(value: String): Symbol = bindSymbolOrNull(value) @@ -27,6 +30,9 @@ public object MstNumericAlgebra : NumericAlgebra { * [Group] over [MST] nodes. */ public object MstGroup : Group, NumericAlgebra, ScaleOperations { + + override val bufferFactory: MutableBufferFactory = MutableBufferFactory() + override val zero: MST.Numeric = number(0.0) override fun number(value: Number): MST.Numeric = MstNumericAlgebra.number(value) @@ -57,7 +63,11 @@ public object MstGroup : Group, NumericAlgebra, ScaleOperations { @Suppress("OVERRIDE_BY_INLINE") @OptIn(UnstableKMathAPI::class) public object MstRing : Ring, NumbersAddOps, ScaleOperations { + + override val bufferFactory: MutableBufferFactory = MutableBufferFactory() + override inline val zero: MST.Numeric get() = MstGroup.zero + override val one: MST.Numeric = number(1.0) override fun number(value: Number): MST.Numeric = MstGroup.number(value) @@ -87,7 +97,11 @@ public object MstRing : Ring, NumbersAddOps, ScaleOperations { @Suppress("OVERRIDE_BY_INLINE") @OptIn(UnstableKMathAPI::class) public object MstField : Field, NumbersAddOps, ScaleOperations { + + override val bufferFactory: MutableBufferFactory = MutableBufferFactory() + override inline val zero: MST.Numeric get() = MstRing.zero + override inline val one: MST.Numeric get() = MstRing.one override fun bindSymbolOrNull(value: String): Symbol = MstNumericAlgebra.bindSymbolOrNull(value) @@ -117,7 +131,11 @@ public object MstField : Field, NumbersAddOps, ScaleOperations { */ @Suppress("OVERRIDE_BY_INLINE") public object MstExtendedField : ExtendedField, NumericAlgebra { + + override val bufferFactory: MutableBufferFactory = MutableBufferFactory() + override inline val zero: MST.Numeric get() = MstField.zero + override inline val one: MST.Numeric get() = MstField.one override fun bindSymbolOrNull(value: String): Symbol = MstNumericAlgebra.bindSymbolOrNull(value) @@ -164,6 +182,9 @@ public object MstExtendedField : ExtendedField, NumericAlgebra { */ @UnstableKMathAPI public object MstLogicAlgebra : LogicAlgebra { + + override val bufferFactory: MutableBufferFactory = MutableBufferFactory() + override fun bindSymbolOrNull(value: String): MST = super.bindSymbolOrNull(value) ?: StringSymbol(value) override fun const(boolean: Boolean): Symbol = if (boolean) { @@ -176,7 +197,7 @@ public object MstLogicAlgebra : LogicAlgebra { override fun MST.and(other: MST): MST = MST.Binary(Boolean::and.name, this, other) - override fun MST.or(other: MST): MST = MST.Binary(Boolean::or.name, this, other) + override fun MST.or(other: MST): MST = MST.Binary(Boolean::or.name, this, other) - override fun MST.xor(other: MST): MST = MST.Binary(Boolean::xor.name, this, other) + override fun MST.xor(other: MST): MST = MST.Binary(Boolean::xor.name, this, other) } diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/SimpleAutoDiff.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/SimpleAutoDiff.kt index 2bb5043b7..6d22b18dc 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/SimpleAutoDiff.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/SimpleAutoDiff.kt @@ -5,9 +5,11 @@ package space.kscience.kmath.expressions +import space.kscience.attributes.SafeType import space.kscience.kmath.UnstableKMathAPI import space.kscience.kmath.linear.Point import space.kscience.kmath.operations.* +import space.kscience.kmath.structures.MutableBufferFactory import space.kscience.kmath.structures.asBuffer import kotlin.contracts.InvocationKind import kotlin.contracts.contract @@ -30,9 +32,10 @@ public open class AutoDiffValue(public val value: T) */ public class DerivationResult( public val value: T, + override val type: SafeType, private val derivativeValues: Map, public val context: Field, -) { +) : WithType { /** * Returns derivative of [variable] or returns [Ring.zero] in [context]. */ @@ -49,19 +52,23 @@ public class DerivationResult( */ public fun DerivationResult.grad(vararg variables: Symbol): Point { check(variables.isNotEmpty()) { "Variable order is not provided for gradient construction" } - return variables.map(::derivative).asBuffer() + return variables.map(::derivative).asBuffer(type) } /** - * Represents field in context of which functions can be derived. + * Represents field. Function derivatives could be computed in this field */ @OptIn(UnstableKMathAPI::class) public open class SimpleAutoDiffField>( - public val context: F, + public val algebra: F, bindings: Map, ) : Field>, ExpressionAlgebra>, NumbersAddOps> { - override val zero: AutoDiffValue get() = const(context.zero) - override val one: AutoDiffValue get() = const(context.one) + + override val bufferFactory: MutableBufferFactory> = MutableBufferFactory>() + + override val zero: AutoDiffValue get() = const(algebra.zero) + + override val one: AutoDiffValue get() = const(algebra.one) // this stack contains pairs of blocks and values to apply them to private var stack: Array = arrayOfNulls(8) @@ -69,7 +76,7 @@ public open class SimpleAutoDiffField>( private val derivatives: MutableMap, T> = hashMapOf() private val bindings: Map> = bindings.entries.associate { - it.key.identity to AutoDiffVariableWithDerivative(it.key.identity, it.value, context.zero) + it.key.identity to AutoDiffVariableWithDerivative(it.key.identity, it.value, algebra.zero) } /** @@ -92,7 +99,7 @@ public open class SimpleAutoDiffField>( override fun bindSymbolOrNull(value: String): AutoDiffValue? = bindings[value] private fun getDerivative(variable: AutoDiffValue): T = - (variable as? AutoDiffVariableWithDerivative)?.d ?: derivatives[variable] ?: context.zero + (variable as? AutoDiffVariableWithDerivative)?.d ?: derivatives[variable] ?: algebra.zero private fun setDerivative(variable: AutoDiffValue, value: T) { if (variable is AutoDiffVariableWithDerivative) variable.d = value else derivatives[variable] = value @@ -103,7 +110,7 @@ public open class SimpleAutoDiffField>( while (sp > 0) { val value = stack[--sp] val block = stack[--sp] as F.(Any?) -> Unit - context.block(value) + algebra.block(value) } } @@ -130,7 +137,6 @@ public open class SimpleAutoDiffField>( * } * ``` */ - @Suppress("UNCHECKED_CAST") public fun derive(value: R, block: F.(R) -> Unit): R { // save block to stack for backward pass if (sp >= stack.size) stack = stack.copyOf(stack.size * 2) @@ -142,9 +148,9 @@ public open class SimpleAutoDiffField>( internal fun differentiate(function: SimpleAutoDiffField.() -> AutoDiffValue): DerivationResult { val result = function() - result.d = context.one // computing derivative w.r.t result + result.d = algebra.one // computing derivative w.r.t result runBackwardPass() - return DerivationResult(result.value, bindings.mapValues { it.value.d }, context) + return DerivationResult(result.value, algebra.type, bindings.mapValues { it.value.d }, algebra) } // // Overloads for Double constants @@ -194,7 +200,7 @@ public open class SimpleAutoDiffField>( public inline fun > SimpleAutoDiffField.const(block: F.() -> T): AutoDiffValue { contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } - return const(context.block()) + return const(algebra.block()) } diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/LinearSpace.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/LinearSpace.kt index 0780be29a..ddf4f17d4 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/LinearSpace.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/LinearSpace.kt @@ -10,10 +10,9 @@ import space.kscience.kmath.UnstableKMathAPI import space.kscience.kmath.nd.* import space.kscience.kmath.operations.BufferRingOps import space.kscience.kmath.operations.Ring +import space.kscience.kmath.operations.WithType import space.kscience.kmath.operations.invoke import space.kscience.kmath.structures.Buffer -import kotlin.reflect.KType -import kotlin.reflect.typeOf /** * Alias for [Structure2D] with more familiar name. @@ -34,7 +33,7 @@ public typealias Point = Buffer * A marker interface for algebras that operate on matrices * @param T type of matrix element */ -public interface MatrixOperations +public interface MatrixOperations : WithType /** * Basic operations on matrices and vectors. @@ -45,6 +44,8 @@ public interface MatrixOperations public interface LinearSpace> : MatrixOperations { public val elementAlgebra: A + override val type: SafeType get() = elementAlgebra.type + /** * Produces a matrix with this context and given dimensions. */ @@ -212,4 +213,4 @@ public fun Matrix.asVector(): Point = * @receiver a buffer. * @return the new matrix. */ -public fun Point.asMatrix(): VirtualMatrix = VirtualMatrix(size, 1) { i, _ -> get(i) } \ No newline at end of file +public fun Point.asMatrix(): VirtualMatrix = VirtualMatrix(type, size, 1) { i, _ -> get(i) } \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/LupDecomposition.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/LupDecomposition.kt index 9fe8397d8..1ac2ca5c6 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/LupDecomposition.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/LupDecomposition.kt @@ -22,11 +22,27 @@ import space.kscience.kmath.structures.* * @param l The lower triangular matrix in this decomposition. It may have [LowerTriangular]. * @param u The upper triangular matrix in this decomposition. It may have [UpperTriangular]. */ -public data class LupDecomposition( +public class LupDecomposition( + public val linearSpace: LinearSpace>, public val l: Matrix, public val u: Matrix, public val pivot: IntBuffer, -) +) { + public val elementAlgebra: Ring get() = linearSpace.elementAlgebra + + public val pivotMatrix: VirtualMatrix + get() = VirtualMatrix(linearSpace.type, l.rowNum, l.colNum) { row, column -> + if (column == pivot[row]) elementAlgebra.one else elementAlgebra.zero + } + + public val LupDecomposition.determinant by lazy { + elementAlgebra { (0 until l.shape[0]).fold(if (even) one else -one) { value, i -> value * lu[i, i] } } + } + +} + + + public class LupDecompositionAttribute(type: SafeType>) : @@ -168,7 +184,7 @@ public fun > LinearSpace>.lup( for (row in col + 1 until m) lu[row, col] /= luDiag } - val l: MatrixWrapper = VirtualMatrix(rowNum, colNum) { i, j -> + val l: MatrixWrapper = VirtualMatrix(type, rowNum, colNum) { i, j -> when { j < i -> lu[i, j] j == i -> one @@ -176,7 +192,7 @@ public fun > LinearSpace>.lup( } }.withAttribute(LowerTriangular) - val u = VirtualMatrix(rowNum, colNum) { i, j -> + val u = VirtualMatrix(type, rowNum, colNum) { i, j -> if (j >= i) lu[i, j] else zero }.withAttribute(UpperTriangular) // @@ -184,7 +200,7 @@ public fun > LinearSpace>.lup( // if (j == pivot[i]) one else zero // }.withAttribute(Determinant, if (even) one else -one) - return LupDecomposition(l, u, pivot.asBuffer()) + return LupDecomposition(this@lup, l, u, pivot.asBuffer()) } } } diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/MatrixBuilder.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/MatrixBuilder.kt index 7b6c36dfc..9543555e0 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/MatrixBuilder.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/MatrixBuilder.kt @@ -6,16 +6,21 @@ package space.kscience.kmath.linear import space.kscience.attributes.FlagAttribute +import space.kscience.attributes.SafeType import space.kscience.kmath.UnstableKMathAPI import space.kscience.kmath.operations.Ring +import space.kscience.kmath.operations.WithType import space.kscience.kmath.structures.BufferAccessor2D -import space.kscience.kmath.structures.MutableBuffer +import space.kscience.kmath.structures.MutableBufferFactory public class MatrixBuilder>( public val linearSpace: LinearSpace, public val rows: Int, public val columns: Int, -) { +) : WithType { + + override val type: SafeType get() = linearSpace.type + public operator fun invoke(vararg elements: T): Matrix { require(rows * columns == elements.size) { "The number of elements ${elements.size} is not equal $rows * $columns" } return linearSpace.buildMatrix(rows, columns) { i, j -> elements[i * columns + j] } @@ -61,7 +66,7 @@ public fun > MatrixBuilder.symmetric( builder: (i: Int, j: Int) -> T, ): Matrix { require(columns == rows) { "In order to build symmetric matrix, number of rows $rows should be equal to number of columns $columns" } - return with(BufferAccessor2D(rows, rows, MutableBuffer.Companion::boxing)) { + return with(BufferAccessor2D(rows, rows, MutableBufferFactory(type))) { val cache = factory(rows * rows) { null } linearSpace.buildMatrix(rows, rows) { i, j -> val cached = cache[i, j] diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/MatrixWrapper.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/MatrixWrapper.kt index a129b04f0..9d89f7636 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/MatrixWrapper.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/MatrixWrapper.kt @@ -39,7 +39,7 @@ public fun > Matrix.withAttribute( attribute: A, attrValue: T, ): MatrixWrapper = if (this is MatrixWrapper) { - MatrixWrapper(origin, attributes.withAttribute(attribute,attrValue)) + MatrixWrapper(origin, attributes.withAttribute(attribute, attrValue)) } else { MatrixWrapper(this, Attributes(attribute, attrValue)) } @@ -68,7 +68,7 @@ public fun Matrix.modifyAttributes(modifier: (Attributes) -> Attrib public fun LinearSpace>.one( rows: Int, columns: Int, -): MatrixWrapper = VirtualMatrix(rows, columns) { i, j -> +): MatrixWrapper = VirtualMatrix(type, rows, columns) { i, j -> if (i == j) elementAlgebra.one else elementAlgebra.zero }.withAttribute(IsUnit) @@ -79,6 +79,6 @@ public fun LinearSpace>.one( public fun LinearSpace>.zero( rows: Int, columns: Int, -): MatrixWrapper = VirtualMatrix(rows, columns) { _, _ -> +): MatrixWrapper = VirtualMatrix(type, rows, columns) { _, _ -> elementAlgebra.zero }.withAttribute(IsZero) diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/Transposed.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/Transposed.kt index 3e93a2a23..2cc566f92 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/Transposed.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/Transposed.kt @@ -6,10 +6,14 @@ package space.kscience.kmath.linear import space.kscience.attributes.Attributes +import space.kscience.attributes.SafeType public class TransposedMatrix(public val origin: Matrix) : Matrix { + override val type: SafeType get() = origin.type + override val rowNum: Int get() = origin.colNum + override val colNum: Int get() = origin.rowNum override fun get(i: Int, j: Int): T = origin[j, i] diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/VirtualMatrix.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/VirtualMatrix.kt index 97de82c7a..9cb7c5771 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/VirtualMatrix.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/VirtualMatrix.kt @@ -6,6 +6,7 @@ package space.kscience.kmath.linear import space.kscience.attributes.Attributes +import space.kscience.attributes.SafeType import space.kscience.kmath.nd.ShapeND @@ -14,7 +15,8 @@ import space.kscience.kmath.nd.ShapeND * * @property generator the function that provides elements. */ -public class VirtualMatrix( +public class VirtualMatrix( + override val type: SafeType, override val rowNum: Int, override val colNum: Int, override val attributes: Attributes = Attributes.EMPTY, @@ -29,4 +31,4 @@ public class VirtualMatrix( public fun MatrixBuilder.virtual( attributes: Attributes = Attributes.EMPTY, generator: (i: Int, j: Int) -> T, -): VirtualMatrix = VirtualMatrix(rows, columns, attributes, generator) +): VirtualMatrix = VirtualMatrix(type, rows, columns, attributes, generator) diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/matrixAttributes.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/matrixAttributes.kt index 5bb35edef..8787d0e09 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/matrixAttributes.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/matrixAttributes.kt @@ -3,13 +3,11 @@ * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. */ -@file:OptIn(UnstableKMathAPI::class) @file:Suppress("UnusedReceiverParameter") package space.kscience.kmath.linear import space.kscience.attributes.* -import space.kscience.kmath.UnstableKMathAPI import space.kscience.kmath.nd.StructureAttribute /** @@ -51,9 +49,11 @@ public val MatrixOperations.Inverted: Inverted get() = Inverted(safeTy * * @param T the type of matrices' items. */ -public class Determinant : MatrixAttribute +public class Determinant(type: SafeType) : + PolymorphicAttribute(type), + MatrixAttribute -public val MatrixOperations.Determinant: Determinant get() = Determinant() +public val MatrixOperations.Determinant: Determinant get() = Determinant(type) /** * Matrices with this feature are lower triangular ones. diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/misc/sorting.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/misc/sorting.kt index f369f0614..0c0ef82fb 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/misc/sorting.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/misc/sorting.kt @@ -4,6 +4,8 @@ */ +@file:OptIn(UnstableKMathAPI::class) + package space.kscience.kmath.misc import space.kscience.kmath.UnstableKMathAPI @@ -22,10 +24,9 @@ public fun > Buffer.indicesSorted(): IntArray = permSortInd /** * Create a zero-copy virtual buffer that contains the same elements but in ascending order */ -@OptIn(UnstableKMathAPI::class) public fun > Buffer.sorted(): Buffer { val permutations = indicesSorted() - return VirtualBuffer(size) { this[permutations[it]] } + return VirtualBuffer(type, size) { this[permutations[it]] } } @UnstableKMathAPI @@ -35,30 +36,27 @@ public fun > Buffer.indicesSortedDescending(): IntArray = /** * Create a zero-copy virtual buffer that contains the same elements but in descending order */ -@OptIn(UnstableKMathAPI::class) public fun > Buffer.sortedDescending(): Buffer { val permutations = indicesSortedDescending() - return VirtualBuffer(size) { this[permutations[it]] } + return VirtualBuffer(type, size) { this[permutations[it]] } } @UnstableKMathAPI public fun > Buffer.indicesSortedBy(selector: (V) -> C): IntArray = permSortIndicesWith(compareBy { selector(get(it)) }) -@OptIn(UnstableKMathAPI::class) public fun > Buffer.sortedBy(selector: (V) -> C): Buffer { val permutations = indicesSortedBy(selector) - return VirtualBuffer(size) { this[permutations[it]] } + return VirtualBuffer(type, size) { this[permutations[it]] } } @UnstableKMathAPI public fun > Buffer.indicesSortedByDescending(selector: (V) -> C): IntArray = permSortIndicesWith(compareByDescending { selector(get(it)) }) -@OptIn(UnstableKMathAPI::class) public fun > Buffer.sortedByDescending(selector: (V) -> C): Buffer { val permutations = indicesSortedByDescending(selector) - return VirtualBuffer(size) { this[permutations[it]] } + return VirtualBuffer(type, size) { this[permutations[it]] } } @UnstableKMathAPI @@ -68,10 +66,9 @@ public fun Buffer.indicesSortedWith(comparator: Comparator): IntArray /** * Create virtual zero-copy buffer with elements sorted by [comparator] */ -@OptIn(UnstableKMathAPI::class) public fun Buffer.sortedWith(comparator: Comparator): Buffer { val permutations = indicesSortedWith(comparator) - return VirtualBuffer(size) { this[permutations[it]] } + return VirtualBuffer(type,size) { this[permutations[it]] } } private fun Buffer.permSortIndicesWith(comparator: Comparator): IntArray { diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/AlgebraND.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/AlgebraND.kt index 0611225c1..88aa52387 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/AlgebraND.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/AlgebraND.kt @@ -8,7 +8,7 @@ package space.kscience.kmath.nd import space.kscience.kmath.PerformancePitfall import space.kscience.kmath.UnstableKMathAPI import space.kscience.kmath.operations.* -import kotlin.reflect.KClass +import space.kscience.kmath.structures.MutableBufferFactory /** * The base interface for all ND-algebra implementations. @@ -91,6 +91,8 @@ public interface AlgebraND> : Algebra> { * @param A the type of group over structure elements. */ public interface GroupOpsND> : GroupOps>, AlgebraND { + override val bufferFactory: MutableBufferFactory> get() = MutableBufferFactory>() + /** * Element-wise addition. * diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/BufferAlgebraND.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/BufferAlgebraND.kt index 78bc83826..365d013f9 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/BufferAlgebraND.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/BufferAlgebraND.kt @@ -7,6 +7,8 @@ package space.kscience.kmath.nd +import space.kscience.attributes.SafeType +import space.kscience.attributes.safeTypeOf import space.kscience.kmath.PerformancePitfall import space.kscience.kmath.UnstableKMathAPI import space.kscience.kmath.operations.* @@ -105,6 +107,9 @@ public open class BufferedGroupNDOps>( override val bufferAlgebra: BufferAlgebra, override val indexerBuilder: (ShapeND) -> ShapeIndexer = BufferAlgebraND.defaultIndexerBuilder, ) : GroupOpsND, BufferAlgebraND { + + override val type: SafeType> get() = safeTypeOf>() + override fun StructureND.unaryMinus(): StructureND = map { -it } } diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/BufferND.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/BufferND.kt index 461a82e0a..1caf0c662 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/BufferND.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/BufferND.kt @@ -5,6 +5,7 @@ package space.kscience.kmath.nd +import space.kscience.attributes.SafeType import space.kscience.kmath.PerformancePitfall import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.BufferFactory @@ -23,6 +24,8 @@ public open class BufferND( public open val buffer: Buffer, ) : StructureND { + override val type: SafeType get() = buffer.type + @PerformancePitfall override operator fun get(index: IntArray): T = buffer[indices.offset(index)] @@ -36,7 +39,7 @@ public open class BufferND( */ public inline fun BufferND( shape: ShapeND, - bufferFactory: BufferFactory = BufferFactory.auto(), + bufferFactory: BufferFactory = BufferFactory(), crossinline initializer: (IntArray) -> T, ): BufferND { val strides = Strides(shape) @@ -84,10 +87,10 @@ public open class MutableBufferND( /** * Create a generic [BufferND] using provided [initializer] */ -public fun MutableBufferND( +public inline fun MutableBufferND( shape: ShapeND, - bufferFactory: MutableBufferFactory = MutableBufferFactory.boxing(), - initializer: (IntArray) -> T, + bufferFactory: MutableBufferFactory = MutableBufferFactory(), + crossinline initializer: (IntArray) -> T, ): MutableBufferND { val strides = Strides(shape) return MutableBufferND(strides, bufferFactory(strides.linearSize) { initializer(strides.index(it)) }) diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/PermutedStructureND.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/PermutedStructureND.kt index 6c35e2f44..56ccf2df8 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/PermutedStructureND.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/PermutedStructureND.kt @@ -5,6 +5,7 @@ package space.kscience.kmath.nd +import space.kscience.attributes.SafeType import space.kscience.kmath.PerformancePitfall @@ -13,6 +14,8 @@ public class PermutedStructureND( public val permutation: (IntArray) -> IntArray, ) : StructureND { + override val type: SafeType get() = origin.type + override val shape: ShapeND get() = origin.shape @@ -32,6 +35,7 @@ public class PermutedMutableStructureND( public val permutation: (IntArray) -> IntArray, ) : MutableStructureND { + override val type: SafeType get() = origin.type @OptIn(PerformancePitfall::class) override fun set(index: IntArray, value: T) { diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/Structure1D.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/Structure1D.kt index 984b5ad0f..ee9eff8dc 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/Structure1D.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/Structure1D.kt @@ -5,6 +5,7 @@ package space.kscience.kmath.nd +import space.kscience.attributes.SafeType import space.kscience.kmath.PerformancePitfall import space.kscience.kmath.operations.asSequence import space.kscience.kmath.structures.Buffer @@ -46,6 +47,9 @@ public interface MutableStructure1D : Structure1D, MutableStructureND, */ @JvmInline private value class Structure1DWrapper(val structure: StructureND) : Structure1D { + + override val type: SafeType get() = structure.type + override val shape: ShapeND get() = structure.shape override val size: Int get() = structure.shape[0] @@ -60,6 +64,9 @@ private value class Structure1DWrapper(val structure: StructureND) : S * A 1D wrapper for a mutable nd-structure */ private class MutableStructure1DWrapper(val structure: MutableStructureND) : MutableStructure1D { + + override val type: SafeType get() = structure.type + override val shape: ShapeND get() = structure.shape override val size: Int get() = structure.shape[0] @@ -79,7 +86,7 @@ private class MutableStructure1DWrapper(val structure: MutableStructureND) .elements() .map(Pair::second) .toMutableList() - .asMutableBuffer() + .asMutableBuffer(type) override fun toString(): String = Buffer.toString(this) } @@ -90,6 +97,9 @@ private class MutableStructure1DWrapper(val structure: MutableStructureND) */ @JvmInline private value class Buffer1DWrapper(val buffer: Buffer) : Structure1D { + + override val type: SafeType get() = buffer.type + override val shape: ShapeND get() = ShapeND(buffer.size) override val size: Int get() = buffer.size @@ -102,6 +112,9 @@ private value class Buffer1DWrapper(val buffer: Buffer) : Structure1D< } internal class MutableBuffer1DWrapper(val buffer: MutableBuffer) : MutableStructure1D { + + override val type: SafeType get() = buffer.type + override val shape: ShapeND get() = ShapeND(buffer.size) override val size: Int get() = buffer.size diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/Structure2D.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/Structure2D.kt index eff58acc3..49e8c7dd7 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/Structure2D.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/Structure2D.kt @@ -6,13 +6,12 @@ package space.kscience.kmath.nd import space.kscience.attributes.Attributes +import space.kscience.attributes.SafeType import space.kscience.kmath.PerformancePitfall import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.MutableBuffer -import space.kscience.kmath.structures.MutableListBuffer import space.kscience.kmath.structures.VirtualBuffer import kotlin.jvm.JvmInline -import kotlin.reflect.KClass /** * A structure that is guaranteed to be two-dimensional. @@ -33,18 +32,18 @@ public interface Structure2D : StructureND { override val shape: ShapeND get() = ShapeND(rowNum, colNum) /** - * The buffer of rows of this structure. It gets elements from the structure dynamically. + * The buffer of rows for this structure. It gets elements from the structure dynamically. */ @PerformancePitfall public val rows: List> - get() = List(rowNum) { i -> VirtualBuffer(colNum) { j -> get(i, j) } } + get() = List(rowNum) { i -> VirtualBuffer(type, colNum) { j -> get(i, j) } } /** - * The buffer of columns of this structure. It gets elements from the structure dynamically. + * The buffer of columns for this structure. It gets elements from the structure dynamically. */ @PerformancePitfall public val columns: List> - get() = List(colNum) { j -> VirtualBuffer(rowNum) { i -> get(i, j) } } + get() = List(colNum) { j -> VirtualBuffer(type, rowNum) { i -> get(i, j) } } /** * Retrieves an element from the structure by two indices. @@ -88,14 +87,14 @@ public interface MutableStructure2D : Structure2D, MutableStructureND { */ @PerformancePitfall override val rows: List> - get() = List(rowNum) { i -> MutableListBuffer(colNum) { j -> get(i, j) } } + get() = List(rowNum) { i -> MutableBuffer(type, colNum) { j -> get(i, j) } } /** - * The buffer of columns of this structure. It gets elements from the structure dynamically. + * The buffer of columns for this structure. It gets elements from the structure dynamically. */ @PerformancePitfall override val columns: List> - get() = List(colNum) { j -> MutableListBuffer(rowNum) { i -> get(i, j) } } + get() = List(colNum) { j -> MutableBuffer(type, rowNum) { i -> get(i, j) } } } /** @@ -103,6 +102,9 @@ public interface MutableStructure2D : Structure2D, MutableStructureND { */ @JvmInline private value class Structure2DWrapper(val structure: StructureND) : Structure2D { + + override val type: SafeType get() = structure.type + override val shape: ShapeND get() = structure.shape override val rowNum: Int get() = shape[0] @@ -122,6 +124,9 @@ private value class Structure2DWrapper(val structure: StructureND) : S * A 2D wrapper for a mutable nd-structure */ private class MutableStructure2DWrapper(val structure: MutableStructureND) : MutableStructure2D { + + override val type: SafeType get() = structure.type + override val shape: ShapeND get() = structure.shape override val rowNum: Int get() = shape[0] diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/StructureND.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/StructureND.kt index bc86f62ae..d0d078609 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/StructureND.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/StructureND.kt @@ -12,6 +12,7 @@ import space.kscience.attributes.SafeType import space.kscience.kmath.PerformancePitfall import space.kscience.kmath.linear.LinearSpace import space.kscience.kmath.operations.Ring +import space.kscience.kmath.operations.WithType import space.kscience.kmath.operations.invoke import space.kscience.kmath.structures.Buffer import kotlin.jvm.JvmName @@ -28,9 +29,9 @@ public interface StructureAttribute : Attribute * * @param T the type of items. */ -public interface StructureND : AttributeContainer, WithShape { +public interface StructureND : AttributeContainer, WithShape, WithType { /** - * The shape of structure i.e., non-empty sequence of non-negative integers that specify sizes of dimensions of + * The shape of structure i.e., non-empty sequence of non-negative integers that specify sizes of dimensions for * this structure. */ override val shape: ShapeND @@ -117,57 +118,61 @@ public interface StructureND : AttributeContainer, WithShape { return "$className(shape=${structure.shape}, buffer=$bufferRepr)" } - /** - * Creates a NDStructure with explicit buffer factory. - * - * Strides should be reused if possible. - */ - public fun buffered( - strides: Strides, - initializer: (IntArray) -> T, - ): BufferND = BufferND(strides, Buffer.boxing(strides.linearSize) { i -> initializer(strides.index(i)) }) - - - public fun buffered( - shape: ShapeND, - initializer: (IntArray) -> T, - ): BufferND = buffered(ColumnStrides(shape), initializer) - - /** - * Inline create NDStructure with non-boxing buffer implementation if it is possible - */ - public inline fun auto( - strides: Strides, - crossinline initializer: (IntArray) -> T, - ): BufferND = BufferND(strides, Buffer.auto(strides.linearSize) { i -> initializer(strides.index(i)) }) - - public inline fun auto( - type: SafeType, - strides: Strides, - crossinline initializer: (IntArray) -> T, - ): BufferND = BufferND(strides, Buffer.auto(type, strides.linearSize) { i -> initializer(strides.index(i)) }) - - - public inline fun auto( - shape: ShapeND, - crossinline initializer: (IntArray) -> T, - ): BufferND = auto(ColumnStrides(shape), initializer) - - @JvmName("autoVarArg") - public inline fun auto( - vararg shape: Int, - crossinline initializer: (IntArray) -> T, - ): BufferND = - auto(ColumnStrides(ShapeND(shape)), initializer) - - public inline fun auto( - type: SafeType, - vararg shape: Int, - crossinline initializer: (IntArray) -> T, - ): BufferND = auto(type, ColumnStrides(ShapeND(shape)), initializer) } } + +/** + * Creates a NDStructure with explicit buffer factory. + * + * Strides should be reused if possible. + */ +public fun BufferND( + type: SafeType, + strides: Strides, + initializer: (IntArray) -> T, +): BufferND = BufferND(strides, Buffer(type, strides.linearSize) { i -> initializer(strides.index(i)) }) + + +public fun BufferND( + type: SafeType, + shape: ShapeND, + initializer: (IntArray) -> T, +): BufferND = BufferND(type, ColumnStrides(shape), initializer) + +/** + * Inline create NDStructure with non-boxing buffer implementation if it is possible + */ +public inline fun BufferND( + strides: Strides, + crossinline initializer: (IntArray) -> T, +): BufferND = BufferND(strides, Buffer(strides.linearSize) { i -> initializer(strides.index(i)) }) + +public inline fun BufferND( + type: SafeType, + strides: Strides, + crossinline initializer: (IntArray) -> T, +): BufferND = BufferND(strides, Buffer(type, strides.linearSize) { i -> initializer(strides.index(i)) }) + + +public inline fun BufferND( + shape: ShapeND, + crossinline initializer: (IntArray) -> T, +): BufferND = BufferND(ColumnStrides(shape), initializer) + +@JvmName("autoVarArg") +public inline fun BufferND( + vararg shape: Int, + crossinline initializer: (IntArray) -> T, +): BufferND = BufferND(ColumnStrides(ShapeND(shape)), initializer) + +public inline fun BufferND( + type: SafeType, + vararg shape: Int, + crossinline initializer: (IntArray) -> T, +): BufferND = BufferND(type, ColumnStrides(ShapeND(shape)), initializer) + + /** * Indicates whether some [StructureND] is equal to another one. */ diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/VirtualStructureND.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/VirtualStructureND.kt index 606b9a631..824d82b06 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/VirtualStructureND.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/VirtualStructureND.kt @@ -5,10 +5,13 @@ package space.kscience.kmath.nd +import space.kscience.attributes.SafeType +import space.kscience.attributes.safeTypeOf import space.kscience.kmath.PerformancePitfall import space.kscience.kmath.UnstableKMathAPI public open class VirtualStructureND( + override val type: SafeType, override val shape: ShapeND, public val producer: (IntArray) -> T, ) : StructureND { @@ -24,10 +27,10 @@ public open class VirtualStructureND( public class VirtualDoubleStructureND( shape: ShapeND, producer: (IntArray) -> Double, -) : VirtualStructureND(shape, producer) +) : VirtualStructureND(safeTypeOf(), shape, producer) @UnstableKMathAPI public class VirtualIntStructureND( shape: ShapeND, producer: (IntArray) -> Int, -) : VirtualStructureND(shape, producer) \ No newline at end of file +) : VirtualStructureND(safeTypeOf(), shape, producer) \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/operationsND.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/operationsND.kt index 40db5187f..0376da116 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/operationsND.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/operationsND.kt @@ -10,7 +10,7 @@ import space.kscience.kmath.PerformancePitfall @OptIn(PerformancePitfall::class) public fun StructureND.roll(axis: Int, step: Int = 1): StructureND { require(axis in shape.indices) { "Axis $axis is outside of shape dimensions: [0, ${shape.size})" } - return VirtualStructureND(shape) { index -> + return VirtualStructureND(type, shape) { index -> val newIndex: IntArray = IntArray(index.size) { indexAxis -> if (indexAxis == axis) { (index[indexAxis] + step).mod(shape[indexAxis]) @@ -26,7 +26,7 @@ public fun StructureND.roll(axis: Int, step: Int = 1): StructureND { public fun StructureND.roll(pair: Pair, vararg others: Pair): StructureND { val axisMap: Map = mapOf(pair, *others) require(axisMap.keys.all { it in shape.indices }) { "Some of axes ${axisMap.keys} is outside of shape dimensions: [0, ${shape.size})" } - return VirtualStructureND(shape) { index -> + return VirtualStructureND(type, shape) { index -> val newIndex: IntArray = IntArray(index.size) { indexAxis -> val offset = axisMap[indexAxis] ?: 0 (index[indexAxis] + offset).mod(shape[indexAxis]) diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/Algebra.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/Algebra.kt index ca8a4f59d..44e9a0b17 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/Algebra.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/Algebra.kt @@ -5,22 +5,32 @@ package space.kscience.kmath.operations +import space.kscience.attributes.SafeType import space.kscience.kmath.UnstableKMathAPI import space.kscience.kmath.expressions.Symbol import space.kscience.kmath.operations.Ring.Companion.optimizedPower import space.kscience.kmath.structures.MutableBufferFactory -import kotlin.reflect.KType + +/** + * An interface containing [type] for dynamic type checking. + */ +public interface WithType { + public val type: SafeType +} + /** * Represents an algebraic structure. * * @param T the type of element which Algebra operates on. */ -public interface Algebra { +public interface Algebra : WithType { /** * Provide a factory for buffers, associated with this [Algebra] */ - public val bufferFactory: MutableBufferFactory get() = MutableBufferFactory.boxing() + public val bufferFactory: MutableBufferFactory + + override val type: SafeType get() = bufferFactory.type /** * Wraps a raw string to [T] object. This method is designed for three purposes: @@ -265,7 +275,7 @@ public interface Ring : Group, RingOps { */ public fun power(arg: T, pow: UInt): T = optimizedPower(arg, pow) - public companion object{ + public companion object { /** * Raises [arg] to the non-negative integer power [exponent]. * @@ -348,7 +358,7 @@ public interface Field : Ring, FieldOps, ScaleOperations, NumericAlg public fun power(arg: T, pow: Int): T = optimizedPower(arg, pow) - public companion object{ + public companion object { /** * Raises [arg] to the integer power [exponent]. * @@ -361,7 +371,11 @@ public interface Field : Ring, FieldOps, ScaleOperations, NumericAlg * @author Iaroslav Postovalov, Evgeniy Zhelenskiy */ private fun Field.optimizedPower(arg: T, exponent: Int): T = when { - exponent < 0 -> one / (this as Ring).optimizedPower(arg, if (exponent == Int.MIN_VALUE) Int.MAX_VALUE.toUInt().inc() else (-exponent).toUInt()) + exponent < 0 -> one / (this as Ring).optimizedPower( + arg, + if (exponent == Int.MIN_VALUE) Int.MAX_VALUE.toUInt().inc() else (-exponent).toUInt() + ) + else -> (this as Ring).optimizedPower(arg, exponent.toUInt()) } } diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/BigInt.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/BigInt.kt index 34a6d4a80..25554477c 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/BigInt.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/BigInt.kt @@ -5,6 +5,8 @@ package space.kscience.kmath.operations +import space.kscience.attributes.SafeType +import space.kscience.attributes.safeTypeOf import space.kscience.kmath.UnstableKMathAPI import space.kscience.kmath.nd.BufferedRingOpsND import space.kscience.kmath.operations.BigInt.Companion.BASE @@ -26,6 +28,9 @@ private typealias TBase = ULong */ @OptIn(UnstableKMathAPI::class) public object BigIntField : Field, NumbersAddOps, ScaleOperations { + + override val type: SafeType = safeTypeOf() + override val zero: BigInt = BigInt.ZERO override val one: BigInt = BigInt.ONE @@ -528,10 +533,10 @@ public fun String.parseBigInteger(): BigInt? { public val BigInt.algebra: BigIntField get() = BigIntField public inline fun BigInt.Companion.buffer(size: Int, initializer: (Int) -> BigInt): Buffer = - Buffer.boxing(size, initializer) + Buffer(size, initializer) public inline fun BigInt.Companion.mutableBuffer(size: Int, initializer: (Int) -> BigInt): Buffer = - Buffer.boxing(size, initializer) + Buffer(size, initializer) public val BigIntField.nd: BufferedRingOpsND get() = BufferedRingOpsND(BufferRingOps(BigIntField)) diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/BufferAlgebra.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/BufferAlgebra.kt index 526d024ca..5a8b26571 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/BufferAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/BufferAlgebra.kt @@ -5,10 +5,11 @@ package space.kscience.kmath.operations +import space.kscience.attributes.SafeType +import space.kscience.attributes.safeTypeOf import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.MutableBuffer import space.kscience.kmath.structures.MutableBufferFactory -import kotlin.reflect.KType public interface WithSize { public val size: Int @@ -20,6 +21,8 @@ public interface WithSize { public interface BufferAlgebra> : Algebra> { public val elementAlgebra: A + override val type: SafeType> get() = safeTypeOf>() + public val elementBufferFactory: MutableBufferFactory get() = elementAlgebra.bufferFactory public fun buffer(size: Int, vararg elements: T): Buffer { diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/LogicAlgebra.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/LogicAlgebra.kt index d8bf0fb57..312fd1396 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/LogicAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/LogicAlgebra.kt @@ -5,6 +5,8 @@ package space.kscience.kmath.operations +import space.kscience.attributes.SafeType +import space.kscience.attributes.safeTypeOf import space.kscience.kmath.UnstableKMathAPI import space.kscience.kmath.expressions.Symbol @@ -72,6 +74,8 @@ public interface LogicAlgebra : Algebra { @Suppress("EXTENSION_SHADOWED_BY_MEMBER") public object BooleanAlgebra : LogicAlgebra { + override val type: SafeType get() = safeTypeOf() + override fun const(boolean: Boolean): Boolean = boolean override fun Boolean.not(): Boolean = !this diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/integerFields.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/integerFields.kt index 72e33c523..a3a819a9b 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/integerFields.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/integerFields.kt @@ -8,7 +8,10 @@ package space.kscience.kmath.operations import space.kscience.kmath.operations.Int16Field.div import space.kscience.kmath.operations.Int32Field.div import space.kscience.kmath.operations.Int64Field.div -import space.kscience.kmath.structures.* +import space.kscience.kmath.structures.Int16 +import space.kscience.kmath.structures.Int32 +import space.kscience.kmath.structures.Int64 +import space.kscience.kmath.structures.MutableBufferFactory import kotlin.math.roundToInt import kotlin.math.roundToLong @@ -22,7 +25,7 @@ import kotlin.math.roundToLong */ @Suppress("EXTENSION_SHADOWED_BY_MEMBER") public object Int16Field : Field, Norm, NumericAlgebra { - override val bufferFactory: MutableBufferFactory = MutableBufferFactory(::Int16Buffer) + override val bufferFactory: MutableBufferFactory = MutableBufferFactory() override val zero: Int16 get() = 0 override val one: Int16 get() = 1 @@ -45,7 +48,8 @@ public object Int16Field : Field, Norm, NumericAlgebra, Norm, NumericAlgebra { - override val bufferFactory: MutableBufferFactory = MutableBufferFactory(::Int32Buffer) + override val bufferFactory: MutableBufferFactory = MutableBufferFactory() + override val zero: Int get() = 0 override val one: Int get() = 1 @@ -68,7 +72,7 @@ public object Int32Field : Field, Norm, NumericAlgebra, Norm, NumericAlgebra { - override val bufferFactory: MutableBufferFactory = MutableBufferFactory(::Int64Buffer) + override val bufferFactory: MutableBufferFactory = MutableBufferFactory() override val zero: Int64 get() = 0L override val one: Int64 get() = 1L diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/numbers.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/numbers.kt index 0b7bef852..7d8e12139 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/numbers.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/numbers.kt @@ -67,7 +67,7 @@ public interface ExtendedField : ExtendedFieldOps, Field, NumericAlgebr */ @Suppress("EXTENSION_SHADOWED_BY_MEMBER") public object Float64Field : ExtendedField, Norm, ScaleOperations { - override val bufferFactory: MutableBufferFactory = MutableBufferFactory(::Float64Buffer) + override val bufferFactory: MutableBufferFactory = MutableBufferFactory() override val zero: Double get() = 0.0 override val one: Double get() = 1.0 diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/Buffer.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/Buffer.kt index 91baac2bf..c3dbfbfe2 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/Buffer.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/Buffer.kt @@ -8,8 +8,8 @@ package space.kscience.kmath.structures import space.kscience.attributes.SafeType import space.kscience.attributes.safeTypeOf import space.kscience.kmath.operations.WithSize +import space.kscience.kmath.operations.WithType import space.kscience.kmath.operations.asSequence -import kotlin.jvm.JvmInline import kotlin.reflect.typeOf /** @@ -17,35 +17,46 @@ import kotlin.reflect.typeOf * * @param T the type of buffer. */ -public fun interface BufferFactory { +public interface BufferFactory : WithType { + public operator fun invoke(size: Int, builder: (Int) -> T): Buffer - - public companion object{ - public inline fun auto(): BufferFactory = - BufferFactory(Buffer.Companion::auto) - - public fun boxing(): BufferFactory = - BufferFactory(Buffer.Companion::boxing) - } } +/** + * Create a [BufferFactory] for given [type], using primitive storage if possible + */ +public fun BufferFactory(type: SafeType): BufferFactory = object : BufferFactory { + override val type: SafeType = type + override fun invoke(size: Int, builder: (Int) -> T): Buffer = Buffer(type, size, builder) +} + +/** + * Create [BufferFactory] using the reified type + */ +public inline fun BufferFactory(): BufferFactory = BufferFactory(safeTypeOf()) + /** * Function that produces [MutableBuffer] from its size and function that supplies values. * * @param T the type of buffer. */ -public fun interface MutableBufferFactory : BufferFactory { +public interface MutableBufferFactory : BufferFactory { override fun invoke(size: Int, builder: (Int) -> T): MutableBuffer - - public companion object { - public inline fun auto(): MutableBufferFactory = - MutableBufferFactory(MutableBuffer.Companion::auto) - - public fun boxing(): MutableBufferFactory = - MutableBufferFactory(MutableBuffer.Companion::boxing) - } } +/** + * Create a [MutableBufferFactory] for given [type], using primitive storage if possible + */ +public fun MutableBufferFactory(type: SafeType): MutableBufferFactory = object : MutableBufferFactory { + override val type: SafeType = type + override fun invoke(size: Int, builder: (Int) -> T): MutableBuffer = MutableBuffer(type, size, builder) +} + +/** + * Create [BufferFactory] using the reified type + */ +public inline fun MutableBufferFactory(): MutableBufferFactory = MutableBufferFactory(safeTypeOf()) + /** * A generic read-only random-access structure for both primitives and objects. * @@ -53,14 +64,14 @@ public fun interface MutableBufferFactory : BufferFactory { * * @param T the type of elements contained in the buffer. */ -public interface Buffer : WithSize { +public interface Buffer : WithSize, WithType { /** * The size of this buffer. */ override val size: Int /** - * Gets element at given index. + * Gets an element at given index. */ public operator fun get(index: Int): T @@ -87,41 +98,48 @@ public interface Buffer : WithSize { return true } - /** - * Creates a [ListBuffer] of given type [T] with given [size]. Each element is calculated by calling the - * specified [initializer] function. - */ - public inline fun boxing(size: Int, initializer: (Int) -> T): Buffer = - List(size, initializer).asBuffer() - - /** - * Creates a [Buffer] of given [type]. If the type is primitive, specialized buffers are used ([Int32Buffer], - * [Float64Buffer], etc.), [ListBuffer] is returned otherwise. - * - * The [size] is specified, and each element is calculated by calling the specified [initializer] function. - */ - @Suppress("UNCHECKED_CAST") - public inline fun auto(type: SafeType, size: Int, initializer: (Int) -> T): Buffer = - when (type.kType) { - typeOf() -> MutableBuffer.double(size) { initializer(it) as Double } as Buffer - typeOf() -> MutableBuffer.short(size) { initializer(it) as Short } as Buffer - typeOf() -> MutableBuffer.int(size) { initializer(it) as Int } as Buffer - typeOf() -> MutableBuffer.long(size) { initializer(it) as Long } as Buffer - typeOf() -> MutableBuffer.float(size) { initializer(it) as Float } as Buffer - else -> boxing(size, initializer) - } - - /** - * Creates a [Buffer] of given type [T]. If the type is primitive, specialized buffers are used ([Int32Buffer], - * [Float64Buffer], etc.), [ListBuffer] is returned otherwise. - * - * The [size] is specified, and each element is calculated by calling the specified [initializer] function. - */ - public inline fun auto(size: Int, initializer: (Int) -> T): Buffer = - auto(safeTypeOf(), size, initializer) } } +/** + * Creates a [Buffer] of given type [T]. If the type is primitive, specialized buffers are used ([Int32Buffer], + * [Float64Buffer], etc.), [ListBuffer] is returned otherwise. + * + * The [size] is specified, and each element is calculated by calling the specified [initializer] function. + */ +@Suppress("UNCHECKED_CAST") +public inline fun Buffer(size: Int, initializer: (Int) -> T): Buffer { + val type = safeTypeOf() + return when (type.kType) { + typeOf() -> MutableBuffer.double(size) { initializer(it) as Double } as Buffer + typeOf() -> MutableBuffer.short(size) { initializer(it) as Short } as Buffer + typeOf() -> MutableBuffer.int(size) { initializer(it) as Int } as Buffer + typeOf() -> MutableBuffer.long(size) { initializer(it) as Long } as Buffer + typeOf() -> MutableBuffer.float(size) { initializer(it) as Float } as Buffer + else -> List(size, initializer).asBuffer(type) + } +} + +/** + * Creates a [Buffer] of given [type]. If the type is primitive, specialized buffers are used ([Int32Buffer], + * [Float64Buffer], etc.), [ListBuffer] is returned otherwise. + * + * The [size] is specified, and each element is calculated by calling the specified [initializer] function. + */ +@Suppress("UNCHECKED_CAST") +public fun Buffer( + type: SafeType, + size: Int, + initializer: (Int) -> T, +): Buffer = when (type.kType) { + typeOf() -> MutableBuffer.double(size) { initializer(it) as Double } as Buffer + typeOf() -> MutableBuffer.short(size) { initializer(it) as Short } as Buffer + typeOf() -> MutableBuffer.int(size) { initializer(it) as Int } as Buffer + typeOf() -> MutableBuffer.long(size) { initializer(it) as Long } as Buffer + typeOf() -> MutableBuffer.float(size) { initializer(it) as Float } as Buffer + else -> List(size, initializer).asBuffer(type) +} + /** * Returns an [IntRange] of the valid indices for this [Buffer]. */ @@ -144,28 +162,17 @@ public fun Buffer.last(): T { return get(size - 1) } -/** - * Immutable wrapper for [MutableBuffer]. - * - * @param T the type of elements contained in the buffer. - * @property buffer The underlying buffer. - */ -@JvmInline -public value class ReadOnlyBuffer(public val buffer: MutableBuffer) : Buffer { - override val size: Int get() = buffer.size - - override operator fun get(index: Int): T = buffer[index] - - override operator fun iterator(): Iterator = buffer.iterator() -} - /** * A buffer with content calculated on-demand. The calculated content is not stored, so it is recalculated on each call. - * Useful when one needs single element from the buffer. + * Useful when one needs a single element from the buffer. * * @param T the type of elements provided by the buffer. */ -public class VirtualBuffer(override val size: Int, private val generator: (Int) -> T) : Buffer { +public class VirtualBuffer( + override val type: SafeType, + override val size: Int, + private val generator: (Int) -> T, +) : Buffer { override operator fun get(index: Int): T { if (index < 0 || index >= size) throw IndexOutOfBoundsException("Expected index from 0 to ${size - 1}, but found $index") return generator(index) @@ -177,6 +184,9 @@ public class VirtualBuffer(override val size: Int, private val generator: } /** - * Convert this buffer to read-only buffer. + * Inline builder for [VirtualBuffer] */ -public fun Buffer.asReadOnly(): Buffer = if (this is MutableBuffer) ReadOnlyBuffer(this) else this \ No newline at end of file +public inline fun VirtualBuffer( + size: Int, + noinline generator: (Int) -> T, +): VirtualBuffer = VirtualBuffer(safeTypeOf(), size, generator) diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/BufferAccessor2D.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/BufferAccessor2D.kt index 3edd7d3b1..a2f01137c 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/BufferAccessor2D.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/BufferAccessor2D.kt @@ -5,7 +5,13 @@ package space.kscience.kmath.structures -import space.kscience.kmath.nd.* +import space.kscience.attributes.SafeType +import space.kscience.kmath.nd.BufferND +import space.kscience.kmath.nd.ShapeND +import space.kscience.kmath.nd.Structure2D +import space.kscience.kmath.nd.as2D +import kotlin.collections.component1 +import kotlin.collections.component2 /** * A context that allows to operate on a [MutableBuffer] as on 2d array @@ -27,11 +33,13 @@ internal class BufferAccessor2D( fun create(mat: Structure2D): MutableBuffer = create { i, j -> mat[i, j] } //TODO optimize wrapper - fun MutableBuffer.toStructure2D(): Structure2D = StructureND.buffered( - ColumnStrides(ShapeND(rowNum, colNum)) + fun MutableBuffer.toStructure2D(): Structure2D = BufferND( + type, ShapeND(rowNum, colNum) ) { (i, j) -> get(i, j) }.as2D() inner class Row(val buffer: MutableBuffer, val rowIndex: Int) : MutableBuffer { + override val type: SafeType get() = buffer.type + override val size: Int get() = colNum override operator fun get(index: Int): T = buffer[rowIndex, index] diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/FlaggedBuffer.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/FlaggedBuffer.kt index d99e02996..10419a074 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/FlaggedBuffer.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/FlaggedBuffer.kt @@ -5,6 +5,8 @@ package space.kscience.kmath.structures +import space.kscience.attributes.SafeType +import space.kscience.attributes.safeTypeOf import kotlin.experimental.and /** @@ -57,6 +59,9 @@ public class FlaggedDoubleBuffer( public val values: DoubleArray, public val flags: ByteArray ) : FlaggedBuffer, Buffer { + + override val type: SafeType = safeTypeOf() + init { require(values.size == flags.size) { "Values and flags must have the same dimensions" } } diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/Float32Buffer.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/Float32Buffer.kt index 44ef4dcf0..b8fbe3436 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/Float32Buffer.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/Float32Buffer.kt @@ -5,6 +5,8 @@ package space.kscience.kmath.structures +import space.kscience.attributes.SafeType +import space.kscience.attributes.safeTypeOf import kotlin.jvm.JvmInline /** @@ -15,6 +17,9 @@ import kotlin.jvm.JvmInline */ @JvmInline public value class Float32Buffer(public val array: FloatArray) : PrimitiveBuffer { + + override val type: SafeType get() = safeTypeOf() + override val size: Int get() = array.size override operator fun get(index: Int): Float = array[index] diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/Float64Buffer.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/Float64Buffer.kt index 0542c1bf4..5837adf83 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/Float64Buffer.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/Float64Buffer.kt @@ -5,6 +5,8 @@ package space.kscience.kmath.structures +import space.kscience.attributes.SafeType +import space.kscience.attributes.safeTypeOf import space.kscience.kmath.operations.BufferTransform import kotlin.jvm.JvmInline @@ -15,6 +17,9 @@ import kotlin.jvm.JvmInline */ @JvmInline public value class Float64Buffer(public val array: DoubleArray) : PrimitiveBuffer { + + override val type: SafeType get() = safeTypeOf() + override val size: Int get() = array.size override operator fun get(index: Int): Double = array[index] diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/Int16Buffer.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/Int16Buffer.kt index 1ba40c934..cf3af61a7 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/Int16Buffer.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/Int16Buffer.kt @@ -5,6 +5,8 @@ package space.kscience.kmath.structures +import space.kscience.attributes.SafeType +import space.kscience.attributes.safeTypeOf import kotlin.jvm.JvmInline /** @@ -14,6 +16,8 @@ import kotlin.jvm.JvmInline */ @JvmInline public value class Int16Buffer(public val array: ShortArray) : MutableBuffer { + + override val type: SafeType get() = safeTypeOf() override val size: Int get() = array.size override operator fun get(index: Int): Short = array[index] diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/Int32Buffer.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/Int32Buffer.kt index afd94e9cd..365ec8688 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/Int32Buffer.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/Int32Buffer.kt @@ -5,6 +5,8 @@ package space.kscience.kmath.structures +import space.kscience.attributes.SafeType +import space.kscience.attributes.safeTypeOf import kotlin.jvm.JvmInline /** @@ -14,6 +16,9 @@ import kotlin.jvm.JvmInline */ @JvmInline public value class Int32Buffer(public val array: IntArray) : PrimitiveBuffer { + + override val type: SafeType get() = safeTypeOf() + override val size: Int get() = array.size override operator fun get(index: Int): Int = array[index] diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/Int64Buffer.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/Int64Buffer.kt index c67d109aa..5f831ed3e 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/Int64Buffer.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/Int64Buffer.kt @@ -5,6 +5,8 @@ package space.kscience.kmath.structures +import space.kscience.attributes.SafeType +import space.kscience.attributes.safeTypeOf import kotlin.jvm.JvmInline /** @@ -14,6 +16,9 @@ import kotlin.jvm.JvmInline */ @JvmInline public value class Int64Buffer(public val array: LongArray) : PrimitiveBuffer { + + override val type: SafeType get() = safeTypeOf() + override val size: Int get() = array.size override operator fun get(index: Int): Long = array[index] diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/Int8Buffer.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/Int8Buffer.kt index 923fbec06..e61056979 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/Int8Buffer.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/Int8Buffer.kt @@ -5,6 +5,8 @@ package space.kscience.kmath.structures +import space.kscience.attributes.SafeType +import space.kscience.attributes.safeTypeOf import kotlin.jvm.JvmInline /** @@ -14,6 +16,9 @@ import kotlin.jvm.JvmInline */ @JvmInline public value class Int8Buffer(public val array: ByteArray) : MutableBuffer { + + override val type: SafeType get() = safeTypeOf() + override val size: Int get() = array.size override operator fun get(index: Int): Byte = array[index] diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/ListBuffer.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/ListBuffer.kt index fbc9a489b..ad065a0c3 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/ListBuffer.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/ListBuffer.kt @@ -5,7 +5,8 @@ package space.kscience.kmath.structures -import kotlin.jvm.JvmInline +import space.kscience.attributes.SafeType +import space.kscience.attributes.safeTypeOf /** * [Buffer] implementation over [List]. @@ -13,9 +14,7 @@ import kotlin.jvm.JvmInline * @param T the type of elements contained in the buffer. * @property list The underlying list. */ -public class ListBuffer(public val list: List) : Buffer { - - public constructor(size: Int, initializer: (Int) -> T) : this(List(size, initializer)) +public class ListBuffer(override val type: SafeType, public val list: List) : Buffer { override val size: Int get() = list.size @@ -29,7 +28,12 @@ public class ListBuffer(public val list: List) : Buffer { /** * Returns an [ListBuffer] that wraps the original list. */ -public fun List.asBuffer(): ListBuffer = ListBuffer(this) +public fun List.asBuffer(type: SafeType): ListBuffer = ListBuffer(type, this) + +/** + * Returns an [ListBuffer] that wraps the original list. + */ +public inline fun List.asBuffer(): ListBuffer = asBuffer(safeTypeOf()) /** * [MutableBuffer] implementation over [MutableList]. @@ -37,10 +41,7 @@ public fun List.asBuffer(): ListBuffer = ListBuffer(this) * @param T the type of elements contained in the buffer. * @property list The underlying list. */ -@JvmInline -public value class MutableListBuffer(public val list: MutableList) : MutableBuffer { - - public constructor(size: Int, initializer: (Int) -> T) : this(MutableList(size, initializer)) +public class MutableListBuffer(override val type: SafeType, public val list: MutableList) : MutableBuffer { override val size: Int get() = list.size @@ -51,10 +52,24 @@ public value class MutableListBuffer(public val list: MutableList) : Mutab } override operator fun iterator(): Iterator = list.iterator() - override fun copy(): MutableBuffer = MutableListBuffer(ArrayList(list)) + override fun copy(): MutableBuffer = MutableListBuffer(type, ArrayList(list)) + + override fun toString(): String = Buffer.toString(this) } + /** * Returns an [MutableListBuffer] that wraps the original list. */ -public fun MutableList.asMutableBuffer(): MutableListBuffer = MutableListBuffer(this) +public fun MutableList.asMutableBuffer(type: SafeType): MutableListBuffer = MutableListBuffer( + type, + this +) + +/** + * Returns an [MutableListBuffer] that wraps the original list. + */ +public inline fun MutableList.asMutableBuffer(): MutableListBuffer = MutableListBuffer( + safeTypeOf(), + this +) diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/MutableBuffer.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/MutableBuffer.kt index eb0074734..21d3f9c6f 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/MutableBuffer.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/MutableBuffer.kt @@ -61,41 +61,35 @@ public interface MutableBuffer : Buffer { */ public inline fun float(size: Int, initializer: (Int) -> Float): Float32Buffer = Float32Buffer(size, initializer) - - - /** - * Create a boxing mutable buffer of given type - */ - public inline fun boxing(size: Int, initializer: (Int) -> T): MutableBuffer = - MutableListBuffer(MutableList(size, initializer)) - - /** - * Creates a [MutableBuffer] of given [type]. If the type is primitive, specialized buffers are used - * ([Int32Buffer], [Float64Buffer], etc.), [ListBuffer] is returned otherwise. - * - * The [size] is specified, and each element is calculated by calling the specified [initializer] function. - */ - @Suppress("UNCHECKED_CAST") - public inline fun auto(type: SafeType, size: Int, initializer: (Int) -> T): MutableBuffer = - when (type.kType) { - typeOf() -> double(size) { initializer(it) as Double } as MutableBuffer - typeOf() -> short(size) { initializer(it) as Short } as MutableBuffer - typeOf() -> int(size) { initializer(it) as Int } as MutableBuffer - typeOf() -> float(size) { initializer(it) as Float } as MutableBuffer - typeOf() -> long(size) { initializer(it) as Long } as MutableBuffer - else -> boxing(size, initializer) - } - - /** - * Creates a [MutableBuffer] of given type [T]. If the type is primitive, specialized buffers are used - * ([Int32Buffer], [Float64Buffer], etc.), [ListBuffer] is returned otherwise. - * - * The [size] is specified, and each element is calculated by calling the specified [initializer] function. - */ - public inline fun auto(size: Int, initializer: (Int) -> T): MutableBuffer = - auto(safeTypeOf(), size, initializer) } } +/** + * Creates a [MutableBuffer] of given [type]. If the type is primitive, specialized buffers are used + * ([Int32Buffer], [Float64Buffer], etc.), [ListBuffer] is returned otherwise. + * + * The [size] is specified, and each element is calculated by calling the specified [initializer] function. + */ +@Suppress("UNCHECKED_CAST") +public inline fun MutableBuffer(type: SafeType, size: Int, initializer: (Int) -> T): MutableBuffer = + when (type.kType) { + typeOf() -> MutableBuffer.double(size) { initializer(it) as Double } as MutableBuffer + typeOf() -> MutableBuffer.short(size) { initializer(it) as Short } as MutableBuffer + typeOf() -> MutableBuffer.int(size) { initializer(it) as Int } as MutableBuffer + typeOf() -> MutableBuffer.float(size) { initializer(it) as Float } as MutableBuffer + typeOf() -> MutableBuffer.long(size) { initializer(it) as Long } as MutableBuffer + else -> MutableListBuffer(type, MutableList(size, initializer)) + } + +/** + * Creates a [MutableBuffer] of given type [T]. If the type is primitive, specialized buffers are used + * ([Int32Buffer], [Float64Buffer], etc.), [ListBuffer] is returned otherwise. + * + * The [size] is specified, and each element is calculated by calling the specified [initializer] function. + */ +public inline fun MutableBuffer(size: Int, initializer: (Int) -> T): MutableBuffer = + MutableBuffer(safeTypeOf(), size, initializer) + + public sealed interface PrimitiveBuffer : MutableBuffer \ No newline at end of file diff --git a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/linear/DoubleLUSolverTest.kt b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/linear/DoubleLUSolverTest.kt index 4d05f9043..17b6d69b8 100644 --- a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/linear/DoubleLUSolverTest.kt +++ b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/linear/DoubleLUSolverTest.kt @@ -40,7 +40,7 @@ class DoubleLUSolverTest { //Check determinant assertEquals(7.0, lup.determinant) - assertMatrixEquals(lup.p dot matrix, lup.l dot lup.u) + assertMatrixEquals(lup.pivotMatrix dot matrix, lup.l dot lup.u) } @Test diff --git a/kmath-ejml/src/main/kotlin/space/kscience/kmath/ejml/_generated.kt b/kmath-ejml/src/main/kotlin/space/kscience/kmath/ejml/_generated.kt index 0cef96f49..881649d01 100644 --- a/kmath-ejml/src/main/kotlin/space/kscience/kmath/ejml/_generated.kt +++ b/kmath-ejml/src/main/kotlin/space/kscience/kmath/ejml/_generated.kt @@ -19,19 +19,13 @@ import org.ejml.sparse.csc.factory.DecompositionFactory_DSCC import org.ejml.sparse.csc.factory.DecompositionFactory_FSCC import org.ejml.sparse.csc.factory.LinearSolverFactory_DSCC import org.ejml.sparse.csc.factory.LinearSolverFactory_FSCC +import space.kscience.kmath.UnstableKMathAPI import space.kscience.kmath.linear.* import space.kscience.kmath.linear.Matrix -import space.kscience.kmath.UnstableKMathAPI import space.kscience.kmath.nd.StructureFeature -import space.kscience.kmath.structures.Float64 -import space.kscience.kmath.structures.Float32 -import space.kscience.kmath.operations.Float64Field import space.kscience.kmath.operations.Float32Field -import space.kscience.kmath.operations.DoubleField -import space.kscience.kmath.operations.FloatField +import space.kscience.kmath.operations.Float64Field import space.kscience.kmath.operations.invoke -import space.kscience.kmath.structures.Float64Buffer -import space.kscience.kmath.structures.Float32Buffer import space.kscience.kmath.structures.DoubleBuffer import space.kscience.kmath.structures.FloatBuffer import kotlin.reflect.KClass diff --git a/kmath-histograms/src/commonMain/kotlin/space/kscience/kmath/histogram/UniformHistogramGroupND.kt b/kmath-histograms/src/commonMain/kotlin/space/kscience/kmath/histogram/UniformHistogramGroupND.kt index f7094f7a9..f9e2e14f2 100644 --- a/kmath-histograms/src/commonMain/kotlin/space/kscience/kmath/histogram/UniformHistogramGroupND.kt +++ b/kmath-histograms/src/commonMain/kotlin/space/kscience/kmath/histogram/UniformHistogramGroupND.kt @@ -147,12 +147,12 @@ public fun > Histogram.Companion.uniformNDFromRanges( bufferFactory: BufferFactory = valueAlgebraND.elementAlgebra.bufferFactory, ): UniformHistogramGroupND = UniformHistogramGroupND( valueAlgebraND, - ListBuffer( + DoubleBuffer( ranges .map(Pair, Int>::first) .map(ClosedFloatingPointRange::start) ), - ListBuffer( + DoubleBuffer( ranges .map(Pair, Int>::first) .map(ClosedFloatingPointRange::endInclusive)