forked from kscience/kmath
0.4 WIP
This commit is contained in:
parent
ea887b8c72
commit
46eacbb750
@ -3,10 +3,11 @@
|
|||||||
## Unreleased
|
## Unreleased
|
||||||
|
|
||||||
### Added
|
### Added
|
||||||
- Integer division algebras
|
- Explicit `SafeType` for algebras and buffers.
|
||||||
- Float32 geometries
|
- Integer division algebras.
|
||||||
- New Attributes-kt module that could be used as stand-alone. It declares type-safe attributes containers.
|
- Float32 geometries.
|
||||||
- Explicit `mutableStructureND` builders for mutable structures
|
- New Attributes-kt module that could be used as stand-alone. It declares. type-safe attributes containers.
|
||||||
|
- Explicit `mutableStructureND` builders for mutable structures.
|
||||||
|
|
||||||
### Changed
|
### Changed
|
||||||
- Default naming for algebra and buffers now uses IntXX/FloatXX notation instead of Java types.
|
- Default naming for algebra and buffers now uses IntXX/FloatXX notation instead of Java types.
|
||||||
|
@ -16,7 +16,7 @@ import kotlin.reflect.typeOf
|
|||||||
* @param kType raw [KType]
|
* @param kType raw [KType]
|
||||||
*/
|
*/
|
||||||
@JvmInline
|
@JvmInline
|
||||||
public value class SafeType<T> @PublishedApi internal constructor(public val kType: KType)
|
public value class SafeType<out T> @PublishedApi internal constructor(public val kType: KType)
|
||||||
|
|
||||||
public inline fun <reified T> safeTypeOf(): SafeType<T> = SafeType(typeOf<T>())
|
public inline fun <reified T> safeTypeOf(): SafeType<T> = SafeType(typeOf<T>())
|
||||||
|
|
||||||
|
@ -52,10 +52,10 @@ public interface XYColumnarData<out T, out X : T, out Y : T> : ColumnarData<T> {
|
|||||||
* Create two-column data from a list of row-objects (zero-copy)
|
* Create two-column data from a list of row-objects (zero-copy)
|
||||||
*/
|
*/
|
||||||
@UnstableKMathAPI
|
@UnstableKMathAPI
|
||||||
public fun <I, T, X : T, Y : T> ofList(
|
public inline fun <I, T, reified X : T, reified Y : T> ofList(
|
||||||
list: List<I>,
|
list: List<I>,
|
||||||
xConverter: (I) -> X,
|
noinline xConverter: (I) -> X,
|
||||||
yConverter: (I) -> Y,
|
noinline yConverter: (I) -> Y,
|
||||||
): XYColumnarData<T, X, Y> = object : XYColumnarData<T, X, Y> {
|
): XYColumnarData<T, X, Y> = object : XYColumnarData<T, X, Y> {
|
||||||
override val size: Int get() = list.size
|
override val size: Int get() = list.size
|
||||||
|
|
||||||
|
@ -13,8 +13,6 @@ import space.kscience.kmath.structures.MutableBufferFactory
|
|||||||
import space.kscience.kmath.structures.asBuffer
|
import space.kscience.kmath.structures.asBuffer
|
||||||
import kotlin.math.max
|
import kotlin.math.max
|
||||||
import kotlin.math.min
|
import kotlin.math.min
|
||||||
import kotlin.reflect.KType
|
|
||||||
import kotlin.reflect.typeOf
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Class representing both the value and the differentials of a function.
|
* Class representing both the value and the differentials of a function.
|
||||||
@ -84,6 +82,8 @@ public abstract class DSAlgebra<T, A : Ring<T>>(
|
|||||||
bindings: Map<Symbol, T>,
|
bindings: Map<Symbol, T>,
|
||||||
) : ExpressionAlgebra<T, DS<T, A>>, SymbolIndexer {
|
) : ExpressionAlgebra<T, DS<T, A>>, SymbolIndexer {
|
||||||
|
|
||||||
|
override val bufferFactory: MutableBufferFactory<DS<T, A>> = MutableBufferFactory()
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get the compiler for number of free parameters and order.
|
* Get the compiler for number of free parameters and order.
|
||||||
*
|
*
|
||||||
|
@ -6,6 +6,7 @@
|
|||||||
package space.kscience.kmath.expressions
|
package space.kscience.kmath.expressions
|
||||||
|
|
||||||
import space.kscience.kmath.operations.*
|
import space.kscience.kmath.operations.*
|
||||||
|
import space.kscience.kmath.structures.MutableBufferFactory
|
||||||
import kotlin.contracts.InvocationKind
|
import kotlin.contracts.InvocationKind
|
||||||
import kotlin.contracts.contract
|
import kotlin.contracts.contract
|
||||||
|
|
||||||
@ -17,6 +18,8 @@ import kotlin.contracts.contract
|
|||||||
public abstract class FunctionalExpressionAlgebra<T, out A : Algebra<T>>(
|
public abstract class FunctionalExpressionAlgebra<T, out A : Algebra<T>>(
|
||||||
public val algebra: A,
|
public val algebra: A,
|
||||||
) : ExpressionAlgebra<T, Expression<T>> {
|
) : ExpressionAlgebra<T, Expression<T>> {
|
||||||
|
override val bufferFactory: MutableBufferFactory<Expression<T>> = MutableBufferFactory<Expression<T>>()
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Builds an Expression of constant expression that does not depend on arguments.
|
* Builds an Expression of constant expression that does not depend on arguments.
|
||||||
*/
|
*/
|
||||||
@ -49,6 +52,7 @@ public abstract class FunctionalExpressionAlgebra<T, out A : Algebra<T>>(
|
|||||||
public open class FunctionalExpressionGroup<T, out A : Group<T>>(
|
public open class FunctionalExpressionGroup<T, out A : Group<T>>(
|
||||||
algebra: A,
|
algebra: A,
|
||||||
) : FunctionalExpressionAlgebra<T, A>(algebra), Group<Expression<T>> {
|
) : FunctionalExpressionAlgebra<T, A>(algebra), Group<Expression<T>> {
|
||||||
|
|
||||||
override val zero: Expression<T> get() = const(algebra.zero)
|
override val zero: Expression<T> get() = const(algebra.zero)
|
||||||
|
|
||||||
override fun Expression<T>.unaryMinus(): Expression<T> =
|
override fun Expression<T>.unaryMinus(): Expression<T> =
|
||||||
|
@ -7,11 +7,14 @@ package space.kscience.kmath.expressions
|
|||||||
|
|
||||||
import space.kscience.kmath.UnstableKMathAPI
|
import space.kscience.kmath.UnstableKMathAPI
|
||||||
import space.kscience.kmath.operations.*
|
import space.kscience.kmath.operations.*
|
||||||
|
import space.kscience.kmath.structures.MutableBufferFactory
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* [Algebra] over [MST] nodes.
|
* [Algebra] over [MST] nodes.
|
||||||
*/
|
*/
|
||||||
public object MstNumericAlgebra : NumericAlgebra<MST> {
|
public object MstNumericAlgebra : NumericAlgebra<MST> {
|
||||||
|
override val bufferFactory: MutableBufferFactory<MST> = MutableBufferFactory()
|
||||||
|
|
||||||
override fun number(value: Number): MST.Numeric = MST.Numeric(value)
|
override fun number(value: Number): MST.Numeric = MST.Numeric(value)
|
||||||
override fun bindSymbolOrNull(value: String): Symbol = StringSymbol(value)
|
override fun bindSymbolOrNull(value: String): Symbol = StringSymbol(value)
|
||||||
override fun bindSymbol(value: String): Symbol = bindSymbolOrNull(value)
|
override fun bindSymbol(value: String): Symbol = bindSymbolOrNull(value)
|
||||||
@ -27,6 +30,9 @@ public object MstNumericAlgebra : NumericAlgebra<MST> {
|
|||||||
* [Group] over [MST] nodes.
|
* [Group] over [MST] nodes.
|
||||||
*/
|
*/
|
||||||
public object MstGroup : Group<MST>, NumericAlgebra<MST>, ScaleOperations<MST> {
|
public object MstGroup : Group<MST>, NumericAlgebra<MST>, ScaleOperations<MST> {
|
||||||
|
|
||||||
|
override val bufferFactory: MutableBufferFactory<MST> = MutableBufferFactory()
|
||||||
|
|
||||||
override val zero: MST.Numeric = number(0.0)
|
override val zero: MST.Numeric = number(0.0)
|
||||||
|
|
||||||
override fun number(value: Number): MST.Numeric = MstNumericAlgebra.number(value)
|
override fun number(value: Number): MST.Numeric = MstNumericAlgebra.number(value)
|
||||||
@ -57,7 +63,11 @@ public object MstGroup : Group<MST>, NumericAlgebra<MST>, ScaleOperations<MST> {
|
|||||||
@Suppress("OVERRIDE_BY_INLINE")
|
@Suppress("OVERRIDE_BY_INLINE")
|
||||||
@OptIn(UnstableKMathAPI::class)
|
@OptIn(UnstableKMathAPI::class)
|
||||||
public object MstRing : Ring<MST>, NumbersAddOps<MST>, ScaleOperations<MST> {
|
public object MstRing : Ring<MST>, NumbersAddOps<MST>, ScaleOperations<MST> {
|
||||||
|
|
||||||
|
override val bufferFactory: MutableBufferFactory<MST> = MutableBufferFactory()
|
||||||
|
|
||||||
override inline val zero: MST.Numeric get() = MstGroup.zero
|
override inline val zero: MST.Numeric get() = MstGroup.zero
|
||||||
|
|
||||||
override val one: MST.Numeric = number(1.0)
|
override val one: MST.Numeric = number(1.0)
|
||||||
|
|
||||||
override fun number(value: Number): MST.Numeric = MstGroup.number(value)
|
override fun number(value: Number): MST.Numeric = MstGroup.number(value)
|
||||||
@ -87,7 +97,11 @@ public object MstRing : Ring<MST>, NumbersAddOps<MST>, ScaleOperations<MST> {
|
|||||||
@Suppress("OVERRIDE_BY_INLINE")
|
@Suppress("OVERRIDE_BY_INLINE")
|
||||||
@OptIn(UnstableKMathAPI::class)
|
@OptIn(UnstableKMathAPI::class)
|
||||||
public object MstField : Field<MST>, NumbersAddOps<MST>, ScaleOperations<MST> {
|
public object MstField : Field<MST>, NumbersAddOps<MST>, ScaleOperations<MST> {
|
||||||
|
|
||||||
|
override val bufferFactory: MutableBufferFactory<MST> = MutableBufferFactory()
|
||||||
|
|
||||||
override inline val zero: MST.Numeric get() = MstRing.zero
|
override inline val zero: MST.Numeric get() = MstRing.zero
|
||||||
|
|
||||||
override inline val one: MST.Numeric get() = MstRing.one
|
override inline val one: MST.Numeric get() = MstRing.one
|
||||||
|
|
||||||
override fun bindSymbolOrNull(value: String): Symbol = MstNumericAlgebra.bindSymbolOrNull(value)
|
override fun bindSymbolOrNull(value: String): Symbol = MstNumericAlgebra.bindSymbolOrNull(value)
|
||||||
@ -117,7 +131,11 @@ public object MstField : Field<MST>, NumbersAddOps<MST>, ScaleOperations<MST> {
|
|||||||
*/
|
*/
|
||||||
@Suppress("OVERRIDE_BY_INLINE")
|
@Suppress("OVERRIDE_BY_INLINE")
|
||||||
public object MstExtendedField : ExtendedField<MST>, NumericAlgebra<MST> {
|
public object MstExtendedField : ExtendedField<MST>, NumericAlgebra<MST> {
|
||||||
|
|
||||||
|
override val bufferFactory: MutableBufferFactory<MST> = MutableBufferFactory()
|
||||||
|
|
||||||
override inline val zero: MST.Numeric get() = MstField.zero
|
override inline val zero: MST.Numeric get() = MstField.zero
|
||||||
|
|
||||||
override inline val one: MST.Numeric get() = MstField.one
|
override inline val one: MST.Numeric get() = MstField.one
|
||||||
|
|
||||||
override fun bindSymbolOrNull(value: String): Symbol = MstNumericAlgebra.bindSymbolOrNull(value)
|
override fun bindSymbolOrNull(value: String): Symbol = MstNumericAlgebra.bindSymbolOrNull(value)
|
||||||
@ -164,6 +182,9 @@ public object MstExtendedField : ExtendedField<MST>, NumericAlgebra<MST> {
|
|||||||
*/
|
*/
|
||||||
@UnstableKMathAPI
|
@UnstableKMathAPI
|
||||||
public object MstLogicAlgebra : LogicAlgebra<MST> {
|
public object MstLogicAlgebra : LogicAlgebra<MST> {
|
||||||
|
|
||||||
|
override val bufferFactory: MutableBufferFactory<MST> = MutableBufferFactory()
|
||||||
|
|
||||||
override fun bindSymbolOrNull(value: String): MST = super.bindSymbolOrNull(value) ?: StringSymbol(value)
|
override fun bindSymbolOrNull(value: String): MST = super.bindSymbolOrNull(value) ?: StringSymbol(value)
|
||||||
|
|
||||||
override fun const(boolean: Boolean): Symbol = if (boolean) {
|
override fun const(boolean: Boolean): Symbol = if (boolean) {
|
||||||
@ -176,7 +197,7 @@ public object MstLogicAlgebra : LogicAlgebra<MST> {
|
|||||||
|
|
||||||
override fun MST.and(other: MST): MST = MST.Binary(Boolean::and.name, this, other)
|
override fun MST.and(other: MST): MST = MST.Binary(Boolean::and.name, this, other)
|
||||||
|
|
||||||
override fun MST.or(other: MST): MST = MST.Binary(Boolean::or.name, this, other)
|
override fun MST.or(other: MST): MST = MST.Binary(Boolean::or.name, this, other)
|
||||||
|
|
||||||
override fun MST.xor(other: MST): MST = MST.Binary(Boolean::xor.name, this, other)
|
override fun MST.xor(other: MST): MST = MST.Binary(Boolean::xor.name, this, other)
|
||||||
}
|
}
|
||||||
|
@ -5,9 +5,11 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.expressions
|
package space.kscience.kmath.expressions
|
||||||
|
|
||||||
|
import space.kscience.attributes.SafeType
|
||||||
import space.kscience.kmath.UnstableKMathAPI
|
import space.kscience.kmath.UnstableKMathAPI
|
||||||
import space.kscience.kmath.linear.Point
|
import space.kscience.kmath.linear.Point
|
||||||
import space.kscience.kmath.operations.*
|
import space.kscience.kmath.operations.*
|
||||||
|
import space.kscience.kmath.structures.MutableBufferFactory
|
||||||
import space.kscience.kmath.structures.asBuffer
|
import space.kscience.kmath.structures.asBuffer
|
||||||
import kotlin.contracts.InvocationKind
|
import kotlin.contracts.InvocationKind
|
||||||
import kotlin.contracts.contract
|
import kotlin.contracts.contract
|
||||||
@ -30,9 +32,10 @@ public open class AutoDiffValue<out T>(public val value: T)
|
|||||||
*/
|
*/
|
||||||
public class DerivationResult<T : Any>(
|
public class DerivationResult<T : Any>(
|
||||||
public val value: T,
|
public val value: T,
|
||||||
|
override val type: SafeType<T>,
|
||||||
private val derivativeValues: Map<String, T>,
|
private val derivativeValues: Map<String, T>,
|
||||||
public val context: Field<T>,
|
public val context: Field<T>,
|
||||||
) {
|
) : WithType<T> {
|
||||||
/**
|
/**
|
||||||
* Returns derivative of [variable] or returns [Ring.zero] in [context].
|
* Returns derivative of [variable] or returns [Ring.zero] in [context].
|
||||||
*/
|
*/
|
||||||
@ -49,19 +52,23 @@ public class DerivationResult<T : Any>(
|
|||||||
*/
|
*/
|
||||||
public fun <T : Any> DerivationResult<T>.grad(vararg variables: Symbol): Point<T> {
|
public fun <T : Any> DerivationResult<T>.grad(vararg variables: Symbol): Point<T> {
|
||||||
check(variables.isNotEmpty()) { "Variable order is not provided for gradient construction" }
|
check(variables.isNotEmpty()) { "Variable order is not provided for gradient construction" }
|
||||||
return variables.map(::derivative).asBuffer()
|
return variables.map(::derivative).asBuffer(type)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Represents field in context of which functions can be derived.
|
* Represents field. Function derivatives could be computed in this field
|
||||||
*/
|
*/
|
||||||
@OptIn(UnstableKMathAPI::class)
|
@OptIn(UnstableKMathAPI::class)
|
||||||
public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
|
public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
|
||||||
public val context: F,
|
public val algebra: F,
|
||||||
bindings: Map<Symbol, T>,
|
bindings: Map<Symbol, T>,
|
||||||
) : Field<AutoDiffValue<T>>, ExpressionAlgebra<T, AutoDiffValue<T>>, NumbersAddOps<AutoDiffValue<T>> {
|
) : Field<AutoDiffValue<T>>, ExpressionAlgebra<T, AutoDiffValue<T>>, NumbersAddOps<AutoDiffValue<T>> {
|
||||||
override val zero: AutoDiffValue<T> get() = const(context.zero)
|
|
||||||
override val one: AutoDiffValue<T> get() = const(context.one)
|
override val bufferFactory: MutableBufferFactory<AutoDiffValue<T>> = MutableBufferFactory<AutoDiffValue<T>>()
|
||||||
|
|
||||||
|
override val zero: AutoDiffValue<T> get() = const(algebra.zero)
|
||||||
|
|
||||||
|
override val one: AutoDiffValue<T> get() = const(algebra.one)
|
||||||
|
|
||||||
// this stack contains pairs of blocks and values to apply them to
|
// this stack contains pairs of blocks and values to apply them to
|
||||||
private var stack: Array<Any?> = arrayOfNulls<Any?>(8)
|
private var stack: Array<Any?> = arrayOfNulls<Any?>(8)
|
||||||
@ -69,7 +76,7 @@ public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
|
|||||||
private val derivatives: MutableMap<AutoDiffValue<T>, T> = hashMapOf()
|
private val derivatives: MutableMap<AutoDiffValue<T>, T> = hashMapOf()
|
||||||
|
|
||||||
private val bindings: Map<String, AutoDiffVariableWithDerivative<T>> = bindings.entries.associate {
|
private val bindings: Map<String, AutoDiffVariableWithDerivative<T>> = bindings.entries.associate {
|
||||||
it.key.identity to AutoDiffVariableWithDerivative(it.key.identity, it.value, context.zero)
|
it.key.identity to AutoDiffVariableWithDerivative(it.key.identity, it.value, algebra.zero)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -92,7 +99,7 @@ public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
|
|||||||
override fun bindSymbolOrNull(value: String): AutoDiffValue<T>? = bindings[value]
|
override fun bindSymbolOrNull(value: String): AutoDiffValue<T>? = bindings[value]
|
||||||
|
|
||||||
private fun getDerivative(variable: AutoDiffValue<T>): T =
|
private fun getDerivative(variable: AutoDiffValue<T>): T =
|
||||||
(variable as? AutoDiffVariableWithDerivative)?.d ?: derivatives[variable] ?: context.zero
|
(variable as? AutoDiffVariableWithDerivative)?.d ?: derivatives[variable] ?: algebra.zero
|
||||||
|
|
||||||
private fun setDerivative(variable: AutoDiffValue<T>, value: T) {
|
private fun setDerivative(variable: AutoDiffValue<T>, value: T) {
|
||||||
if (variable is AutoDiffVariableWithDerivative) variable.d = value else derivatives[variable] = value
|
if (variable is AutoDiffVariableWithDerivative) variable.d = value else derivatives[variable] = value
|
||||||
@ -103,7 +110,7 @@ public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
|
|||||||
while (sp > 0) {
|
while (sp > 0) {
|
||||||
val value = stack[--sp]
|
val value = stack[--sp]
|
||||||
val block = stack[--sp] as F.(Any?) -> Unit
|
val block = stack[--sp] as F.(Any?) -> Unit
|
||||||
context.block(value)
|
algebra.block(value)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -130,7 +137,6 @@ public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
|
|||||||
* }
|
* }
|
||||||
* ```
|
* ```
|
||||||
*/
|
*/
|
||||||
@Suppress("UNCHECKED_CAST")
|
|
||||||
public fun <R> derive(value: R, block: F.(R) -> Unit): R {
|
public fun <R> derive(value: R, block: F.(R) -> Unit): R {
|
||||||
// save block to stack for backward pass
|
// save block to stack for backward pass
|
||||||
if (sp >= stack.size) stack = stack.copyOf(stack.size * 2)
|
if (sp >= stack.size) stack = stack.copyOf(stack.size * 2)
|
||||||
@ -142,9 +148,9 @@ public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
|
|||||||
|
|
||||||
internal fun differentiate(function: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>): DerivationResult<T> {
|
internal fun differentiate(function: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>): DerivationResult<T> {
|
||||||
val result = function()
|
val result = function()
|
||||||
result.d = context.one // computing derivative w.r.t result
|
result.d = algebra.one // computing derivative w.r.t result
|
||||||
runBackwardPass()
|
runBackwardPass()
|
||||||
return DerivationResult(result.value, bindings.mapValues { it.value.d }, context)
|
return DerivationResult(result.value, algebra.type, bindings.mapValues { it.value.d }, algebra)
|
||||||
}
|
}
|
||||||
|
|
||||||
// // Overloads for Double constants
|
// // Overloads for Double constants
|
||||||
@ -194,7 +200,7 @@ public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
|
|||||||
|
|
||||||
public inline fun <T : Any, F : Field<T>> SimpleAutoDiffField<T, F>.const(block: F.() -> T): AutoDiffValue<T> {
|
public inline fun <T : Any, F : Field<T>> SimpleAutoDiffField<T, F>.const(block: F.() -> T): AutoDiffValue<T> {
|
||||||
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||||
return const(context.block())
|
return const(algebra.block())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -10,10 +10,9 @@ import space.kscience.kmath.UnstableKMathAPI
|
|||||||
import space.kscience.kmath.nd.*
|
import space.kscience.kmath.nd.*
|
||||||
import space.kscience.kmath.operations.BufferRingOps
|
import space.kscience.kmath.operations.BufferRingOps
|
||||||
import space.kscience.kmath.operations.Ring
|
import space.kscience.kmath.operations.Ring
|
||||||
|
import space.kscience.kmath.operations.WithType
|
||||||
import space.kscience.kmath.operations.invoke
|
import space.kscience.kmath.operations.invoke
|
||||||
import space.kscience.kmath.structures.Buffer
|
import space.kscience.kmath.structures.Buffer
|
||||||
import kotlin.reflect.KType
|
|
||||||
import kotlin.reflect.typeOf
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Alias for [Structure2D] with more familiar name.
|
* Alias for [Structure2D] with more familiar name.
|
||||||
@ -34,7 +33,7 @@ public typealias Point<T> = Buffer<T>
|
|||||||
* A marker interface for algebras that operate on matrices
|
* A marker interface for algebras that operate on matrices
|
||||||
* @param T type of matrix element
|
* @param T type of matrix element
|
||||||
*/
|
*/
|
||||||
public interface MatrixOperations<T>
|
public interface MatrixOperations<T> : WithType<T>
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Basic operations on matrices and vectors.
|
* Basic operations on matrices and vectors.
|
||||||
@ -45,6 +44,8 @@ public interface MatrixOperations<T>
|
|||||||
public interface LinearSpace<T, out A : Ring<T>> : MatrixOperations<T> {
|
public interface LinearSpace<T, out A : Ring<T>> : MatrixOperations<T> {
|
||||||
public val elementAlgebra: A
|
public val elementAlgebra: A
|
||||||
|
|
||||||
|
override val type: SafeType<T> get() = elementAlgebra.type
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Produces a matrix with this context and given dimensions.
|
* Produces a matrix with this context and given dimensions.
|
||||||
*/
|
*/
|
||||||
@ -212,4 +213,4 @@ public fun <T : Any> Matrix<T>.asVector(): Point<T> =
|
|||||||
* @receiver a buffer.
|
* @receiver a buffer.
|
||||||
* @return the new matrix.
|
* @return the new matrix.
|
||||||
*/
|
*/
|
||||||
public fun <T : Any> Point<T>.asMatrix(): VirtualMatrix<T> = VirtualMatrix(size, 1) { i, _ -> get(i) }
|
public fun <T : Any> Point<T>.asMatrix(): VirtualMatrix<T> = VirtualMatrix(type, size, 1) { i, _ -> get(i) }
|
@ -22,11 +22,27 @@ import space.kscience.kmath.structures.*
|
|||||||
* @param l The lower triangular matrix in this decomposition. It may have [LowerTriangular].
|
* @param l The lower triangular matrix in this decomposition. It may have [LowerTriangular].
|
||||||
* @param u The upper triangular matrix in this decomposition. It may have [UpperTriangular].
|
* @param u The upper triangular matrix in this decomposition. It may have [UpperTriangular].
|
||||||
*/
|
*/
|
||||||
public data class LupDecomposition<T>(
|
public class LupDecomposition<T>(
|
||||||
|
public val linearSpace: LinearSpace<T, Ring<T>>,
|
||||||
public val l: Matrix<T>,
|
public val l: Matrix<T>,
|
||||||
public val u: Matrix<T>,
|
public val u: Matrix<T>,
|
||||||
public val pivot: IntBuffer,
|
public val pivot: IntBuffer,
|
||||||
)
|
) {
|
||||||
|
public val elementAlgebra: Ring<T> get() = linearSpace.elementAlgebra
|
||||||
|
|
||||||
|
public val pivotMatrix: VirtualMatrix<T>
|
||||||
|
get() = VirtualMatrix(linearSpace.type, l.rowNum, l.colNum) { row, column ->
|
||||||
|
if (column == pivot[row]) elementAlgebra.one else elementAlgebra.zero
|
||||||
|
}
|
||||||
|
|
||||||
|
public val <T> LupDecomposition<T>.determinant by lazy {
|
||||||
|
elementAlgebra { (0 until l.shape[0]).fold(if (even) one else -one) { value, i -> value * lu[i, i] } }
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
public class LupDecompositionAttribute<T>(type: SafeType<LupDecomposition<T>>) :
|
public class LupDecompositionAttribute<T>(type: SafeType<LupDecomposition<T>>) :
|
||||||
@ -168,7 +184,7 @@ public fun <T : Comparable<T>> LinearSpace<T, Field<T>>.lup(
|
|||||||
for (row in col + 1 until m) lu[row, col] /= luDiag
|
for (row in col + 1 until m) lu[row, col] /= luDiag
|
||||||
}
|
}
|
||||||
|
|
||||||
val l: MatrixWrapper<T> = VirtualMatrix(rowNum, colNum) { i, j ->
|
val l: MatrixWrapper<T> = VirtualMatrix(type, rowNum, colNum) { i, j ->
|
||||||
when {
|
when {
|
||||||
j < i -> lu[i, j]
|
j < i -> lu[i, j]
|
||||||
j == i -> one
|
j == i -> one
|
||||||
@ -176,7 +192,7 @@ public fun <T : Comparable<T>> LinearSpace<T, Field<T>>.lup(
|
|||||||
}
|
}
|
||||||
}.withAttribute(LowerTriangular)
|
}.withAttribute(LowerTriangular)
|
||||||
|
|
||||||
val u = VirtualMatrix(rowNum, colNum) { i, j ->
|
val u = VirtualMatrix(type, rowNum, colNum) { i, j ->
|
||||||
if (j >= i) lu[i, j] else zero
|
if (j >= i) lu[i, j] else zero
|
||||||
}.withAttribute(UpperTriangular)
|
}.withAttribute(UpperTriangular)
|
||||||
//
|
//
|
||||||
@ -184,7 +200,7 @@ public fun <T : Comparable<T>> LinearSpace<T, Field<T>>.lup(
|
|||||||
// if (j == pivot[i]) one else zero
|
// if (j == pivot[i]) one else zero
|
||||||
// }.withAttribute(Determinant, if (even) one else -one)
|
// }.withAttribute(Determinant, if (even) one else -one)
|
||||||
|
|
||||||
return LupDecomposition(l, u, pivot.asBuffer())
|
return LupDecomposition(this@lup, l, u, pivot.asBuffer())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -6,16 +6,21 @@
|
|||||||
package space.kscience.kmath.linear
|
package space.kscience.kmath.linear
|
||||||
|
|
||||||
import space.kscience.attributes.FlagAttribute
|
import space.kscience.attributes.FlagAttribute
|
||||||
|
import space.kscience.attributes.SafeType
|
||||||
import space.kscience.kmath.UnstableKMathAPI
|
import space.kscience.kmath.UnstableKMathAPI
|
||||||
import space.kscience.kmath.operations.Ring
|
import space.kscience.kmath.operations.Ring
|
||||||
|
import space.kscience.kmath.operations.WithType
|
||||||
import space.kscience.kmath.structures.BufferAccessor2D
|
import space.kscience.kmath.structures.BufferAccessor2D
|
||||||
import space.kscience.kmath.structures.MutableBuffer
|
import space.kscience.kmath.structures.MutableBufferFactory
|
||||||
|
|
||||||
public class MatrixBuilder<T : Any, out A : Ring<T>>(
|
public class MatrixBuilder<T : Any, out A : Ring<T>>(
|
||||||
public val linearSpace: LinearSpace<T, A>,
|
public val linearSpace: LinearSpace<T, A>,
|
||||||
public val rows: Int,
|
public val rows: Int,
|
||||||
public val columns: Int,
|
public val columns: Int,
|
||||||
) {
|
) : WithType<T> {
|
||||||
|
|
||||||
|
override val type: SafeType<T> get() = linearSpace.type
|
||||||
|
|
||||||
public operator fun invoke(vararg elements: T): Matrix<T> {
|
public operator fun invoke(vararg elements: T): Matrix<T> {
|
||||||
require(rows * columns == elements.size) { "The number of elements ${elements.size} is not equal $rows * $columns" }
|
require(rows * columns == elements.size) { "The number of elements ${elements.size} is not equal $rows * $columns" }
|
||||||
return linearSpace.buildMatrix(rows, columns) { i, j -> elements[i * columns + j] }
|
return linearSpace.buildMatrix(rows, columns) { i, j -> elements[i * columns + j] }
|
||||||
@ -61,7 +66,7 @@ public fun <T : Any, A : Ring<T>> MatrixBuilder<T, A>.symmetric(
|
|||||||
builder: (i: Int, j: Int) -> T,
|
builder: (i: Int, j: Int) -> T,
|
||||||
): Matrix<T> {
|
): Matrix<T> {
|
||||||
require(columns == rows) { "In order to build symmetric matrix, number of rows $rows should be equal to number of columns $columns" }
|
require(columns == rows) { "In order to build symmetric matrix, number of rows $rows should be equal to number of columns $columns" }
|
||||||
return with(BufferAccessor2D<T?>(rows, rows, MutableBuffer.Companion::boxing)) {
|
return with(BufferAccessor2D<T?>(rows, rows, MutableBufferFactory(type))) {
|
||||||
val cache = factory(rows * rows) { null }
|
val cache = factory(rows * rows) { null }
|
||||||
linearSpace.buildMatrix(rows, rows) { i, j ->
|
linearSpace.buildMatrix(rows, rows) { i, j ->
|
||||||
val cached = cache[i, j]
|
val cached = cache[i, j]
|
||||||
|
@ -39,7 +39,7 @@ public fun <T : Any, A : Attribute<T>> Matrix<T>.withAttribute(
|
|||||||
attribute: A,
|
attribute: A,
|
||||||
attrValue: T,
|
attrValue: T,
|
||||||
): MatrixWrapper<T> = if (this is MatrixWrapper) {
|
): MatrixWrapper<T> = if (this is MatrixWrapper) {
|
||||||
MatrixWrapper(origin, attributes.withAttribute(attribute,attrValue))
|
MatrixWrapper(origin, attributes.withAttribute(attribute, attrValue))
|
||||||
} else {
|
} else {
|
||||||
MatrixWrapper(this, Attributes(attribute, attrValue))
|
MatrixWrapper(this, Attributes(attribute, attrValue))
|
||||||
}
|
}
|
||||||
@ -68,7 +68,7 @@ public fun <T : Any> Matrix<T>.modifyAttributes(modifier: (Attributes) -> Attrib
|
|||||||
public fun <T : Any> LinearSpace<T, Ring<T>>.one(
|
public fun <T : Any> LinearSpace<T, Ring<T>>.one(
|
||||||
rows: Int,
|
rows: Int,
|
||||||
columns: Int,
|
columns: Int,
|
||||||
): MatrixWrapper<T> = VirtualMatrix(rows, columns) { i, j ->
|
): MatrixWrapper<T> = VirtualMatrix(type, rows, columns) { i, j ->
|
||||||
if (i == j) elementAlgebra.one else elementAlgebra.zero
|
if (i == j) elementAlgebra.one else elementAlgebra.zero
|
||||||
}.withAttribute(IsUnit)
|
}.withAttribute(IsUnit)
|
||||||
|
|
||||||
@ -79,6 +79,6 @@ public fun <T : Any> LinearSpace<T, Ring<T>>.one(
|
|||||||
public fun <T : Any> LinearSpace<T, Ring<T>>.zero(
|
public fun <T : Any> LinearSpace<T, Ring<T>>.zero(
|
||||||
rows: Int,
|
rows: Int,
|
||||||
columns: Int,
|
columns: Int,
|
||||||
): MatrixWrapper<T> = VirtualMatrix(rows, columns) { _, _ ->
|
): MatrixWrapper<T> = VirtualMatrix(type, rows, columns) { _, _ ->
|
||||||
elementAlgebra.zero
|
elementAlgebra.zero
|
||||||
}.withAttribute(IsZero)
|
}.withAttribute(IsZero)
|
||||||
|
@ -6,10 +6,14 @@
|
|||||||
package space.kscience.kmath.linear
|
package space.kscience.kmath.linear
|
||||||
|
|
||||||
import space.kscience.attributes.Attributes
|
import space.kscience.attributes.Attributes
|
||||||
|
import space.kscience.attributes.SafeType
|
||||||
|
|
||||||
|
|
||||||
public class TransposedMatrix<T>(public val origin: Matrix<T>) : Matrix<T> {
|
public class TransposedMatrix<T>(public val origin: Matrix<T>) : Matrix<T> {
|
||||||
|
override val type: SafeType<T> get() = origin.type
|
||||||
|
|
||||||
override val rowNum: Int get() = origin.colNum
|
override val rowNum: Int get() = origin.colNum
|
||||||
|
|
||||||
override val colNum: Int get() = origin.rowNum
|
override val colNum: Int get() = origin.rowNum
|
||||||
|
|
||||||
override fun get(i: Int, j: Int): T = origin[j, i]
|
override fun get(i: Int, j: Int): T = origin[j, i]
|
||||||
|
@ -6,6 +6,7 @@
|
|||||||
package space.kscience.kmath.linear
|
package space.kscience.kmath.linear
|
||||||
|
|
||||||
import space.kscience.attributes.Attributes
|
import space.kscience.attributes.Attributes
|
||||||
|
import space.kscience.attributes.SafeType
|
||||||
import space.kscience.kmath.nd.ShapeND
|
import space.kscience.kmath.nd.ShapeND
|
||||||
|
|
||||||
|
|
||||||
@ -14,7 +15,8 @@ import space.kscience.kmath.nd.ShapeND
|
|||||||
*
|
*
|
||||||
* @property generator the function that provides elements.
|
* @property generator the function that provides elements.
|
||||||
*/
|
*/
|
||||||
public class VirtualMatrix<out T : Any>(
|
public class VirtualMatrix<out T>(
|
||||||
|
override val type: SafeType<T>,
|
||||||
override val rowNum: Int,
|
override val rowNum: Int,
|
||||||
override val colNum: Int,
|
override val colNum: Int,
|
||||||
override val attributes: Attributes = Attributes.EMPTY,
|
override val attributes: Attributes = Attributes.EMPTY,
|
||||||
@ -29,4 +31,4 @@ public class VirtualMatrix<out T : Any>(
|
|||||||
public fun <T : Any> MatrixBuilder<T, *>.virtual(
|
public fun <T : Any> MatrixBuilder<T, *>.virtual(
|
||||||
attributes: Attributes = Attributes.EMPTY,
|
attributes: Attributes = Attributes.EMPTY,
|
||||||
generator: (i: Int, j: Int) -> T,
|
generator: (i: Int, j: Int) -> T,
|
||||||
): VirtualMatrix<T> = VirtualMatrix(rows, columns, attributes, generator)
|
): VirtualMatrix<T> = VirtualMatrix(type, rows, columns, attributes, generator)
|
||||||
|
@ -3,13 +3,11 @@
|
|||||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
@file:OptIn(UnstableKMathAPI::class)
|
|
||||||
@file:Suppress("UnusedReceiverParameter")
|
@file:Suppress("UnusedReceiverParameter")
|
||||||
|
|
||||||
package space.kscience.kmath.linear
|
package space.kscience.kmath.linear
|
||||||
|
|
||||||
import space.kscience.attributes.*
|
import space.kscience.attributes.*
|
||||||
import space.kscience.kmath.UnstableKMathAPI
|
|
||||||
import space.kscience.kmath.nd.StructureAttribute
|
import space.kscience.kmath.nd.StructureAttribute
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -51,9 +49,11 @@ public val <T> MatrixOperations<T>.Inverted: Inverted<T> get() = Inverted(safeTy
|
|||||||
*
|
*
|
||||||
* @param T the type of matrices' items.
|
* @param T the type of matrices' items.
|
||||||
*/
|
*/
|
||||||
public class Determinant<T> : MatrixAttribute<T>
|
public class Determinant<T>(type: SafeType<T>) :
|
||||||
|
PolymorphicAttribute<T>(type),
|
||||||
|
MatrixAttribute<T>
|
||||||
|
|
||||||
public val <T> MatrixOperations<T>.Determinant: Determinant<T> get() = Determinant()
|
public val <T> MatrixOperations<T>.Determinant: Determinant<T> get() = Determinant(type)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Matrices with this feature are lower triangular ones.
|
* Matrices with this feature are lower triangular ones.
|
||||||
|
@ -4,6 +4,8 @@
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
|
|
||||||
|
@file:OptIn(UnstableKMathAPI::class)
|
||||||
|
|
||||||
package space.kscience.kmath.misc
|
package space.kscience.kmath.misc
|
||||||
|
|
||||||
import space.kscience.kmath.UnstableKMathAPI
|
import space.kscience.kmath.UnstableKMathAPI
|
||||||
@ -22,10 +24,9 @@ public fun <V : Comparable<V>> Buffer<V>.indicesSorted(): IntArray = permSortInd
|
|||||||
/**
|
/**
|
||||||
* Create a zero-copy virtual buffer that contains the same elements but in ascending order
|
* Create a zero-copy virtual buffer that contains the same elements but in ascending order
|
||||||
*/
|
*/
|
||||||
@OptIn(UnstableKMathAPI::class)
|
|
||||||
public fun <V : Comparable<V>> Buffer<V>.sorted(): Buffer<V> {
|
public fun <V : Comparable<V>> Buffer<V>.sorted(): Buffer<V> {
|
||||||
val permutations = indicesSorted()
|
val permutations = indicesSorted()
|
||||||
return VirtualBuffer(size) { this[permutations[it]] }
|
return VirtualBuffer(type, size) { this[permutations[it]] }
|
||||||
}
|
}
|
||||||
|
|
||||||
@UnstableKMathAPI
|
@UnstableKMathAPI
|
||||||
@ -35,30 +36,27 @@ public fun <V : Comparable<V>> Buffer<V>.indicesSortedDescending(): IntArray =
|
|||||||
/**
|
/**
|
||||||
* Create a zero-copy virtual buffer that contains the same elements but in descending order
|
* Create a zero-copy virtual buffer that contains the same elements but in descending order
|
||||||
*/
|
*/
|
||||||
@OptIn(UnstableKMathAPI::class)
|
|
||||||
public fun <V : Comparable<V>> Buffer<V>.sortedDescending(): Buffer<V> {
|
public fun <V : Comparable<V>> Buffer<V>.sortedDescending(): Buffer<V> {
|
||||||
val permutations = indicesSortedDescending()
|
val permutations = indicesSortedDescending()
|
||||||
return VirtualBuffer(size) { this[permutations[it]] }
|
return VirtualBuffer(type, size) { this[permutations[it]] }
|
||||||
}
|
}
|
||||||
|
|
||||||
@UnstableKMathAPI
|
@UnstableKMathAPI
|
||||||
public fun <V, C : Comparable<C>> Buffer<V>.indicesSortedBy(selector: (V) -> C): IntArray =
|
public fun <V, C : Comparable<C>> Buffer<V>.indicesSortedBy(selector: (V) -> C): IntArray =
|
||||||
permSortIndicesWith(compareBy { selector(get(it)) })
|
permSortIndicesWith(compareBy { selector(get(it)) })
|
||||||
|
|
||||||
@OptIn(UnstableKMathAPI::class)
|
|
||||||
public fun <V, C : Comparable<C>> Buffer<V>.sortedBy(selector: (V) -> C): Buffer<V> {
|
public fun <V, C : Comparable<C>> Buffer<V>.sortedBy(selector: (V) -> C): Buffer<V> {
|
||||||
val permutations = indicesSortedBy(selector)
|
val permutations = indicesSortedBy(selector)
|
||||||
return VirtualBuffer(size) { this[permutations[it]] }
|
return VirtualBuffer(type, size) { this[permutations[it]] }
|
||||||
}
|
}
|
||||||
|
|
||||||
@UnstableKMathAPI
|
@UnstableKMathAPI
|
||||||
public fun <V, C : Comparable<C>> Buffer<V>.indicesSortedByDescending(selector: (V) -> C): IntArray =
|
public fun <V, C : Comparable<C>> Buffer<V>.indicesSortedByDescending(selector: (V) -> C): IntArray =
|
||||||
permSortIndicesWith(compareByDescending { selector(get(it)) })
|
permSortIndicesWith(compareByDescending { selector(get(it)) })
|
||||||
|
|
||||||
@OptIn(UnstableKMathAPI::class)
|
|
||||||
public fun <V, C : Comparable<C>> Buffer<V>.sortedByDescending(selector: (V) -> C): Buffer<V> {
|
public fun <V, C : Comparable<C>> Buffer<V>.sortedByDescending(selector: (V) -> C): Buffer<V> {
|
||||||
val permutations = indicesSortedByDescending(selector)
|
val permutations = indicesSortedByDescending(selector)
|
||||||
return VirtualBuffer(size) { this[permutations[it]] }
|
return VirtualBuffer(type, size) { this[permutations[it]] }
|
||||||
}
|
}
|
||||||
|
|
||||||
@UnstableKMathAPI
|
@UnstableKMathAPI
|
||||||
@ -68,10 +66,9 @@ public fun <V> Buffer<V>.indicesSortedWith(comparator: Comparator<V>): IntArray
|
|||||||
/**
|
/**
|
||||||
* Create virtual zero-copy buffer with elements sorted by [comparator]
|
* Create virtual zero-copy buffer with elements sorted by [comparator]
|
||||||
*/
|
*/
|
||||||
@OptIn(UnstableKMathAPI::class)
|
|
||||||
public fun <V> Buffer<V>.sortedWith(comparator: Comparator<V>): Buffer<V> {
|
public fun <V> Buffer<V>.sortedWith(comparator: Comparator<V>): Buffer<V> {
|
||||||
val permutations = indicesSortedWith(comparator)
|
val permutations = indicesSortedWith(comparator)
|
||||||
return VirtualBuffer(size) { this[permutations[it]] }
|
return VirtualBuffer(type,size) { this[permutations[it]] }
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun <V> Buffer<V>.permSortIndicesWith(comparator: Comparator<Int>): IntArray {
|
private fun <V> Buffer<V>.permSortIndicesWith(comparator: Comparator<Int>): IntArray {
|
||||||
|
@ -8,7 +8,7 @@ package space.kscience.kmath.nd
|
|||||||
import space.kscience.kmath.PerformancePitfall
|
import space.kscience.kmath.PerformancePitfall
|
||||||
import space.kscience.kmath.UnstableKMathAPI
|
import space.kscience.kmath.UnstableKMathAPI
|
||||||
import space.kscience.kmath.operations.*
|
import space.kscience.kmath.operations.*
|
||||||
import kotlin.reflect.KClass
|
import space.kscience.kmath.structures.MutableBufferFactory
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The base interface for all ND-algebra implementations.
|
* The base interface for all ND-algebra implementations.
|
||||||
@ -91,6 +91,8 @@ public interface AlgebraND<T, out C : Algebra<T>> : Algebra<StructureND<T>> {
|
|||||||
* @param A the type of group over structure elements.
|
* @param A the type of group over structure elements.
|
||||||
*/
|
*/
|
||||||
public interface GroupOpsND<T, out A : GroupOps<T>> : GroupOps<StructureND<T>>, AlgebraND<T, A> {
|
public interface GroupOpsND<T, out A : GroupOps<T>> : GroupOps<StructureND<T>>, AlgebraND<T, A> {
|
||||||
|
override val bufferFactory: MutableBufferFactory<StructureND<T>> get() = MutableBufferFactory<StructureND<T>>()
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Element-wise addition.
|
* Element-wise addition.
|
||||||
*
|
*
|
||||||
|
@ -7,6 +7,8 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.nd
|
package space.kscience.kmath.nd
|
||||||
|
|
||||||
|
import space.kscience.attributes.SafeType
|
||||||
|
import space.kscience.attributes.safeTypeOf
|
||||||
import space.kscience.kmath.PerformancePitfall
|
import space.kscience.kmath.PerformancePitfall
|
||||||
import space.kscience.kmath.UnstableKMathAPI
|
import space.kscience.kmath.UnstableKMathAPI
|
||||||
import space.kscience.kmath.operations.*
|
import space.kscience.kmath.operations.*
|
||||||
@ -105,6 +107,9 @@ public open class BufferedGroupNDOps<T, out A : Group<T>>(
|
|||||||
override val bufferAlgebra: BufferAlgebra<T, A>,
|
override val bufferAlgebra: BufferAlgebra<T, A>,
|
||||||
override val indexerBuilder: (ShapeND) -> ShapeIndexer = BufferAlgebraND.defaultIndexerBuilder,
|
override val indexerBuilder: (ShapeND) -> ShapeIndexer = BufferAlgebraND.defaultIndexerBuilder,
|
||||||
) : GroupOpsND<T, A>, BufferAlgebraND<T, A> {
|
) : GroupOpsND<T, A>, BufferAlgebraND<T, A> {
|
||||||
|
|
||||||
|
override val type: SafeType<StructureND<T>> get() = safeTypeOf<StructureND<T>>()
|
||||||
|
|
||||||
override fun StructureND<T>.unaryMinus(): StructureND<T> = map { -it }
|
override fun StructureND<T>.unaryMinus(): StructureND<T> = map { -it }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -5,6 +5,7 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.nd
|
package space.kscience.kmath.nd
|
||||||
|
|
||||||
|
import space.kscience.attributes.SafeType
|
||||||
import space.kscience.kmath.PerformancePitfall
|
import space.kscience.kmath.PerformancePitfall
|
||||||
import space.kscience.kmath.structures.Buffer
|
import space.kscience.kmath.structures.Buffer
|
||||||
import space.kscience.kmath.structures.BufferFactory
|
import space.kscience.kmath.structures.BufferFactory
|
||||||
@ -23,6 +24,8 @@ public open class BufferND<out T>(
|
|||||||
public open val buffer: Buffer<T>,
|
public open val buffer: Buffer<T>,
|
||||||
) : StructureND<T> {
|
) : StructureND<T> {
|
||||||
|
|
||||||
|
override val type: SafeType<T> get() = buffer.type
|
||||||
|
|
||||||
@PerformancePitfall
|
@PerformancePitfall
|
||||||
override operator fun get(index: IntArray): T = buffer[indices.offset(index)]
|
override operator fun get(index: IntArray): T = buffer[indices.offset(index)]
|
||||||
|
|
||||||
@ -36,7 +39,7 @@ public open class BufferND<out T>(
|
|||||||
*/
|
*/
|
||||||
public inline fun <reified T> BufferND(
|
public inline fun <reified T> BufferND(
|
||||||
shape: ShapeND,
|
shape: ShapeND,
|
||||||
bufferFactory: BufferFactory<T> = BufferFactory.auto<T>(),
|
bufferFactory: BufferFactory<T> = BufferFactory<T>(),
|
||||||
crossinline initializer: (IntArray) -> T,
|
crossinline initializer: (IntArray) -> T,
|
||||||
): BufferND<T> {
|
): BufferND<T> {
|
||||||
val strides = Strides(shape)
|
val strides = Strides(shape)
|
||||||
@ -84,10 +87,10 @@ public open class MutableBufferND<T>(
|
|||||||
/**
|
/**
|
||||||
* Create a generic [BufferND] using provided [initializer]
|
* Create a generic [BufferND] using provided [initializer]
|
||||||
*/
|
*/
|
||||||
public fun <T> MutableBufferND(
|
public inline fun <reified T> MutableBufferND(
|
||||||
shape: ShapeND,
|
shape: ShapeND,
|
||||||
bufferFactory: MutableBufferFactory<T> = MutableBufferFactory.boxing(),
|
bufferFactory: MutableBufferFactory<T> = MutableBufferFactory(),
|
||||||
initializer: (IntArray) -> T,
|
crossinline initializer: (IntArray) -> T,
|
||||||
): MutableBufferND<T> {
|
): MutableBufferND<T> {
|
||||||
val strides = Strides(shape)
|
val strides = Strides(shape)
|
||||||
return MutableBufferND(strides, bufferFactory(strides.linearSize) { initializer(strides.index(it)) })
|
return MutableBufferND(strides, bufferFactory(strides.linearSize) { initializer(strides.index(it)) })
|
||||||
|
@ -5,6 +5,7 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.nd
|
package space.kscience.kmath.nd
|
||||||
|
|
||||||
|
import space.kscience.attributes.SafeType
|
||||||
import space.kscience.kmath.PerformancePitfall
|
import space.kscience.kmath.PerformancePitfall
|
||||||
|
|
||||||
|
|
||||||
@ -13,6 +14,8 @@ public class PermutedStructureND<T>(
|
|||||||
public val permutation: (IntArray) -> IntArray,
|
public val permutation: (IntArray) -> IntArray,
|
||||||
) : StructureND<T> {
|
) : StructureND<T> {
|
||||||
|
|
||||||
|
override val type: SafeType<T> get() = origin.type
|
||||||
|
|
||||||
override val shape: ShapeND
|
override val shape: ShapeND
|
||||||
get() = origin.shape
|
get() = origin.shape
|
||||||
|
|
||||||
@ -32,6 +35,7 @@ public class PermutedMutableStructureND<T>(
|
|||||||
public val permutation: (IntArray) -> IntArray,
|
public val permutation: (IntArray) -> IntArray,
|
||||||
) : MutableStructureND<T> {
|
) : MutableStructureND<T> {
|
||||||
|
|
||||||
|
override val type: SafeType<T> get() = origin.type
|
||||||
|
|
||||||
@OptIn(PerformancePitfall::class)
|
@OptIn(PerformancePitfall::class)
|
||||||
override fun set(index: IntArray, value: T) {
|
override fun set(index: IntArray, value: T) {
|
||||||
|
@ -5,6 +5,7 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.nd
|
package space.kscience.kmath.nd
|
||||||
|
|
||||||
|
import space.kscience.attributes.SafeType
|
||||||
import space.kscience.kmath.PerformancePitfall
|
import space.kscience.kmath.PerformancePitfall
|
||||||
import space.kscience.kmath.operations.asSequence
|
import space.kscience.kmath.operations.asSequence
|
||||||
import space.kscience.kmath.structures.Buffer
|
import space.kscience.kmath.structures.Buffer
|
||||||
@ -46,6 +47,9 @@ public interface MutableStructure1D<T> : Structure1D<T>, MutableStructureND<T>,
|
|||||||
*/
|
*/
|
||||||
@JvmInline
|
@JvmInline
|
||||||
private value class Structure1DWrapper<out T>(val structure: StructureND<T>) : Structure1D<T> {
|
private value class Structure1DWrapper<out T>(val structure: StructureND<T>) : Structure1D<T> {
|
||||||
|
|
||||||
|
override val type: SafeType<T> get() = structure.type
|
||||||
|
|
||||||
override val shape: ShapeND get() = structure.shape
|
override val shape: ShapeND get() = structure.shape
|
||||||
override val size: Int get() = structure.shape[0]
|
override val size: Int get() = structure.shape[0]
|
||||||
|
|
||||||
@ -60,6 +64,9 @@ private value class Structure1DWrapper<out T>(val structure: StructureND<T>) : S
|
|||||||
* A 1D wrapper for a mutable nd-structure
|
* A 1D wrapper for a mutable nd-structure
|
||||||
*/
|
*/
|
||||||
private class MutableStructure1DWrapper<T>(val structure: MutableStructureND<T>) : MutableStructure1D<T> {
|
private class MutableStructure1DWrapper<T>(val structure: MutableStructureND<T>) : MutableStructure1D<T> {
|
||||||
|
|
||||||
|
override val type: SafeType<T> get() = structure.type
|
||||||
|
|
||||||
override val shape: ShapeND get() = structure.shape
|
override val shape: ShapeND get() = structure.shape
|
||||||
override val size: Int get() = structure.shape[0]
|
override val size: Int get() = structure.shape[0]
|
||||||
|
|
||||||
@ -79,7 +86,7 @@ private class MutableStructure1DWrapper<T>(val structure: MutableStructureND<T>)
|
|||||||
.elements()
|
.elements()
|
||||||
.map(Pair<IntArray, T>::second)
|
.map(Pair<IntArray, T>::second)
|
||||||
.toMutableList()
|
.toMutableList()
|
||||||
.asMutableBuffer()
|
.asMutableBuffer(type)
|
||||||
|
|
||||||
override fun toString(): String = Buffer.toString(this)
|
override fun toString(): String = Buffer.toString(this)
|
||||||
}
|
}
|
||||||
@ -90,6 +97,9 @@ private class MutableStructure1DWrapper<T>(val structure: MutableStructureND<T>)
|
|||||||
*/
|
*/
|
||||||
@JvmInline
|
@JvmInline
|
||||||
private value class Buffer1DWrapper<out T>(val buffer: Buffer<T>) : Structure1D<T> {
|
private value class Buffer1DWrapper<out T>(val buffer: Buffer<T>) : Structure1D<T> {
|
||||||
|
|
||||||
|
override val type: SafeType<T> get() = buffer.type
|
||||||
|
|
||||||
override val shape: ShapeND get() = ShapeND(buffer.size)
|
override val shape: ShapeND get() = ShapeND(buffer.size)
|
||||||
override val size: Int get() = buffer.size
|
override val size: Int get() = buffer.size
|
||||||
|
|
||||||
@ -102,6 +112,9 @@ private value class Buffer1DWrapper<out T>(val buffer: Buffer<T>) : Structure1D<
|
|||||||
}
|
}
|
||||||
|
|
||||||
internal class MutableBuffer1DWrapper<T>(val buffer: MutableBuffer<T>) : MutableStructure1D<T> {
|
internal class MutableBuffer1DWrapper<T>(val buffer: MutableBuffer<T>) : MutableStructure1D<T> {
|
||||||
|
|
||||||
|
override val type: SafeType<T> get() = buffer.type
|
||||||
|
|
||||||
override val shape: ShapeND get() = ShapeND(buffer.size)
|
override val shape: ShapeND get() = ShapeND(buffer.size)
|
||||||
override val size: Int get() = buffer.size
|
override val size: Int get() = buffer.size
|
||||||
|
|
||||||
|
@ -6,13 +6,12 @@
|
|||||||
package space.kscience.kmath.nd
|
package space.kscience.kmath.nd
|
||||||
|
|
||||||
import space.kscience.attributes.Attributes
|
import space.kscience.attributes.Attributes
|
||||||
|
import space.kscience.attributes.SafeType
|
||||||
import space.kscience.kmath.PerformancePitfall
|
import space.kscience.kmath.PerformancePitfall
|
||||||
import space.kscience.kmath.structures.Buffer
|
import space.kscience.kmath.structures.Buffer
|
||||||
import space.kscience.kmath.structures.MutableBuffer
|
import space.kscience.kmath.structures.MutableBuffer
|
||||||
import space.kscience.kmath.structures.MutableListBuffer
|
|
||||||
import space.kscience.kmath.structures.VirtualBuffer
|
import space.kscience.kmath.structures.VirtualBuffer
|
||||||
import kotlin.jvm.JvmInline
|
import kotlin.jvm.JvmInline
|
||||||
import kotlin.reflect.KClass
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A structure that is guaranteed to be two-dimensional.
|
* A structure that is guaranteed to be two-dimensional.
|
||||||
@ -33,18 +32,18 @@ public interface Structure2D<out T> : StructureND<T> {
|
|||||||
override val shape: ShapeND get() = ShapeND(rowNum, colNum)
|
override val shape: ShapeND get() = ShapeND(rowNum, colNum)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The buffer of rows of this structure. It gets elements from the structure dynamically.
|
* The buffer of rows for this structure. It gets elements from the structure dynamically.
|
||||||
*/
|
*/
|
||||||
@PerformancePitfall
|
@PerformancePitfall
|
||||||
public val rows: List<Buffer<T>>
|
public val rows: List<Buffer<T>>
|
||||||
get() = List(rowNum) { i -> VirtualBuffer(colNum) { j -> get(i, j) } }
|
get() = List(rowNum) { i -> VirtualBuffer(type, colNum) { j -> get(i, j) } }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The buffer of columns of this structure. It gets elements from the structure dynamically.
|
* The buffer of columns for this structure. It gets elements from the structure dynamically.
|
||||||
*/
|
*/
|
||||||
@PerformancePitfall
|
@PerformancePitfall
|
||||||
public val columns: List<Buffer<T>>
|
public val columns: List<Buffer<T>>
|
||||||
get() = List(colNum) { j -> VirtualBuffer(rowNum) { i -> get(i, j) } }
|
get() = List(colNum) { j -> VirtualBuffer(type, rowNum) { i -> get(i, j) } }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Retrieves an element from the structure by two indices.
|
* Retrieves an element from the structure by two indices.
|
||||||
@ -88,14 +87,14 @@ public interface MutableStructure2D<T> : Structure2D<T>, MutableStructureND<T> {
|
|||||||
*/
|
*/
|
||||||
@PerformancePitfall
|
@PerformancePitfall
|
||||||
override val rows: List<MutableBuffer<T>>
|
override val rows: List<MutableBuffer<T>>
|
||||||
get() = List(rowNum) { i -> MutableListBuffer(colNum) { j -> get(i, j) } }
|
get() = List(rowNum) { i -> MutableBuffer(type, colNum) { j -> get(i, j) } }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The buffer of columns of this structure. It gets elements from the structure dynamically.
|
* The buffer of columns for this structure. It gets elements from the structure dynamically.
|
||||||
*/
|
*/
|
||||||
@PerformancePitfall
|
@PerformancePitfall
|
||||||
override val columns: List<MutableBuffer<T>>
|
override val columns: List<MutableBuffer<T>>
|
||||||
get() = List(colNum) { j -> MutableListBuffer(rowNum) { i -> get(i, j) } }
|
get() = List(colNum) { j -> MutableBuffer(type, rowNum) { i -> get(i, j) } }
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -103,6 +102,9 @@ public interface MutableStructure2D<T> : Structure2D<T>, MutableStructureND<T> {
|
|||||||
*/
|
*/
|
||||||
@JvmInline
|
@JvmInline
|
||||||
private value class Structure2DWrapper<out T>(val structure: StructureND<T>) : Structure2D<T> {
|
private value class Structure2DWrapper<out T>(val structure: StructureND<T>) : Structure2D<T> {
|
||||||
|
|
||||||
|
override val type: SafeType<T> get() = structure.type
|
||||||
|
|
||||||
override val shape: ShapeND get() = structure.shape
|
override val shape: ShapeND get() = structure.shape
|
||||||
|
|
||||||
override val rowNum: Int get() = shape[0]
|
override val rowNum: Int get() = shape[0]
|
||||||
@ -122,6 +124,9 @@ private value class Structure2DWrapper<out T>(val structure: StructureND<T>) : S
|
|||||||
* A 2D wrapper for a mutable nd-structure
|
* A 2D wrapper for a mutable nd-structure
|
||||||
*/
|
*/
|
||||||
private class MutableStructure2DWrapper<T>(val structure: MutableStructureND<T>) : MutableStructure2D<T> {
|
private class MutableStructure2DWrapper<T>(val structure: MutableStructureND<T>) : MutableStructure2D<T> {
|
||||||
|
|
||||||
|
override val type: SafeType<T> get() = structure.type
|
||||||
|
|
||||||
override val shape: ShapeND get() = structure.shape
|
override val shape: ShapeND get() = structure.shape
|
||||||
|
|
||||||
override val rowNum: Int get() = shape[0]
|
override val rowNum: Int get() = shape[0]
|
||||||
|
@ -12,6 +12,7 @@ import space.kscience.attributes.SafeType
|
|||||||
import space.kscience.kmath.PerformancePitfall
|
import space.kscience.kmath.PerformancePitfall
|
||||||
import space.kscience.kmath.linear.LinearSpace
|
import space.kscience.kmath.linear.LinearSpace
|
||||||
import space.kscience.kmath.operations.Ring
|
import space.kscience.kmath.operations.Ring
|
||||||
|
import space.kscience.kmath.operations.WithType
|
||||||
import space.kscience.kmath.operations.invoke
|
import space.kscience.kmath.operations.invoke
|
||||||
import space.kscience.kmath.structures.Buffer
|
import space.kscience.kmath.structures.Buffer
|
||||||
import kotlin.jvm.JvmName
|
import kotlin.jvm.JvmName
|
||||||
@ -28,9 +29,9 @@ public interface StructureAttribute<T> : Attribute<T>
|
|||||||
*
|
*
|
||||||
* @param T the type of items.
|
* @param T the type of items.
|
||||||
*/
|
*/
|
||||||
public interface StructureND<out T> : AttributeContainer, WithShape {
|
public interface StructureND<out T> : AttributeContainer, WithShape, WithType<T> {
|
||||||
/**
|
/**
|
||||||
* The shape of structure i.e., non-empty sequence of non-negative integers that specify sizes of dimensions of
|
* The shape of structure i.e., non-empty sequence of non-negative integers that specify sizes of dimensions for
|
||||||
* this structure.
|
* this structure.
|
||||||
*/
|
*/
|
||||||
override val shape: ShapeND
|
override val shape: ShapeND
|
||||||
@ -117,57 +118,61 @@ public interface StructureND<out T> : AttributeContainer, WithShape {
|
|||||||
return "$className(shape=${structure.shape}, buffer=$bufferRepr)"
|
return "$className(shape=${structure.shape}, buffer=$bufferRepr)"
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Creates a NDStructure with explicit buffer factory.
|
|
||||||
*
|
|
||||||
* Strides should be reused if possible.
|
|
||||||
*/
|
|
||||||
public fun <T> buffered(
|
|
||||||
strides: Strides,
|
|
||||||
initializer: (IntArray) -> T,
|
|
||||||
): BufferND<T> = BufferND(strides, Buffer.boxing(strides.linearSize) { i -> initializer(strides.index(i)) })
|
|
||||||
|
|
||||||
|
|
||||||
public fun <T> buffered(
|
|
||||||
shape: ShapeND,
|
|
||||||
initializer: (IntArray) -> T,
|
|
||||||
): BufferND<T> = buffered(ColumnStrides(shape), initializer)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Inline create NDStructure with non-boxing buffer implementation if it is possible
|
|
||||||
*/
|
|
||||||
public inline fun <reified T : Any> auto(
|
|
||||||
strides: Strides,
|
|
||||||
crossinline initializer: (IntArray) -> T,
|
|
||||||
): BufferND<T> = BufferND(strides, Buffer.auto(strides.linearSize) { i -> initializer(strides.index(i)) })
|
|
||||||
|
|
||||||
public inline fun <T : Any> auto(
|
|
||||||
type: SafeType<T>,
|
|
||||||
strides: Strides,
|
|
||||||
crossinline initializer: (IntArray) -> T,
|
|
||||||
): BufferND<T> = BufferND(strides, Buffer.auto(type, strides.linearSize) { i -> initializer(strides.index(i)) })
|
|
||||||
|
|
||||||
|
|
||||||
public inline fun <reified T : Any> auto(
|
|
||||||
shape: ShapeND,
|
|
||||||
crossinline initializer: (IntArray) -> T,
|
|
||||||
): BufferND<T> = auto(ColumnStrides(shape), initializer)
|
|
||||||
|
|
||||||
@JvmName("autoVarArg")
|
|
||||||
public inline fun <reified T : Any> auto(
|
|
||||||
vararg shape: Int,
|
|
||||||
crossinline initializer: (IntArray) -> T,
|
|
||||||
): BufferND<T> =
|
|
||||||
auto(ColumnStrides(ShapeND(shape)), initializer)
|
|
||||||
|
|
||||||
public inline fun <T : Any> auto(
|
|
||||||
type: SafeType<T>,
|
|
||||||
vararg shape: Int,
|
|
||||||
crossinline initializer: (IntArray) -> T,
|
|
||||||
): BufferND<T> = auto(type, ColumnStrides(ShapeND(shape)), initializer)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a NDStructure with explicit buffer factory.
|
||||||
|
*
|
||||||
|
* Strides should be reused if possible.
|
||||||
|
*/
|
||||||
|
public fun <T> BufferND(
|
||||||
|
type: SafeType<T>,
|
||||||
|
strides: Strides,
|
||||||
|
initializer: (IntArray) -> T,
|
||||||
|
): BufferND<T> = BufferND(strides, Buffer(type, strides.linearSize) { i -> initializer(strides.index(i)) })
|
||||||
|
|
||||||
|
|
||||||
|
public fun <T> BufferND(
|
||||||
|
type: SafeType<T>,
|
||||||
|
shape: ShapeND,
|
||||||
|
initializer: (IntArray) -> T,
|
||||||
|
): BufferND<T> = BufferND(type, ColumnStrides(shape), initializer)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Inline create NDStructure with non-boxing buffer implementation if it is possible
|
||||||
|
*/
|
||||||
|
public inline fun <reified T : Any> BufferND(
|
||||||
|
strides: Strides,
|
||||||
|
crossinline initializer: (IntArray) -> T,
|
||||||
|
): BufferND<T> = BufferND(strides, Buffer(strides.linearSize) { i -> initializer(strides.index(i)) })
|
||||||
|
|
||||||
|
public inline fun <T : Any> BufferND(
|
||||||
|
type: SafeType<T>,
|
||||||
|
strides: Strides,
|
||||||
|
crossinline initializer: (IntArray) -> T,
|
||||||
|
): BufferND<T> = BufferND(strides, Buffer(type, strides.linearSize) { i -> initializer(strides.index(i)) })
|
||||||
|
|
||||||
|
|
||||||
|
public inline fun <reified T : Any> BufferND(
|
||||||
|
shape: ShapeND,
|
||||||
|
crossinline initializer: (IntArray) -> T,
|
||||||
|
): BufferND<T> = BufferND(ColumnStrides(shape), initializer)
|
||||||
|
|
||||||
|
@JvmName("autoVarArg")
|
||||||
|
public inline fun <reified T : Any> BufferND(
|
||||||
|
vararg shape: Int,
|
||||||
|
crossinline initializer: (IntArray) -> T,
|
||||||
|
): BufferND<T> = BufferND(ColumnStrides(ShapeND(shape)), initializer)
|
||||||
|
|
||||||
|
public inline fun <T : Any> BufferND(
|
||||||
|
type: SafeType<T>,
|
||||||
|
vararg shape: Int,
|
||||||
|
crossinline initializer: (IntArray) -> T,
|
||||||
|
): BufferND<T> = BufferND(type, ColumnStrides(ShapeND(shape)), initializer)
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Indicates whether some [StructureND] is equal to another one.
|
* Indicates whether some [StructureND] is equal to another one.
|
||||||
*/
|
*/
|
||||||
|
@ -5,10 +5,13 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.nd
|
package space.kscience.kmath.nd
|
||||||
|
|
||||||
|
import space.kscience.attributes.SafeType
|
||||||
|
import space.kscience.attributes.safeTypeOf
|
||||||
import space.kscience.kmath.PerformancePitfall
|
import space.kscience.kmath.PerformancePitfall
|
||||||
import space.kscience.kmath.UnstableKMathAPI
|
import space.kscience.kmath.UnstableKMathAPI
|
||||||
|
|
||||||
public open class VirtualStructureND<T>(
|
public open class VirtualStructureND<T>(
|
||||||
|
override val type: SafeType<T>,
|
||||||
override val shape: ShapeND,
|
override val shape: ShapeND,
|
||||||
public val producer: (IntArray) -> T,
|
public val producer: (IntArray) -> T,
|
||||||
) : StructureND<T> {
|
) : StructureND<T> {
|
||||||
@ -24,10 +27,10 @@ public open class VirtualStructureND<T>(
|
|||||||
public class VirtualDoubleStructureND(
|
public class VirtualDoubleStructureND(
|
||||||
shape: ShapeND,
|
shape: ShapeND,
|
||||||
producer: (IntArray) -> Double,
|
producer: (IntArray) -> Double,
|
||||||
) : VirtualStructureND<Double>(shape, producer)
|
) : VirtualStructureND<Double>(safeTypeOf(), shape, producer)
|
||||||
|
|
||||||
@UnstableKMathAPI
|
@UnstableKMathAPI
|
||||||
public class VirtualIntStructureND(
|
public class VirtualIntStructureND(
|
||||||
shape: ShapeND,
|
shape: ShapeND,
|
||||||
producer: (IntArray) -> Int,
|
producer: (IntArray) -> Int,
|
||||||
) : VirtualStructureND<Int>(shape, producer)
|
) : VirtualStructureND<Int>(safeTypeOf(), shape, producer)
|
@ -10,7 +10,7 @@ import space.kscience.kmath.PerformancePitfall
|
|||||||
@OptIn(PerformancePitfall::class)
|
@OptIn(PerformancePitfall::class)
|
||||||
public fun <T> StructureND<T>.roll(axis: Int, step: Int = 1): StructureND<T> {
|
public fun <T> StructureND<T>.roll(axis: Int, step: Int = 1): StructureND<T> {
|
||||||
require(axis in shape.indices) { "Axis $axis is outside of shape dimensions: [0, ${shape.size})" }
|
require(axis in shape.indices) { "Axis $axis is outside of shape dimensions: [0, ${shape.size})" }
|
||||||
return VirtualStructureND(shape) { index ->
|
return VirtualStructureND(type, shape) { index ->
|
||||||
val newIndex: IntArray = IntArray(index.size) { indexAxis ->
|
val newIndex: IntArray = IntArray(index.size) { indexAxis ->
|
||||||
if (indexAxis == axis) {
|
if (indexAxis == axis) {
|
||||||
(index[indexAxis] + step).mod(shape[indexAxis])
|
(index[indexAxis] + step).mod(shape[indexAxis])
|
||||||
@ -26,7 +26,7 @@ public fun <T> StructureND<T>.roll(axis: Int, step: Int = 1): StructureND<T> {
|
|||||||
public fun <T> StructureND<T>.roll(pair: Pair<Int, Int>, vararg others: Pair<Int, Int>): StructureND<T> {
|
public fun <T> StructureND<T>.roll(pair: Pair<Int, Int>, vararg others: Pair<Int, Int>): StructureND<T> {
|
||||||
val axisMap: Map<Int, Int> = mapOf(pair, *others)
|
val axisMap: Map<Int, Int> = mapOf(pair, *others)
|
||||||
require(axisMap.keys.all { it in shape.indices }) { "Some of axes ${axisMap.keys} is outside of shape dimensions: [0, ${shape.size})" }
|
require(axisMap.keys.all { it in shape.indices }) { "Some of axes ${axisMap.keys} is outside of shape dimensions: [0, ${shape.size})" }
|
||||||
return VirtualStructureND(shape) { index ->
|
return VirtualStructureND(type, shape) { index ->
|
||||||
val newIndex: IntArray = IntArray(index.size) { indexAxis ->
|
val newIndex: IntArray = IntArray(index.size) { indexAxis ->
|
||||||
val offset = axisMap[indexAxis] ?: 0
|
val offset = axisMap[indexAxis] ?: 0
|
||||||
(index[indexAxis] + offset).mod(shape[indexAxis])
|
(index[indexAxis] + offset).mod(shape[indexAxis])
|
||||||
|
@ -5,22 +5,32 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.operations
|
package space.kscience.kmath.operations
|
||||||
|
|
||||||
|
import space.kscience.attributes.SafeType
|
||||||
import space.kscience.kmath.UnstableKMathAPI
|
import space.kscience.kmath.UnstableKMathAPI
|
||||||
import space.kscience.kmath.expressions.Symbol
|
import space.kscience.kmath.expressions.Symbol
|
||||||
import space.kscience.kmath.operations.Ring.Companion.optimizedPower
|
import space.kscience.kmath.operations.Ring.Companion.optimizedPower
|
||||||
import space.kscience.kmath.structures.MutableBufferFactory
|
import space.kscience.kmath.structures.MutableBufferFactory
|
||||||
import kotlin.reflect.KType
|
|
||||||
|
/**
|
||||||
|
* An interface containing [type] for dynamic type checking.
|
||||||
|
*/
|
||||||
|
public interface WithType<out T> {
|
||||||
|
public val type: SafeType<T>
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Represents an algebraic structure.
|
* Represents an algebraic structure.
|
||||||
*
|
*
|
||||||
* @param T the type of element which Algebra operates on.
|
* @param T the type of element which Algebra operates on.
|
||||||
*/
|
*/
|
||||||
public interface Algebra<T> {
|
public interface Algebra<T> : WithType<T> {
|
||||||
/**
|
/**
|
||||||
* Provide a factory for buffers, associated with this [Algebra]
|
* Provide a factory for buffers, associated with this [Algebra]
|
||||||
*/
|
*/
|
||||||
public val bufferFactory: MutableBufferFactory<T> get() = MutableBufferFactory.boxing()
|
public val bufferFactory: MutableBufferFactory<T>
|
||||||
|
|
||||||
|
override val type: SafeType<T> get() = bufferFactory.type
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Wraps a raw string to [T] object. This method is designed for three purposes:
|
* Wraps a raw string to [T] object. This method is designed for three purposes:
|
||||||
@ -265,7 +275,7 @@ public interface Ring<T> : Group<T>, RingOps<T> {
|
|||||||
*/
|
*/
|
||||||
public fun power(arg: T, pow: UInt): T = optimizedPower(arg, pow)
|
public fun power(arg: T, pow: UInt): T = optimizedPower(arg, pow)
|
||||||
|
|
||||||
public companion object{
|
public companion object {
|
||||||
/**
|
/**
|
||||||
* Raises [arg] to the non-negative integer power [exponent].
|
* Raises [arg] to the non-negative integer power [exponent].
|
||||||
*
|
*
|
||||||
@ -348,7 +358,7 @@ public interface Field<T> : Ring<T>, FieldOps<T>, ScaleOperations<T>, NumericAlg
|
|||||||
|
|
||||||
public fun power(arg: T, pow: Int): T = optimizedPower(arg, pow)
|
public fun power(arg: T, pow: Int): T = optimizedPower(arg, pow)
|
||||||
|
|
||||||
public companion object{
|
public companion object {
|
||||||
/**
|
/**
|
||||||
* Raises [arg] to the integer power [exponent].
|
* Raises [arg] to the integer power [exponent].
|
||||||
*
|
*
|
||||||
@ -361,7 +371,11 @@ public interface Field<T> : Ring<T>, FieldOps<T>, ScaleOperations<T>, NumericAlg
|
|||||||
* @author Iaroslav Postovalov, Evgeniy Zhelenskiy
|
* @author Iaroslav Postovalov, Evgeniy Zhelenskiy
|
||||||
*/
|
*/
|
||||||
private fun <T> Field<T>.optimizedPower(arg: T, exponent: Int): T = when {
|
private fun <T> Field<T>.optimizedPower(arg: T, exponent: Int): T = when {
|
||||||
exponent < 0 -> one / (this as Ring<T>).optimizedPower(arg, if (exponent == Int.MIN_VALUE) Int.MAX_VALUE.toUInt().inc() else (-exponent).toUInt())
|
exponent < 0 -> one / (this as Ring<T>).optimizedPower(
|
||||||
|
arg,
|
||||||
|
if (exponent == Int.MIN_VALUE) Int.MAX_VALUE.toUInt().inc() else (-exponent).toUInt()
|
||||||
|
)
|
||||||
|
|
||||||
else -> (this as Ring<T>).optimizedPower(arg, exponent.toUInt())
|
else -> (this as Ring<T>).optimizedPower(arg, exponent.toUInt())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -5,6 +5,8 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.operations
|
package space.kscience.kmath.operations
|
||||||
|
|
||||||
|
import space.kscience.attributes.SafeType
|
||||||
|
import space.kscience.attributes.safeTypeOf
|
||||||
import space.kscience.kmath.UnstableKMathAPI
|
import space.kscience.kmath.UnstableKMathAPI
|
||||||
import space.kscience.kmath.nd.BufferedRingOpsND
|
import space.kscience.kmath.nd.BufferedRingOpsND
|
||||||
import space.kscience.kmath.operations.BigInt.Companion.BASE
|
import space.kscience.kmath.operations.BigInt.Companion.BASE
|
||||||
@ -26,6 +28,9 @@ private typealias TBase = ULong
|
|||||||
*/
|
*/
|
||||||
@OptIn(UnstableKMathAPI::class)
|
@OptIn(UnstableKMathAPI::class)
|
||||||
public object BigIntField : Field<BigInt>, NumbersAddOps<BigInt>, ScaleOperations<BigInt> {
|
public object BigIntField : Field<BigInt>, NumbersAddOps<BigInt>, ScaleOperations<BigInt> {
|
||||||
|
|
||||||
|
override val type: SafeType<BigInt> = safeTypeOf()
|
||||||
|
|
||||||
override val zero: BigInt = BigInt.ZERO
|
override val zero: BigInt = BigInt.ZERO
|
||||||
override val one: BigInt = BigInt.ONE
|
override val one: BigInt = BigInt.ONE
|
||||||
|
|
||||||
@ -528,10 +533,10 @@ public fun String.parseBigInteger(): BigInt? {
|
|||||||
public val BigInt.algebra: BigIntField get() = BigIntField
|
public val BigInt.algebra: BigIntField get() = BigIntField
|
||||||
|
|
||||||
public inline fun BigInt.Companion.buffer(size: Int, initializer: (Int) -> BigInt): Buffer<BigInt> =
|
public inline fun BigInt.Companion.buffer(size: Int, initializer: (Int) -> BigInt): Buffer<BigInt> =
|
||||||
Buffer.boxing(size, initializer)
|
Buffer(size, initializer)
|
||||||
|
|
||||||
public inline fun BigInt.Companion.mutableBuffer(size: Int, initializer: (Int) -> BigInt): Buffer<BigInt> =
|
public inline fun BigInt.Companion.mutableBuffer(size: Int, initializer: (Int) -> BigInt): Buffer<BigInt> =
|
||||||
Buffer.boxing(size, initializer)
|
Buffer(size, initializer)
|
||||||
|
|
||||||
public val BigIntField.nd: BufferedRingOpsND<BigInt, BigIntField>
|
public val BigIntField.nd: BufferedRingOpsND<BigInt, BigIntField>
|
||||||
get() = BufferedRingOpsND(BufferRingOps(BigIntField))
|
get() = BufferedRingOpsND(BufferRingOps(BigIntField))
|
||||||
|
@ -5,10 +5,11 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.operations
|
package space.kscience.kmath.operations
|
||||||
|
|
||||||
|
import space.kscience.attributes.SafeType
|
||||||
|
import space.kscience.attributes.safeTypeOf
|
||||||
import space.kscience.kmath.structures.Buffer
|
import space.kscience.kmath.structures.Buffer
|
||||||
import space.kscience.kmath.structures.MutableBuffer
|
import space.kscience.kmath.structures.MutableBuffer
|
||||||
import space.kscience.kmath.structures.MutableBufferFactory
|
import space.kscience.kmath.structures.MutableBufferFactory
|
||||||
import kotlin.reflect.KType
|
|
||||||
|
|
||||||
public interface WithSize {
|
public interface WithSize {
|
||||||
public val size: Int
|
public val size: Int
|
||||||
@ -20,6 +21,8 @@ public interface WithSize {
|
|||||||
public interface BufferAlgebra<T, out A : Algebra<T>> : Algebra<Buffer<T>> {
|
public interface BufferAlgebra<T, out A : Algebra<T>> : Algebra<Buffer<T>> {
|
||||||
public val elementAlgebra: A
|
public val elementAlgebra: A
|
||||||
|
|
||||||
|
override val type: SafeType<Buffer<T>> get() = safeTypeOf<Buffer<T>>()
|
||||||
|
|
||||||
public val elementBufferFactory: MutableBufferFactory<T> get() = elementAlgebra.bufferFactory
|
public val elementBufferFactory: MutableBufferFactory<T> get() = elementAlgebra.bufferFactory
|
||||||
|
|
||||||
public fun buffer(size: Int, vararg elements: T): Buffer<T> {
|
public fun buffer(size: Int, vararg elements: T): Buffer<T> {
|
||||||
|
@ -5,6 +5,8 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.operations
|
package space.kscience.kmath.operations
|
||||||
|
|
||||||
|
import space.kscience.attributes.SafeType
|
||||||
|
import space.kscience.attributes.safeTypeOf
|
||||||
import space.kscience.kmath.UnstableKMathAPI
|
import space.kscience.kmath.UnstableKMathAPI
|
||||||
import space.kscience.kmath.expressions.Symbol
|
import space.kscience.kmath.expressions.Symbol
|
||||||
|
|
||||||
@ -72,6 +74,8 @@ public interface LogicAlgebra<T : Any> : Algebra<T> {
|
|||||||
@Suppress("EXTENSION_SHADOWED_BY_MEMBER")
|
@Suppress("EXTENSION_SHADOWED_BY_MEMBER")
|
||||||
public object BooleanAlgebra : LogicAlgebra<Boolean> {
|
public object BooleanAlgebra : LogicAlgebra<Boolean> {
|
||||||
|
|
||||||
|
override val type: SafeType<Boolean> get() = safeTypeOf()
|
||||||
|
|
||||||
override fun const(boolean: Boolean): Boolean = boolean
|
override fun const(boolean: Boolean): Boolean = boolean
|
||||||
|
|
||||||
override fun Boolean.not(): Boolean = !this
|
override fun Boolean.not(): Boolean = !this
|
||||||
|
@ -8,7 +8,10 @@ package space.kscience.kmath.operations
|
|||||||
import space.kscience.kmath.operations.Int16Field.div
|
import space.kscience.kmath.operations.Int16Field.div
|
||||||
import space.kscience.kmath.operations.Int32Field.div
|
import space.kscience.kmath.operations.Int32Field.div
|
||||||
import space.kscience.kmath.operations.Int64Field.div
|
import space.kscience.kmath.operations.Int64Field.div
|
||||||
import space.kscience.kmath.structures.*
|
import space.kscience.kmath.structures.Int16
|
||||||
|
import space.kscience.kmath.structures.Int32
|
||||||
|
import space.kscience.kmath.structures.Int64
|
||||||
|
import space.kscience.kmath.structures.MutableBufferFactory
|
||||||
import kotlin.math.roundToInt
|
import kotlin.math.roundToInt
|
||||||
import kotlin.math.roundToLong
|
import kotlin.math.roundToLong
|
||||||
|
|
||||||
@ -22,7 +25,7 @@ import kotlin.math.roundToLong
|
|||||||
*/
|
*/
|
||||||
@Suppress("EXTENSION_SHADOWED_BY_MEMBER")
|
@Suppress("EXTENSION_SHADOWED_BY_MEMBER")
|
||||||
public object Int16Field : Field<Int16>, Norm<Int16, Int16>, NumericAlgebra<Int16> {
|
public object Int16Field : Field<Int16>, Norm<Int16, Int16>, NumericAlgebra<Int16> {
|
||||||
override val bufferFactory: MutableBufferFactory<Int16> = MutableBufferFactory(::Int16Buffer)
|
override val bufferFactory: MutableBufferFactory<Int16> = MutableBufferFactory<Int16>()
|
||||||
override val zero: Int16 get() = 0
|
override val zero: Int16 get() = 0
|
||||||
override val one: Int16 get() = 1
|
override val one: Int16 get() = 1
|
||||||
|
|
||||||
@ -45,7 +48,8 @@ public object Int16Field : Field<Int16>, Norm<Int16, Int16>, NumericAlgebra<Int1
|
|||||||
*/
|
*/
|
||||||
@Suppress("EXTENSION_SHADOWED_BY_MEMBER")
|
@Suppress("EXTENSION_SHADOWED_BY_MEMBER")
|
||||||
public object Int32Field : Field<Int32>, Norm<Int32, Int32>, NumericAlgebra<Int32> {
|
public object Int32Field : Field<Int32>, Norm<Int32, Int32>, NumericAlgebra<Int32> {
|
||||||
override val bufferFactory: MutableBufferFactory<Int> = MutableBufferFactory(::Int32Buffer)
|
override val bufferFactory: MutableBufferFactory<Int> = MutableBufferFactory()
|
||||||
|
|
||||||
override val zero: Int get() = 0
|
override val zero: Int get() = 0
|
||||||
override val one: Int get() = 1
|
override val one: Int get() = 1
|
||||||
|
|
||||||
@ -68,7 +72,7 @@ public object Int32Field : Field<Int32>, Norm<Int32, Int32>, NumericAlgebra<Int3
|
|||||||
*/
|
*/
|
||||||
@Suppress("EXTENSION_SHADOWED_BY_MEMBER")
|
@Suppress("EXTENSION_SHADOWED_BY_MEMBER")
|
||||||
public object Int64Field : Field<Int64>, Norm<Int64, Int64>, NumericAlgebra<Int64> {
|
public object Int64Field : Field<Int64>, Norm<Int64, Int64>, NumericAlgebra<Int64> {
|
||||||
override val bufferFactory: MutableBufferFactory<Int64> = MutableBufferFactory(::Int64Buffer)
|
override val bufferFactory: MutableBufferFactory<Int64> = MutableBufferFactory()
|
||||||
override val zero: Int64 get() = 0L
|
override val zero: Int64 get() = 0L
|
||||||
override val one: Int64 get() = 1L
|
override val one: Int64 get() = 1L
|
||||||
|
|
||||||
|
@ -67,7 +67,7 @@ public interface ExtendedField<T> : ExtendedFieldOps<T>, Field<T>, NumericAlgebr
|
|||||||
*/
|
*/
|
||||||
@Suppress("EXTENSION_SHADOWED_BY_MEMBER")
|
@Suppress("EXTENSION_SHADOWED_BY_MEMBER")
|
||||||
public object Float64Field : ExtendedField<Double>, Norm<Double, Double>, ScaleOperations<Double> {
|
public object Float64Field : ExtendedField<Double>, Norm<Double, Double>, ScaleOperations<Double> {
|
||||||
override val bufferFactory: MutableBufferFactory<Double> = MutableBufferFactory(::Float64Buffer)
|
override val bufferFactory: MutableBufferFactory<Double> = MutableBufferFactory()
|
||||||
|
|
||||||
override val zero: Double get() = 0.0
|
override val zero: Double get() = 0.0
|
||||||
override val one: Double get() = 1.0
|
override val one: Double get() = 1.0
|
||||||
|
@ -8,8 +8,8 @@ package space.kscience.kmath.structures
|
|||||||
import space.kscience.attributes.SafeType
|
import space.kscience.attributes.SafeType
|
||||||
import space.kscience.attributes.safeTypeOf
|
import space.kscience.attributes.safeTypeOf
|
||||||
import space.kscience.kmath.operations.WithSize
|
import space.kscience.kmath.operations.WithSize
|
||||||
|
import space.kscience.kmath.operations.WithType
|
||||||
import space.kscience.kmath.operations.asSequence
|
import space.kscience.kmath.operations.asSequence
|
||||||
import kotlin.jvm.JvmInline
|
|
||||||
import kotlin.reflect.typeOf
|
import kotlin.reflect.typeOf
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -17,35 +17,46 @@ import kotlin.reflect.typeOf
|
|||||||
*
|
*
|
||||||
* @param T the type of buffer.
|
* @param T the type of buffer.
|
||||||
*/
|
*/
|
||||||
public fun interface BufferFactory<T> {
|
public interface BufferFactory<T> : WithType<T> {
|
||||||
|
|
||||||
public operator fun invoke(size: Int, builder: (Int) -> T): Buffer<T>
|
public operator fun invoke(size: Int, builder: (Int) -> T): Buffer<T>
|
||||||
|
|
||||||
public companion object{
|
|
||||||
public inline fun <reified T> auto(): BufferFactory<T> =
|
|
||||||
BufferFactory(Buffer.Companion::auto)
|
|
||||||
|
|
||||||
public fun <T> boxing(): BufferFactory<T> =
|
|
||||||
BufferFactory(Buffer.Companion::boxing)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a [BufferFactory] for given [type], using primitive storage if possible
|
||||||
|
*/
|
||||||
|
public fun <T> BufferFactory(type: SafeType<T>): BufferFactory<T> = object : BufferFactory<T> {
|
||||||
|
override val type: SafeType<T> = type
|
||||||
|
override fun invoke(size: Int, builder: (Int) -> T): Buffer<T> = Buffer(type, size, builder)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create [BufferFactory] using the reified type
|
||||||
|
*/
|
||||||
|
public inline fun <reified T> BufferFactory(): BufferFactory<T> = BufferFactory(safeTypeOf())
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Function that produces [MutableBuffer] from its size and function that supplies values.
|
* Function that produces [MutableBuffer] from its size and function that supplies values.
|
||||||
*
|
*
|
||||||
* @param T the type of buffer.
|
* @param T the type of buffer.
|
||||||
*/
|
*/
|
||||||
public fun interface MutableBufferFactory<T> : BufferFactory<T> {
|
public interface MutableBufferFactory<T> : BufferFactory<T> {
|
||||||
override fun invoke(size: Int, builder: (Int) -> T): MutableBuffer<T>
|
override fun invoke(size: Int, builder: (Int) -> T): MutableBuffer<T>
|
||||||
|
|
||||||
public companion object {
|
|
||||||
public inline fun <reified T : Any> auto(): MutableBufferFactory<T> =
|
|
||||||
MutableBufferFactory(MutableBuffer.Companion::auto)
|
|
||||||
|
|
||||||
public fun <T> boxing(): MutableBufferFactory<T> =
|
|
||||||
MutableBufferFactory(MutableBuffer.Companion::boxing)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a [MutableBufferFactory] for given [type], using primitive storage if possible
|
||||||
|
*/
|
||||||
|
public fun <T> MutableBufferFactory(type: SafeType<T>): MutableBufferFactory<T> = object : MutableBufferFactory<T> {
|
||||||
|
override val type: SafeType<T> = type
|
||||||
|
override fun invoke(size: Int, builder: (Int) -> T): MutableBuffer<T> = MutableBuffer(type, size, builder)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create [BufferFactory] using the reified type
|
||||||
|
*/
|
||||||
|
public inline fun <reified T> MutableBufferFactory(): MutableBufferFactory<T> = MutableBufferFactory(safeTypeOf())
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A generic read-only random-access structure for both primitives and objects.
|
* A generic read-only random-access structure for both primitives and objects.
|
||||||
*
|
*
|
||||||
@ -53,14 +64,14 @@ public fun interface MutableBufferFactory<T> : BufferFactory<T> {
|
|||||||
*
|
*
|
||||||
* @param T the type of elements contained in the buffer.
|
* @param T the type of elements contained in the buffer.
|
||||||
*/
|
*/
|
||||||
public interface Buffer<out T> : WithSize {
|
public interface Buffer<out T> : WithSize, WithType<T> {
|
||||||
/**
|
/**
|
||||||
* The size of this buffer.
|
* The size of this buffer.
|
||||||
*/
|
*/
|
||||||
override val size: Int
|
override val size: Int
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Gets element at given index.
|
* Gets an element at given index.
|
||||||
*/
|
*/
|
||||||
public operator fun get(index: Int): T
|
public operator fun get(index: Int): T
|
||||||
|
|
||||||
@ -87,41 +98,48 @@ public interface Buffer<out T> : WithSize {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Creates a [ListBuffer] of given type [T] with given [size]. Each element is calculated by calling the
|
|
||||||
* specified [initializer] function.
|
|
||||||
*/
|
|
||||||
public inline fun <T> boxing(size: Int, initializer: (Int) -> T): Buffer<T> =
|
|
||||||
List(size, initializer).asBuffer()
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Creates a [Buffer] of given [type]. If the type is primitive, specialized buffers are used ([Int32Buffer],
|
|
||||||
* [Float64Buffer], etc.), [ListBuffer] is returned otherwise.
|
|
||||||
*
|
|
||||||
* The [size] is specified, and each element is calculated by calling the specified [initializer] function.
|
|
||||||
*/
|
|
||||||
@Suppress("UNCHECKED_CAST")
|
|
||||||
public inline fun <T> auto(type: SafeType<T>, size: Int, initializer: (Int) -> T): Buffer<T> =
|
|
||||||
when (type.kType) {
|
|
||||||
typeOf<Double>() -> MutableBuffer.double(size) { initializer(it) as Double } as Buffer<T>
|
|
||||||
typeOf<Short>() -> MutableBuffer.short(size) { initializer(it) as Short } as Buffer<T>
|
|
||||||
typeOf<Int>() -> MutableBuffer.int(size) { initializer(it) as Int } as Buffer<T>
|
|
||||||
typeOf<Long>() -> MutableBuffer.long(size) { initializer(it) as Long } as Buffer<T>
|
|
||||||
typeOf<Float>() -> MutableBuffer.float(size) { initializer(it) as Float } as Buffer<T>
|
|
||||||
else -> boxing(size, initializer)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Creates a [Buffer] of given type [T]. If the type is primitive, specialized buffers are used ([Int32Buffer],
|
|
||||||
* [Float64Buffer], etc.), [ListBuffer] is returned otherwise.
|
|
||||||
*
|
|
||||||
* The [size] is specified, and each element is calculated by calling the specified [initializer] function.
|
|
||||||
*/
|
|
||||||
public inline fun <reified T> auto(size: Int, initializer: (Int) -> T): Buffer<T> =
|
|
||||||
auto(safeTypeOf<T>(), size, initializer)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a [Buffer] of given type [T]. If the type is primitive, specialized buffers are used ([Int32Buffer],
|
||||||
|
* [Float64Buffer], etc.), [ListBuffer] is returned otherwise.
|
||||||
|
*
|
||||||
|
* The [size] is specified, and each element is calculated by calling the specified [initializer] function.
|
||||||
|
*/
|
||||||
|
@Suppress("UNCHECKED_CAST")
|
||||||
|
public inline fun <reified T> Buffer(size: Int, initializer: (Int) -> T): Buffer<T> {
|
||||||
|
val type = safeTypeOf<T>()
|
||||||
|
return when (type.kType) {
|
||||||
|
typeOf<Double>() -> MutableBuffer.double(size) { initializer(it) as Double } as Buffer<T>
|
||||||
|
typeOf<Short>() -> MutableBuffer.short(size) { initializer(it) as Short } as Buffer<T>
|
||||||
|
typeOf<Int>() -> MutableBuffer.int(size) { initializer(it) as Int } as Buffer<T>
|
||||||
|
typeOf<Long>() -> MutableBuffer.long(size) { initializer(it) as Long } as Buffer<T>
|
||||||
|
typeOf<Float>() -> MutableBuffer.float(size) { initializer(it) as Float } as Buffer<T>
|
||||||
|
else -> List(size, initializer).asBuffer(type)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a [Buffer] of given [type]. If the type is primitive, specialized buffers are used ([Int32Buffer],
|
||||||
|
* [Float64Buffer], etc.), [ListBuffer] is returned otherwise.
|
||||||
|
*
|
||||||
|
* The [size] is specified, and each element is calculated by calling the specified [initializer] function.
|
||||||
|
*/
|
||||||
|
@Suppress("UNCHECKED_CAST")
|
||||||
|
public fun <T> Buffer(
|
||||||
|
type: SafeType<T>,
|
||||||
|
size: Int,
|
||||||
|
initializer: (Int) -> T,
|
||||||
|
): Buffer<T> = when (type.kType) {
|
||||||
|
typeOf<Double>() -> MutableBuffer.double(size) { initializer(it) as Double } as Buffer<T>
|
||||||
|
typeOf<Short>() -> MutableBuffer.short(size) { initializer(it) as Short } as Buffer<T>
|
||||||
|
typeOf<Int>() -> MutableBuffer.int(size) { initializer(it) as Int } as Buffer<T>
|
||||||
|
typeOf<Long>() -> MutableBuffer.long(size) { initializer(it) as Long } as Buffer<T>
|
||||||
|
typeOf<Float>() -> MutableBuffer.float(size) { initializer(it) as Float } as Buffer<T>
|
||||||
|
else -> List(size, initializer).asBuffer(type)
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns an [IntRange] of the valid indices for this [Buffer].
|
* Returns an [IntRange] of the valid indices for this [Buffer].
|
||||||
*/
|
*/
|
||||||
@ -144,28 +162,17 @@ public fun <T> Buffer<T>.last(): T {
|
|||||||
return get(size - 1)
|
return get(size - 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Immutable wrapper for [MutableBuffer].
|
|
||||||
*
|
|
||||||
* @param T the type of elements contained in the buffer.
|
|
||||||
* @property buffer The underlying buffer.
|
|
||||||
*/
|
|
||||||
@JvmInline
|
|
||||||
public value class ReadOnlyBuffer<T>(public val buffer: MutableBuffer<T>) : Buffer<T> {
|
|
||||||
override val size: Int get() = buffer.size
|
|
||||||
|
|
||||||
override operator fun get(index: Int): T = buffer[index]
|
|
||||||
|
|
||||||
override operator fun iterator(): Iterator<T> = buffer.iterator()
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A buffer with content calculated on-demand. The calculated content is not stored, so it is recalculated on each call.
|
* A buffer with content calculated on-demand. The calculated content is not stored, so it is recalculated on each call.
|
||||||
* Useful when one needs single element from the buffer.
|
* Useful when one needs a single element from the buffer.
|
||||||
*
|
*
|
||||||
* @param T the type of elements provided by the buffer.
|
* @param T the type of elements provided by the buffer.
|
||||||
*/
|
*/
|
||||||
public class VirtualBuffer<out T>(override val size: Int, private val generator: (Int) -> T) : Buffer<T> {
|
public class VirtualBuffer<out T>(
|
||||||
|
override val type: SafeType<T>,
|
||||||
|
override val size: Int,
|
||||||
|
private val generator: (Int) -> T,
|
||||||
|
) : Buffer<T> {
|
||||||
override operator fun get(index: Int): T {
|
override operator fun get(index: Int): T {
|
||||||
if (index < 0 || index >= size) throw IndexOutOfBoundsException("Expected index from 0 to ${size - 1}, but found $index")
|
if (index < 0 || index >= size) throw IndexOutOfBoundsException("Expected index from 0 to ${size - 1}, but found $index")
|
||||||
return generator(index)
|
return generator(index)
|
||||||
@ -177,6 +184,9 @@ public class VirtualBuffer<out T>(override val size: Int, private val generator:
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Convert this buffer to read-only buffer.
|
* Inline builder for [VirtualBuffer]
|
||||||
*/
|
*/
|
||||||
public fun <T> Buffer<T>.asReadOnly(): Buffer<T> = if (this is MutableBuffer) ReadOnlyBuffer(this) else this
|
public inline fun <reified T> VirtualBuffer(
|
||||||
|
size: Int,
|
||||||
|
noinline generator: (Int) -> T,
|
||||||
|
): VirtualBuffer<T> = VirtualBuffer(safeTypeOf(), size, generator)
|
||||||
|
@ -5,7 +5,13 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.structures
|
package space.kscience.kmath.structures
|
||||||
|
|
||||||
import space.kscience.kmath.nd.*
|
import space.kscience.attributes.SafeType
|
||||||
|
import space.kscience.kmath.nd.BufferND
|
||||||
|
import space.kscience.kmath.nd.ShapeND
|
||||||
|
import space.kscience.kmath.nd.Structure2D
|
||||||
|
import space.kscience.kmath.nd.as2D
|
||||||
|
import kotlin.collections.component1
|
||||||
|
import kotlin.collections.component2
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A context that allows to operate on a [MutableBuffer] as on 2d array
|
* A context that allows to operate on a [MutableBuffer] as on 2d array
|
||||||
@ -27,11 +33,13 @@ internal class BufferAccessor2D<T>(
|
|||||||
fun create(mat: Structure2D<T>): MutableBuffer<T> = create { i, j -> mat[i, j] }
|
fun create(mat: Structure2D<T>): MutableBuffer<T> = create { i, j -> mat[i, j] }
|
||||||
|
|
||||||
//TODO optimize wrapper
|
//TODO optimize wrapper
|
||||||
fun MutableBuffer<T>.toStructure2D(): Structure2D<T> = StructureND.buffered(
|
fun MutableBuffer<T>.toStructure2D(): Structure2D<T> = BufferND(
|
||||||
ColumnStrides(ShapeND(rowNum, colNum))
|
type, ShapeND(rowNum, colNum)
|
||||||
) { (i, j) -> get(i, j) }.as2D()
|
) { (i, j) -> get(i, j) }.as2D()
|
||||||
|
|
||||||
inner class Row(val buffer: MutableBuffer<T>, val rowIndex: Int) : MutableBuffer<T> {
|
inner class Row(val buffer: MutableBuffer<T>, val rowIndex: Int) : MutableBuffer<T> {
|
||||||
|
override val type: SafeType<T> get() = buffer.type
|
||||||
|
|
||||||
override val size: Int get() = colNum
|
override val size: Int get() = colNum
|
||||||
|
|
||||||
override operator fun get(index: Int): T = buffer[rowIndex, index]
|
override operator fun get(index: Int): T = buffer[rowIndex, index]
|
||||||
|
@ -5,6 +5,8 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.structures
|
package space.kscience.kmath.structures
|
||||||
|
|
||||||
|
import space.kscience.attributes.SafeType
|
||||||
|
import space.kscience.attributes.safeTypeOf
|
||||||
import kotlin.experimental.and
|
import kotlin.experimental.and
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -57,6 +59,9 @@ public class FlaggedDoubleBuffer(
|
|||||||
public val values: DoubleArray,
|
public val values: DoubleArray,
|
||||||
public val flags: ByteArray
|
public val flags: ByteArray
|
||||||
) : FlaggedBuffer<Double?>, Buffer<Double?> {
|
) : FlaggedBuffer<Double?>, Buffer<Double?> {
|
||||||
|
|
||||||
|
override val type: SafeType<Double?> = safeTypeOf()
|
||||||
|
|
||||||
init {
|
init {
|
||||||
require(values.size == flags.size) { "Values and flags must have the same dimensions" }
|
require(values.size == flags.size) { "Values and flags must have the same dimensions" }
|
||||||
}
|
}
|
||||||
|
@ -5,6 +5,8 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.structures
|
package space.kscience.kmath.structures
|
||||||
|
|
||||||
|
import space.kscience.attributes.SafeType
|
||||||
|
import space.kscience.attributes.safeTypeOf
|
||||||
import kotlin.jvm.JvmInline
|
import kotlin.jvm.JvmInline
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -15,6 +17,9 @@ import kotlin.jvm.JvmInline
|
|||||||
*/
|
*/
|
||||||
@JvmInline
|
@JvmInline
|
||||||
public value class Float32Buffer(public val array: FloatArray) : PrimitiveBuffer<Float> {
|
public value class Float32Buffer(public val array: FloatArray) : PrimitiveBuffer<Float> {
|
||||||
|
|
||||||
|
override val type: SafeType<Float32> get() = safeTypeOf<Float32>()
|
||||||
|
|
||||||
override val size: Int get() = array.size
|
override val size: Int get() = array.size
|
||||||
|
|
||||||
override operator fun get(index: Int): Float = array[index]
|
override operator fun get(index: Int): Float = array[index]
|
||||||
|
@ -5,6 +5,8 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.structures
|
package space.kscience.kmath.structures
|
||||||
|
|
||||||
|
import space.kscience.attributes.SafeType
|
||||||
|
import space.kscience.attributes.safeTypeOf
|
||||||
import space.kscience.kmath.operations.BufferTransform
|
import space.kscience.kmath.operations.BufferTransform
|
||||||
import kotlin.jvm.JvmInline
|
import kotlin.jvm.JvmInline
|
||||||
|
|
||||||
@ -15,6 +17,9 @@ import kotlin.jvm.JvmInline
|
|||||||
*/
|
*/
|
||||||
@JvmInline
|
@JvmInline
|
||||||
public value class Float64Buffer(public val array: DoubleArray) : PrimitiveBuffer<Double> {
|
public value class Float64Buffer(public val array: DoubleArray) : PrimitiveBuffer<Double> {
|
||||||
|
|
||||||
|
override val type: SafeType<Double> get() = safeTypeOf()
|
||||||
|
|
||||||
override val size: Int get() = array.size
|
override val size: Int get() = array.size
|
||||||
|
|
||||||
override operator fun get(index: Int): Double = array[index]
|
override operator fun get(index: Int): Double = array[index]
|
||||||
|
@ -5,6 +5,8 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.structures
|
package space.kscience.kmath.structures
|
||||||
|
|
||||||
|
import space.kscience.attributes.SafeType
|
||||||
|
import space.kscience.attributes.safeTypeOf
|
||||||
import kotlin.jvm.JvmInline
|
import kotlin.jvm.JvmInline
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -14,6 +16,8 @@ import kotlin.jvm.JvmInline
|
|||||||
*/
|
*/
|
||||||
@JvmInline
|
@JvmInline
|
||||||
public value class Int16Buffer(public val array: ShortArray) : MutableBuffer<Short> {
|
public value class Int16Buffer(public val array: ShortArray) : MutableBuffer<Short> {
|
||||||
|
|
||||||
|
override val type: SafeType<Short> get() = safeTypeOf()
|
||||||
override val size: Int get() = array.size
|
override val size: Int get() = array.size
|
||||||
|
|
||||||
override operator fun get(index: Int): Short = array[index]
|
override operator fun get(index: Int): Short = array[index]
|
||||||
|
@ -5,6 +5,8 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.structures
|
package space.kscience.kmath.structures
|
||||||
|
|
||||||
|
import space.kscience.attributes.SafeType
|
||||||
|
import space.kscience.attributes.safeTypeOf
|
||||||
import kotlin.jvm.JvmInline
|
import kotlin.jvm.JvmInline
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -14,6 +16,9 @@ import kotlin.jvm.JvmInline
|
|||||||
*/
|
*/
|
||||||
@JvmInline
|
@JvmInline
|
||||||
public value class Int32Buffer(public val array: IntArray) : PrimitiveBuffer<Int> {
|
public value class Int32Buffer(public val array: IntArray) : PrimitiveBuffer<Int> {
|
||||||
|
|
||||||
|
override val type: SafeType<Int> get() = safeTypeOf()
|
||||||
|
|
||||||
override val size: Int get() = array.size
|
override val size: Int get() = array.size
|
||||||
|
|
||||||
override operator fun get(index: Int): Int = array[index]
|
override operator fun get(index: Int): Int = array[index]
|
||||||
|
@ -5,6 +5,8 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.structures
|
package space.kscience.kmath.structures
|
||||||
|
|
||||||
|
import space.kscience.attributes.SafeType
|
||||||
|
import space.kscience.attributes.safeTypeOf
|
||||||
import kotlin.jvm.JvmInline
|
import kotlin.jvm.JvmInline
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -14,6 +16,9 @@ import kotlin.jvm.JvmInline
|
|||||||
*/
|
*/
|
||||||
@JvmInline
|
@JvmInline
|
||||||
public value class Int64Buffer(public val array: LongArray) : PrimitiveBuffer<Long> {
|
public value class Int64Buffer(public val array: LongArray) : PrimitiveBuffer<Long> {
|
||||||
|
|
||||||
|
override val type: SafeType<Long> get() = safeTypeOf()
|
||||||
|
|
||||||
override val size: Int get() = array.size
|
override val size: Int get() = array.size
|
||||||
|
|
||||||
override operator fun get(index: Int): Long = array[index]
|
override operator fun get(index: Int): Long = array[index]
|
||||||
|
@ -5,6 +5,8 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.structures
|
package space.kscience.kmath.structures
|
||||||
|
|
||||||
|
import space.kscience.attributes.SafeType
|
||||||
|
import space.kscience.attributes.safeTypeOf
|
||||||
import kotlin.jvm.JvmInline
|
import kotlin.jvm.JvmInline
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -14,6 +16,9 @@ import kotlin.jvm.JvmInline
|
|||||||
*/
|
*/
|
||||||
@JvmInline
|
@JvmInline
|
||||||
public value class Int8Buffer(public val array: ByteArray) : MutableBuffer<Byte> {
|
public value class Int8Buffer(public val array: ByteArray) : MutableBuffer<Byte> {
|
||||||
|
|
||||||
|
override val type: SafeType<Byte> get() = safeTypeOf()
|
||||||
|
|
||||||
override val size: Int get() = array.size
|
override val size: Int get() = array.size
|
||||||
|
|
||||||
override operator fun get(index: Int): Byte = array[index]
|
override operator fun get(index: Int): Byte = array[index]
|
||||||
|
@ -5,7 +5,8 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.structures
|
package space.kscience.kmath.structures
|
||||||
|
|
||||||
import kotlin.jvm.JvmInline
|
import space.kscience.attributes.SafeType
|
||||||
|
import space.kscience.attributes.safeTypeOf
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* [Buffer] implementation over [List].
|
* [Buffer] implementation over [List].
|
||||||
@ -13,9 +14,7 @@ import kotlin.jvm.JvmInline
|
|||||||
* @param T the type of elements contained in the buffer.
|
* @param T the type of elements contained in the buffer.
|
||||||
* @property list The underlying list.
|
* @property list The underlying list.
|
||||||
*/
|
*/
|
||||||
public class ListBuffer<T>(public val list: List<T>) : Buffer<T> {
|
public class ListBuffer<T>(override val type: SafeType<T>, public val list: List<T>) : Buffer<T> {
|
||||||
|
|
||||||
public constructor(size: Int, initializer: (Int) -> T) : this(List(size, initializer))
|
|
||||||
|
|
||||||
override val size: Int get() = list.size
|
override val size: Int get() = list.size
|
||||||
|
|
||||||
@ -29,7 +28,12 @@ public class ListBuffer<T>(public val list: List<T>) : Buffer<T> {
|
|||||||
/**
|
/**
|
||||||
* Returns an [ListBuffer] that wraps the original list.
|
* Returns an [ListBuffer] that wraps the original list.
|
||||||
*/
|
*/
|
||||||
public fun <T> List<T>.asBuffer(): ListBuffer<T> = ListBuffer(this)
|
public fun <T> List<T>.asBuffer(type: SafeType<T>): ListBuffer<T> = ListBuffer(type, this)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns an [ListBuffer] that wraps the original list.
|
||||||
|
*/
|
||||||
|
public inline fun <reified T> List<T>.asBuffer(): ListBuffer<T> = asBuffer(safeTypeOf())
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* [MutableBuffer] implementation over [MutableList].
|
* [MutableBuffer] implementation over [MutableList].
|
||||||
@ -37,10 +41,7 @@ public fun <T> List<T>.asBuffer(): ListBuffer<T> = ListBuffer(this)
|
|||||||
* @param T the type of elements contained in the buffer.
|
* @param T the type of elements contained in the buffer.
|
||||||
* @property list The underlying list.
|
* @property list The underlying list.
|
||||||
*/
|
*/
|
||||||
@JvmInline
|
public class MutableListBuffer<T>(override val type: SafeType<T>, public val list: MutableList<T>) : MutableBuffer<T> {
|
||||||
public value class MutableListBuffer<T>(public val list: MutableList<T>) : MutableBuffer<T> {
|
|
||||||
|
|
||||||
public constructor(size: Int, initializer: (Int) -> T) : this(MutableList(size, initializer))
|
|
||||||
|
|
||||||
override val size: Int get() = list.size
|
override val size: Int get() = list.size
|
||||||
|
|
||||||
@ -51,10 +52,24 @@ public value class MutableListBuffer<T>(public val list: MutableList<T>) : Mutab
|
|||||||
}
|
}
|
||||||
|
|
||||||
override operator fun iterator(): Iterator<T> = list.iterator()
|
override operator fun iterator(): Iterator<T> = list.iterator()
|
||||||
override fun copy(): MutableBuffer<T> = MutableListBuffer(ArrayList(list))
|
override fun copy(): MutableBuffer<T> = MutableListBuffer(type, ArrayList(list))
|
||||||
|
|
||||||
|
override fun toString(): String = Buffer.toString(this)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns an [MutableListBuffer] that wraps the original list.
|
* Returns an [MutableListBuffer] that wraps the original list.
|
||||||
*/
|
*/
|
||||||
public fun <T> MutableList<T>.asMutableBuffer(): MutableListBuffer<T> = MutableListBuffer(this)
|
public fun <T> MutableList<T>.asMutableBuffer(type: SafeType<T>): MutableListBuffer<T> = MutableListBuffer(
|
||||||
|
type,
|
||||||
|
this
|
||||||
|
)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns an [MutableListBuffer] that wraps the original list.
|
||||||
|
*/
|
||||||
|
public inline fun <reified T> MutableList<T>.asMutableBuffer(): MutableListBuffer<T> = MutableListBuffer(
|
||||||
|
safeTypeOf(),
|
||||||
|
this
|
||||||
|
)
|
||||||
|
@ -61,41 +61,35 @@ public interface MutableBuffer<T> : Buffer<T> {
|
|||||||
*/
|
*/
|
||||||
public inline fun float(size: Int, initializer: (Int) -> Float): Float32Buffer =
|
public inline fun float(size: Int, initializer: (Int) -> Float): Float32Buffer =
|
||||||
Float32Buffer(size, initializer)
|
Float32Buffer(size, initializer)
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Create a boxing mutable buffer of given type
|
|
||||||
*/
|
|
||||||
public inline fun <T> boxing(size: Int, initializer: (Int) -> T): MutableBuffer<T> =
|
|
||||||
MutableListBuffer(MutableList(size, initializer))
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Creates a [MutableBuffer] of given [type]. If the type is primitive, specialized buffers are used
|
|
||||||
* ([Int32Buffer], [Float64Buffer], etc.), [ListBuffer] is returned otherwise.
|
|
||||||
*
|
|
||||||
* The [size] is specified, and each element is calculated by calling the specified [initializer] function.
|
|
||||||
*/
|
|
||||||
@Suppress("UNCHECKED_CAST")
|
|
||||||
public inline fun <T> auto(type: SafeType<T>, size: Int, initializer: (Int) -> T): MutableBuffer<T> =
|
|
||||||
when (type.kType) {
|
|
||||||
typeOf<Double>() -> double(size) { initializer(it) as Double } as MutableBuffer<T>
|
|
||||||
typeOf<Short>() -> short(size) { initializer(it) as Short } as MutableBuffer<T>
|
|
||||||
typeOf<Int>() -> int(size) { initializer(it) as Int } as MutableBuffer<T>
|
|
||||||
typeOf<Float>() -> float(size) { initializer(it) as Float } as MutableBuffer<T>
|
|
||||||
typeOf<Long>() -> long(size) { initializer(it) as Long } as MutableBuffer<T>
|
|
||||||
else -> boxing(size, initializer)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Creates a [MutableBuffer] of given type [T]. If the type is primitive, specialized buffers are used
|
|
||||||
* ([Int32Buffer], [Float64Buffer], etc.), [ListBuffer] is returned otherwise.
|
|
||||||
*
|
|
||||||
* The [size] is specified, and each element is calculated by calling the specified [initializer] function.
|
|
||||||
*/
|
|
||||||
public inline fun <reified T> auto(size: Int, initializer: (Int) -> T): MutableBuffer<T> =
|
|
||||||
auto(safeTypeOf<T>(), size, initializer)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a [MutableBuffer] of given [type]. If the type is primitive, specialized buffers are used
|
||||||
|
* ([Int32Buffer], [Float64Buffer], etc.), [ListBuffer] is returned otherwise.
|
||||||
|
*
|
||||||
|
* The [size] is specified, and each element is calculated by calling the specified [initializer] function.
|
||||||
|
*/
|
||||||
|
@Suppress("UNCHECKED_CAST")
|
||||||
|
public inline fun <T> MutableBuffer(type: SafeType<T>, size: Int, initializer: (Int) -> T): MutableBuffer<T> =
|
||||||
|
when (type.kType) {
|
||||||
|
typeOf<Double>() -> MutableBuffer.double(size) { initializer(it) as Double } as MutableBuffer<T>
|
||||||
|
typeOf<Short>() -> MutableBuffer.short(size) { initializer(it) as Short } as MutableBuffer<T>
|
||||||
|
typeOf<Int>() -> MutableBuffer.int(size) { initializer(it) as Int } as MutableBuffer<T>
|
||||||
|
typeOf<Float>() -> MutableBuffer.float(size) { initializer(it) as Float } as MutableBuffer<T>
|
||||||
|
typeOf<Long>() -> MutableBuffer.long(size) { initializer(it) as Long } as MutableBuffer<T>
|
||||||
|
else -> MutableListBuffer(type, MutableList(size, initializer))
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a [MutableBuffer] of given type [T]. If the type is primitive, specialized buffers are used
|
||||||
|
* ([Int32Buffer], [Float64Buffer], etc.), [ListBuffer] is returned otherwise.
|
||||||
|
*
|
||||||
|
* The [size] is specified, and each element is calculated by calling the specified [initializer] function.
|
||||||
|
*/
|
||||||
|
public inline fun <reified T> MutableBuffer(size: Int, initializer: (Int) -> T): MutableBuffer<T> =
|
||||||
|
MutableBuffer(safeTypeOf<T>(), size, initializer)
|
||||||
|
|
||||||
|
|
||||||
public sealed interface PrimitiveBuffer<T> : MutableBuffer<T>
|
public sealed interface PrimitiveBuffer<T> : MutableBuffer<T>
|
@ -40,7 +40,7 @@ class DoubleLUSolverTest {
|
|||||||
//Check determinant
|
//Check determinant
|
||||||
assertEquals(7.0, lup.determinant)
|
assertEquals(7.0, lup.determinant)
|
||||||
|
|
||||||
assertMatrixEquals(lup.p dot matrix, lup.l dot lup.u)
|
assertMatrixEquals(lup.pivotMatrix dot matrix, lup.l dot lup.u)
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -19,19 +19,13 @@ import org.ejml.sparse.csc.factory.DecompositionFactory_DSCC
|
|||||||
import org.ejml.sparse.csc.factory.DecompositionFactory_FSCC
|
import org.ejml.sparse.csc.factory.DecompositionFactory_FSCC
|
||||||
import org.ejml.sparse.csc.factory.LinearSolverFactory_DSCC
|
import org.ejml.sparse.csc.factory.LinearSolverFactory_DSCC
|
||||||
import org.ejml.sparse.csc.factory.LinearSolverFactory_FSCC
|
import org.ejml.sparse.csc.factory.LinearSolverFactory_FSCC
|
||||||
|
import space.kscience.kmath.UnstableKMathAPI
|
||||||
import space.kscience.kmath.linear.*
|
import space.kscience.kmath.linear.*
|
||||||
import space.kscience.kmath.linear.Matrix
|
import space.kscience.kmath.linear.Matrix
|
||||||
import space.kscience.kmath.UnstableKMathAPI
|
|
||||||
import space.kscience.kmath.nd.StructureFeature
|
import space.kscience.kmath.nd.StructureFeature
|
||||||
import space.kscience.kmath.structures.Float64
|
|
||||||
import space.kscience.kmath.structures.Float32
|
|
||||||
import space.kscience.kmath.operations.Float64Field
|
|
||||||
import space.kscience.kmath.operations.Float32Field
|
import space.kscience.kmath.operations.Float32Field
|
||||||
import space.kscience.kmath.operations.DoubleField
|
import space.kscience.kmath.operations.Float64Field
|
||||||
import space.kscience.kmath.operations.FloatField
|
|
||||||
import space.kscience.kmath.operations.invoke
|
import space.kscience.kmath.operations.invoke
|
||||||
import space.kscience.kmath.structures.Float64Buffer
|
|
||||||
import space.kscience.kmath.structures.Float32Buffer
|
|
||||||
import space.kscience.kmath.structures.DoubleBuffer
|
import space.kscience.kmath.structures.DoubleBuffer
|
||||||
import space.kscience.kmath.structures.FloatBuffer
|
import space.kscience.kmath.structures.FloatBuffer
|
||||||
import kotlin.reflect.KClass
|
import kotlin.reflect.KClass
|
||||||
|
@ -147,12 +147,12 @@ public fun <V : Any, A : Field<V>> Histogram.Companion.uniformNDFromRanges(
|
|||||||
bufferFactory: BufferFactory<V> = valueAlgebraND.elementAlgebra.bufferFactory,
|
bufferFactory: BufferFactory<V> = valueAlgebraND.elementAlgebra.bufferFactory,
|
||||||
): UniformHistogramGroupND<V, A> = UniformHistogramGroupND(
|
): UniformHistogramGroupND<V, A> = UniformHistogramGroupND(
|
||||||
valueAlgebraND,
|
valueAlgebraND,
|
||||||
ListBuffer(
|
DoubleBuffer(
|
||||||
ranges
|
ranges
|
||||||
.map(Pair<ClosedFloatingPointRange<Double>, Int>::first)
|
.map(Pair<ClosedFloatingPointRange<Double>, Int>::first)
|
||||||
.map(ClosedFloatingPointRange<Double>::start)
|
.map(ClosedFloatingPointRange<Double>::start)
|
||||||
),
|
),
|
||||||
ListBuffer(
|
DoubleBuffer(
|
||||||
ranges
|
ranges
|
||||||
.map(Pair<ClosedFloatingPointRange<Double>, Int>::first)
|
.map(Pair<ClosedFloatingPointRange<Double>, Int>::first)
|
||||||
.map(ClosedFloatingPointRange<Double>::endInclusive)
|
.map(ClosedFloatingPointRange<Double>::endInclusive)
|
||||||
|
Loading…
Reference in New Issue
Block a user