Add Buffer.asList()

This commit is contained in:
Alexander Nozik 2023-11-22 14:32:56 +03:00
parent 5c82a5e1fa
commit 24b934eab7
4 changed files with 107 additions and 37 deletions

View File

@ -8,6 +8,7 @@
- Float32 geometries. - Float32 geometries.
- New Attributes-kt module that could be used as stand-alone. It declares. type-safe attributes containers. - New Attributes-kt module that could be used as stand-alone. It declares. type-safe attributes containers.
- Explicit `mutableStructureND` builders for mutable structures. - Explicit `mutableStructureND` builders for mutable structures.
- `Buffer.asList()` zero-copy transformation.
### Changed ### Changed
- Default naming for algebra and buffers now uses IntXX/FloatXX notation instead of Java types. - Default naming for algebra and buffers now uses IntXX/FloatXX notation instead of Java types.

View File

@ -101,32 +101,13 @@ public interface Buffer<out T> : WithSize, WithType<T> {
} }
} }
/**
* Creates a [Buffer] of given type [T]. If the type is primitive, specialized buffers are used ([Int32Buffer],
* [Float64Buffer], etc.), [ListBuffer] is returned otherwise.
*
* The [size] is specified, and each element is calculated by calling the specified [initializer] function.
*/
@Suppress("UNCHECKED_CAST")
public inline fun <reified T> Buffer(size: Int, initializer: (Int) -> T): Buffer<T> {
val type = safeTypeOf<T>()
return when (type.kType) {
typeOf<Double>() -> MutableBuffer.double(size) { initializer(it) as Double } as Buffer<T>
typeOf<Short>() -> MutableBuffer.short(size) { initializer(it) as Short } as Buffer<T>
typeOf<Int>() -> MutableBuffer.int(size) { initializer(it) as Int } as Buffer<T>
typeOf<Long>() -> MutableBuffer.long(size) { initializer(it) as Long } as Buffer<T>
typeOf<Float>() -> MutableBuffer.float(size) { initializer(it) as Float } as Buffer<T>
else -> List(size, initializer).asBuffer(type)
}
}
/** /**
* Creates a [Buffer] of given [type]. If the type is primitive, specialized buffers are used ([Int32Buffer], * Creates a [Buffer] of given [type]. If the type is primitive, specialized buffers are used ([Int32Buffer],
* [Float64Buffer], etc.), [ListBuffer] is returned otherwise. * [Float64Buffer], etc.), [ListBuffer] is returned otherwise.
* *
* The [size] is specified, and each element is calculated by calling the specified [initializer] function. * The [size] is specified, and each element is calculated by calling the specified [initializer] function.
*/ */
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST", "DuplicatedCode")
public fun <T> Buffer( public fun <T> Buffer(
type: SafeType<T>, type: SafeType<T>,
size: Int, size: Int,
@ -140,6 +121,26 @@ public fun <T> Buffer(
else -> List(size, initializer).asBuffer(type) else -> List(size, initializer).asBuffer(type)
} }
/**
* Creates a [Buffer] of given type [T]. If the type is primitive, specialized buffers are used ([Int32Buffer],
* [Float64Buffer], etc.), [ListBuffer] is returned otherwise.
*
* The [size] is specified, and each element is calculated by calling the specified [initializer] function.
*/
@Suppress("UNCHECKED_CAST", "DuplicatedCode")
public inline fun <reified T> Buffer(size: Int, initializer: (Int) -> T): Buffer<T> {
//code duplication here because we want to inline initializers
val type = safeTypeOf<T>()
return when (type.kType) {
typeOf<Double>() -> MutableBuffer.double(size) { initializer(it) as Double } as Buffer<T>
typeOf<Short>() -> MutableBuffer.short(size) { initializer(it) as Short } as Buffer<T>
typeOf<Int>() -> MutableBuffer.int(size) { initializer(it) as Int } as Buffer<T>
typeOf<Long>() -> MutableBuffer.long(size) { initializer(it) as Long } as Buffer<T>
typeOf<Float>() -> MutableBuffer.float(size) { initializer(it) as Float } as Buffer<T>
else -> List(size, initializer).asBuffer(type)
}
}
/** /**
* Returns an [IntRange] of the valid indices for this [Buffer]. * Returns an [IntRange] of the valid indices for this [Buffer].
*/ */

View File

@ -0,0 +1,81 @@
/*
* Copyright 2018-2023 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.structures
import kotlin.jvm.JvmInline
@JvmInline
private value class BufferList<T>(val buffer: Buffer<T>) : List<T> {
override val size: Int get() = buffer.size
override fun get(index: Int): T = buffer[index]
override fun isEmpty(): Boolean = buffer.size == 0
override fun iterator(): Iterator<T> = buffer.iterator()
override fun listIterator(index: Int): ListIterator<T> = object : ListIterator<T> {
var currentIndex = index
override fun hasNext(): Boolean = currentIndex < buffer.size - 1
override fun hasPrevious(): Boolean = currentIndex > 0
override fun next(): T {
if (!hasNext()) throw NoSuchElementException()
return get(currentIndex++)
}
override fun nextIndex(): Int = currentIndex
override fun previous(): T {
if (!hasPrevious()) throw NoSuchElementException()
return get(--currentIndex)
}
override fun previousIndex(): Int = currentIndex - 1
}
override fun listIterator(): ListIterator<T> = listIterator(0)
override fun subList(fromIndex: Int, toIndex: Int): List<T> =
buffer.slice(fromIndex..toIndex).asList()
override fun lastIndexOf(element: T): Int {
for (i in buffer.indices.reversed()) {
if (buffer[i] == element) return i
}
return -1
}
override fun indexOf(element: T): Int {
for (i in buffer.indices) {
if (buffer[i] == element) return i
}
return -1
}
override fun containsAll(elements: Collection<T>): Boolean {
val remainingElements = HashSet(elements)
for (e in buffer) {
if (e in remainingElements) {
remainingElements.remove(e)
}
if (remainingElements.isEmpty()) {
return true
}
}
return false
}
override fun contains(element: T): Boolean = indexOf(element) >= 0
}
/**
* Returns a zero-copy list that reflects the content of the buffer.
*/
public fun <T> Buffer<T>.asList(): List<T> = BufferList(this)

View File

@ -24,8 +24,6 @@ import space.kscience.attributes.safeTypeOf
import space.kscience.kmath.UnstableKMathAPI import space.kscience.kmath.UnstableKMathAPI
import space.kscience.kmath.linear.* import space.kscience.kmath.linear.*
import space.kscience.kmath.linear.Matrix import space.kscience.kmath.linear.Matrix
import space.kscience.kmath.nd.Structure2D
import space.kscience.kmath.nd.StructureAttribute
import space.kscience.kmath.nd.StructureFeature import space.kscience.kmath.nd.StructureFeature
import space.kscience.kmath.operations.Float32Field import space.kscience.kmath.operations.Float32Field
import space.kscience.kmath.operations.Float64Field import space.kscience.kmath.operations.Float64Field
@ -80,6 +78,7 @@ public class EjmlFloatMatrix<out M : FMatrix>(override val origin: M) : EjmlMatr
} }
/** /**
* [EjmlLinearSpace] implementation based on [CommonOps_DDRM], [DecompositionFactory_DDRM] operations and * [EjmlLinearSpace] implementation based on [CommonOps_DDRM], [DecompositionFactory_DDRM] operations and
* [DMatrixRMaj] matrices. * [DMatrixRMaj] matrices.
@ -218,18 +217,6 @@ public object EjmlLinearSpaceDDRM : EjmlLinearSpace<Double, Float64Field, DMatri
override fun Double.times(v: Point<Double>): EjmlDoubleVector<DMatrixRMaj> = v * this override fun Double.times(v: Point<Double>): EjmlDoubleVector<DMatrixRMaj> = v * this
override fun <V, A : StructureAttribute<V>> computeAttribute(structure: Structure2D<Double>, attribute: A): V? {
val origin = structure.toEjml().origin
return when(attribute){
Inverted -> {
val res = origin.copy()
CommonOps_DDRM.invert(res)
res.wrapMatrix()
}
else->
}
}
@UnstableKMathAPI @UnstableKMathAPI
override fun <F : StructureFeature> computeFeature(structure: Matrix<Double>, type: KClass<out F>): F? { override fun <F : StructureFeature> computeFeature(structure: Matrix<Double>, type: KClass<out F>): F? {
structure.getFeature(type)?.let { return it } structure.getFeature(type)?.let { return it }
@ -330,7 +317,7 @@ public object EjmlLinearSpaceDDRM : EjmlLinearSpace<Double, Float64Field, DMatri
} }
} }
import org.checkerframework.checker.guieffect.qual.SafeType
/** /**
* [EjmlLinearSpace] implementation based on [CommonOps_FDRM], [DecompositionFactory_FDRM] operations and * [EjmlLinearSpace] implementation based on [CommonOps_FDRM], [DecompositionFactory_FDRM] operations and
@ -570,7 +557,7 @@ public object EjmlLinearSpaceFDRM : EjmlLinearSpace<Float, Float32Field, FMatrix
} }
} }
import org.checkerframework.checker.guieffect.qual.SafeType
/** /**
* [EjmlLinearSpace] implementation based on [CommonOps_DSCC], [DecompositionFactory_DSCC] operations and * [EjmlLinearSpace] implementation based on [CommonOps_DSCC], [DecompositionFactory_DSCC] operations and
@ -805,7 +792,7 @@ public object EjmlLinearSpaceDSCC : EjmlLinearSpace<Double, Float64Field, DMatri
} }
} }
import org.checkerframework.checker.guieffect.qual.SafeType
/** /**
* [EjmlLinearSpace] implementation based on [CommonOps_FSCC], [DecompositionFactory_FSCC] operations and * [EjmlLinearSpace] implementation based on [CommonOps_FSCC], [DecompositionFactory_FSCC] operations and