forked from kscience/kmath
Name refactoring for tensors
This commit is contained in:
parent
40c02f4bd7
commit
69e6849a12
@ -12,7 +12,7 @@ import space.kscience.kmath.operations.*
|
||||
import space.kscience.kmath.structures.BufferFactory
|
||||
|
||||
public interface BufferAlgebraND<T, out A : Algebra<T>> : AlgebraND<T, A> {
|
||||
public val indexerBuilder: (IntArray) -> ShapeIndex
|
||||
public val indexerBuilder: (IntArray) -> ShapeIndexer
|
||||
public val bufferAlgebra: BufferAlgebra<T, A>
|
||||
override val elementAlgebra: A get() = bufferAlgebra.elementAlgebra
|
||||
|
||||
@ -43,7 +43,7 @@ public interface BufferAlgebraND<T, out A : Algebra<T>> : AlgebraND<T, A> {
|
||||
zipInline(left.toBufferND(), right.toBufferND(), transform)
|
||||
|
||||
public companion object {
|
||||
public val defaultIndexerBuilder: (IntArray) -> ShapeIndex = DefaultStrides.Companion::invoke
|
||||
public val defaultIndexerBuilder: (IntArray) -> ShapeIndexer = DefaultStrides.Companion::invoke
|
||||
}
|
||||
}
|
||||
|
||||
@ -80,25 +80,25 @@ internal inline fun <T, A : Algebra<T>> BufferAlgebraND<T, A>.zipInline(
|
||||
|
||||
public open class BufferedGroupNDOps<T, out A : Group<T>>(
|
||||
override val bufferAlgebra: BufferAlgebra<T, A>,
|
||||
override val indexerBuilder: (IntArray) -> ShapeIndex = BufferAlgebraND.defaultIndexerBuilder
|
||||
override val indexerBuilder: (IntArray) -> ShapeIndexer = BufferAlgebraND.defaultIndexerBuilder
|
||||
) : GroupOpsND<T, A>, BufferAlgebraND<T, A> {
|
||||
override fun StructureND<T>.unaryMinus(): StructureND<T> = map { -it }
|
||||
}
|
||||
|
||||
public open class BufferedRingOpsND<T, out A : Ring<T>>(
|
||||
bufferAlgebra: BufferAlgebra<T, A>,
|
||||
indexerBuilder: (IntArray) -> ShapeIndex = BufferAlgebraND.defaultIndexerBuilder
|
||||
indexerBuilder: (IntArray) -> ShapeIndexer = BufferAlgebraND.defaultIndexerBuilder
|
||||
) : BufferedGroupNDOps<T, A>(bufferAlgebra, indexerBuilder), RingOpsND<T, A>
|
||||
|
||||
public open class BufferedFieldOpsND<T, out A : Field<T>>(
|
||||
bufferAlgebra: BufferAlgebra<T, A>,
|
||||
indexerBuilder: (IntArray) -> ShapeIndex = BufferAlgebraND.defaultIndexerBuilder
|
||||
indexerBuilder: (IntArray) -> ShapeIndexer = BufferAlgebraND.defaultIndexerBuilder
|
||||
) : BufferedRingOpsND<T, A>(bufferAlgebra, indexerBuilder), FieldOpsND<T, A> {
|
||||
|
||||
public constructor(
|
||||
elementAlgebra: A,
|
||||
bufferFactory: BufferFactory<T>,
|
||||
indexerBuilder: (IntArray) -> ShapeIndex = BufferAlgebraND.defaultIndexerBuilder
|
||||
indexerBuilder: (IntArray) -> ShapeIndexer = BufferAlgebraND.defaultIndexerBuilder
|
||||
) : this(BufferFieldOps(elementAlgebra, bufferFactory), indexerBuilder)
|
||||
|
||||
override fun scale(a: StructureND<T>, value: Double): StructureND<T> = a.map { it * value }
|
||||
|
@ -19,7 +19,7 @@ import space.kscience.kmath.structures.MutableBufferFactory
|
||||
* @param buffer The underlying buffer.
|
||||
*/
|
||||
public open class BufferND<out T>(
|
||||
public val indexes: ShapeIndex,
|
||||
public val indexes: ShapeIndexer,
|
||||
public open val buffer: Buffer<T>,
|
||||
) : StructureND<T> {
|
||||
|
||||
@ -58,7 +58,7 @@ public inline fun <T, reified R : Any> StructureND<T>.mapToBuffer(
|
||||
* @param buffer The underlying buffer.
|
||||
*/
|
||||
public class MutableBufferND<T>(
|
||||
strides: ShapeIndex,
|
||||
strides: ShapeIndexer,
|
||||
override val buffer: MutableBuffer<T>,
|
||||
) : MutableStructureND<T>, BufferND<T>(strides, buffer) {
|
||||
override fun set(index: IntArray, value: T) {
|
||||
|
@ -13,7 +13,7 @@ import kotlin.contracts.contract
|
||||
import kotlin.math.pow
|
||||
|
||||
public class DoubleBufferND(
|
||||
indexes: ShapeIndex,
|
||||
indexes: ShapeIndexer,
|
||||
override val buffer: DoubleBuffer,
|
||||
) : BufferND<Double>(indexes, buffer)
|
||||
|
||||
|
@ -10,7 +10,7 @@ import kotlin.native.concurrent.ThreadLocal
|
||||
/**
|
||||
* A converter from linear index to multivariate index
|
||||
*/
|
||||
public interface ShapeIndex{
|
||||
public interface ShapeIndexer{
|
||||
public val shape: Shape
|
||||
|
||||
/**
|
||||
@ -42,7 +42,7 @@ public interface ShapeIndex{
|
||||
/**
|
||||
* Linear transformation of indexes
|
||||
*/
|
||||
public abstract class Strides: ShapeIndex {
|
||||
public abstract class Strides: ShapeIndexer {
|
||||
/**
|
||||
* Array strides
|
||||
*/
|
@ -85,7 +85,7 @@ public interface MutableStructure2D<T> : Structure2D<T>, MutableStructureND<T> {
|
||||
*/
|
||||
@PerformancePitfall
|
||||
override val rows: List<MutableStructure1D<T>>
|
||||
get() = List(rowNum) { i -> MutableBuffer1DWrapper(MutableListBuffer(colNum) { j -> get(i, j) })}
|
||||
get() = List(rowNum) { i -> MutableBuffer1DWrapper(MutableListBuffer(colNum) { j -> get(i, j) }) }
|
||||
|
||||
/**
|
||||
* The buffer of columns of this structure. It gets elements from the structure dynamically.
|
||||
@ -100,7 +100,7 @@ public interface MutableStructure2D<T> : Structure2D<T>, MutableStructureND<T> {
|
||||
*/
|
||||
@JvmInline
|
||||
private value class Structure2DWrapper<out T>(val structure: StructureND<T>) : Structure2D<T> {
|
||||
override val shape: IntArray get() = structure.shape
|
||||
override val shape: Shape get() = structure.shape
|
||||
|
||||
override val rowNum: Int get() = shape[0]
|
||||
override val colNum: Int get() = shape[1]
|
||||
@ -116,9 +116,8 @@ private value class Structure2DWrapper<out T>(val structure: StructureND<T>) : S
|
||||
/**
|
||||
* A 2D wrapper for a mutable nd-structure
|
||||
*/
|
||||
private class MutableStructure2DWrapper<T>(val structure: MutableStructureND<T>): MutableStructure2D<T>
|
||||
{
|
||||
override val shape: IntArray get() = structure.shape
|
||||
private class MutableStructure2DWrapper<T>(val structure: MutableStructureND<T>) : MutableStructure2D<T> {
|
||||
override val shape: Shape get() = structure.shape
|
||||
|
||||
override val rowNum: Int get() = shape[0]
|
||||
override val colNum: Int get() = shape[1]
|
||||
@ -129,7 +128,7 @@ private class MutableStructure2DWrapper<T>(val structure: MutableStructureND<T>)
|
||||
structure[index] = value
|
||||
}
|
||||
|
||||
override operator fun set(i: Int, j: Int, value: T){
|
||||
override operator fun set(i: Int, j: Int, value: T) {
|
||||
structure[intArrayOf(i, j)] = value
|
||||
}
|
||||
|
||||
@ -152,10 +151,11 @@ public fun <T> StructureND<T>.as2D(): Structure2D<T> = this as? Structure2D<T> ?
|
||||
/**
|
||||
* Represents a [StructureND] as [Structure2D]. Throws runtime error in case of dimension mismatch.
|
||||
*/
|
||||
public fun <T> MutableStructureND<T>.as2D(): MutableStructure2D<T> = this as? MutableStructure2D<T> ?: when (shape.size) {
|
||||
2 -> MutableStructure2DWrapper(this)
|
||||
else -> error("Can't create 2d-structure from ${shape.size}d-structure")
|
||||
}
|
||||
public fun <T> MutableStructureND<T>.as2D(): MutableStructure2D<T> =
|
||||
this as? MutableStructure2D<T> ?: when (shape.size) {
|
||||
2 -> MutableStructure2DWrapper(this)
|
||||
else -> error("Can't create 2d-structure from ${shape.size}d-structure")
|
||||
}
|
||||
|
||||
/**
|
||||
* Expose inner [StructureND] if possible
|
||||
|
@ -33,7 +33,7 @@ public interface StructureND<out T> : Featured<StructureFeature> {
|
||||
* The shape of structure i.e., non-empty sequence of non-negative integers that specify sizes of dimensions of
|
||||
* this structure.
|
||||
*/
|
||||
public val shape: IntArray
|
||||
public val shape: Shape
|
||||
|
||||
/**
|
||||
* The count of dimensions in this structure. It should be equal to size of [shape].
|
||||
|
@ -15,6 +15,7 @@ import org.jetbrains.kotlinx.multik.api.zeros
|
||||
import org.jetbrains.kotlinx.multik.ndarray.data.*
|
||||
import org.jetbrains.kotlinx.multik.ndarray.operations.*
|
||||
import space.kscience.kmath.misc.PerformancePitfall
|
||||
import space.kscience.kmath.nd.Shape
|
||||
import space.kscience.kmath.nd.mapInPlace
|
||||
import space.kscience.kmath.operations.*
|
||||
import space.kscience.kmath.tensors.api.Tensor
|
||||
@ -22,7 +23,7 @@ import space.kscience.kmath.tensors.api.TensorAlgebra
|
||||
|
||||
@JvmInline
|
||||
public value class MultikTensor<T>(public val array: MutableMultiArray<T, DN>) : Tensor<T> {
|
||||
override val shape: IntArray get() = array.shape
|
||||
override val shape: Shape get() = array.shape
|
||||
|
||||
override fun get(index: IntArray): T = array[index]
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user