Name refactoring for tensors

This commit is contained in:
Alexander Nozik 2021-10-20 16:11:36 +03:00
parent 40c02f4bd7
commit 69e6849a12
7 changed files with 24 additions and 23 deletions

View File

@ -12,7 +12,7 @@ import space.kscience.kmath.operations.*
import space.kscience.kmath.structures.BufferFactory
public interface BufferAlgebraND<T, out A : Algebra<T>> : AlgebraND<T, A> {
public val indexerBuilder: (IntArray) -> ShapeIndex
public val indexerBuilder: (IntArray) -> ShapeIndexer
public val bufferAlgebra: BufferAlgebra<T, A>
override val elementAlgebra: A get() = bufferAlgebra.elementAlgebra
@ -43,7 +43,7 @@ public interface BufferAlgebraND<T, out A : Algebra<T>> : AlgebraND<T, A> {
zipInline(left.toBufferND(), right.toBufferND(), transform)
public companion object {
public val defaultIndexerBuilder: (IntArray) -> ShapeIndex = DefaultStrides.Companion::invoke
public val defaultIndexerBuilder: (IntArray) -> ShapeIndexer = DefaultStrides.Companion::invoke
}
}
@ -80,25 +80,25 @@ internal inline fun <T, A : Algebra<T>> BufferAlgebraND<T, A>.zipInline(
public open class BufferedGroupNDOps<T, out A : Group<T>>(
override val bufferAlgebra: BufferAlgebra<T, A>,
override val indexerBuilder: (IntArray) -> ShapeIndex = BufferAlgebraND.defaultIndexerBuilder
override val indexerBuilder: (IntArray) -> ShapeIndexer = BufferAlgebraND.defaultIndexerBuilder
) : GroupOpsND<T, A>, BufferAlgebraND<T, A> {
override fun StructureND<T>.unaryMinus(): StructureND<T> = map { -it }
}
public open class BufferedRingOpsND<T, out A : Ring<T>>(
bufferAlgebra: BufferAlgebra<T, A>,
indexerBuilder: (IntArray) -> ShapeIndex = BufferAlgebraND.defaultIndexerBuilder
indexerBuilder: (IntArray) -> ShapeIndexer = BufferAlgebraND.defaultIndexerBuilder
) : BufferedGroupNDOps<T, A>(bufferAlgebra, indexerBuilder), RingOpsND<T, A>
public open class BufferedFieldOpsND<T, out A : Field<T>>(
bufferAlgebra: BufferAlgebra<T, A>,
indexerBuilder: (IntArray) -> ShapeIndex = BufferAlgebraND.defaultIndexerBuilder
indexerBuilder: (IntArray) -> ShapeIndexer = BufferAlgebraND.defaultIndexerBuilder
) : BufferedRingOpsND<T, A>(bufferAlgebra, indexerBuilder), FieldOpsND<T, A> {
public constructor(
elementAlgebra: A,
bufferFactory: BufferFactory<T>,
indexerBuilder: (IntArray) -> ShapeIndex = BufferAlgebraND.defaultIndexerBuilder
indexerBuilder: (IntArray) -> ShapeIndexer = BufferAlgebraND.defaultIndexerBuilder
) : this(BufferFieldOps(elementAlgebra, bufferFactory), indexerBuilder)
override fun scale(a: StructureND<T>, value: Double): StructureND<T> = a.map { it * value }

View File

@ -19,7 +19,7 @@ import space.kscience.kmath.structures.MutableBufferFactory
* @param buffer The underlying buffer.
*/
public open class BufferND<out T>(
public val indexes: ShapeIndex,
public val indexes: ShapeIndexer,
public open val buffer: Buffer<T>,
) : StructureND<T> {
@ -58,7 +58,7 @@ public inline fun <T, reified R : Any> StructureND<T>.mapToBuffer(
* @param buffer The underlying buffer.
*/
public class MutableBufferND<T>(
strides: ShapeIndex,
strides: ShapeIndexer,
override val buffer: MutableBuffer<T>,
) : MutableStructureND<T>, BufferND<T>(strides, buffer) {
override fun set(index: IntArray, value: T) {

View File

@ -13,7 +13,7 @@ import kotlin.contracts.contract
import kotlin.math.pow
public class DoubleBufferND(
indexes: ShapeIndex,
indexes: ShapeIndexer,
override val buffer: DoubleBuffer,
) : BufferND<Double>(indexes, buffer)

View File

@ -10,7 +10,7 @@ import kotlin.native.concurrent.ThreadLocal
/**
* A converter from linear index to multivariate index
*/
public interface ShapeIndex{
public interface ShapeIndexer{
public val shape: Shape
/**
@ -42,7 +42,7 @@ public interface ShapeIndex{
/**
* Linear transformation of indexes
*/
public abstract class Strides: ShapeIndex {
public abstract class Strides: ShapeIndexer {
/**
* Array strides
*/

View File

@ -100,7 +100,7 @@ public interface MutableStructure2D<T> : Structure2D<T>, MutableStructureND<T> {
*/
@JvmInline
private value class Structure2DWrapper<out T>(val structure: StructureND<T>) : Structure2D<T> {
override val shape: IntArray get() = structure.shape
override val shape: Shape get() = structure.shape
override val rowNum: Int get() = shape[0]
override val colNum: Int get() = shape[1]
@ -116,9 +116,8 @@ private value class Structure2DWrapper<out T>(val structure: StructureND<T>) : S
/**
* A 2D wrapper for a mutable nd-structure
*/
private class MutableStructure2DWrapper<T>(val structure: MutableStructureND<T>): MutableStructure2D<T>
{
override val shape: IntArray get() = structure.shape
private class MutableStructure2DWrapper<T>(val structure: MutableStructureND<T>) : MutableStructure2D<T> {
override val shape: Shape get() = structure.shape
override val rowNum: Int get() = shape[0]
override val colNum: Int get() = shape[1]
@ -152,7 +151,8 @@ public fun <T> StructureND<T>.as2D(): Structure2D<T> = this as? Structure2D<T> ?
/**
* Represents a [StructureND] as [Structure2D]. Throws runtime error in case of dimension mismatch.
*/
public fun <T> MutableStructureND<T>.as2D(): MutableStructure2D<T> = this as? MutableStructure2D<T> ?: when (shape.size) {
public fun <T> MutableStructureND<T>.as2D(): MutableStructure2D<T> =
this as? MutableStructure2D<T> ?: when (shape.size) {
2 -> MutableStructure2DWrapper(this)
else -> error("Can't create 2d-structure from ${shape.size}d-structure")
}

View File

@ -33,7 +33,7 @@ public interface StructureND<out T> : Featured<StructureFeature> {
* The shape of structure i.e., non-empty sequence of non-negative integers that specify sizes of dimensions of
* this structure.
*/
public val shape: IntArray
public val shape: Shape
/**
* The count of dimensions in this structure. It should be equal to size of [shape].

View File

@ -15,6 +15,7 @@ import org.jetbrains.kotlinx.multik.api.zeros
import org.jetbrains.kotlinx.multik.ndarray.data.*
import org.jetbrains.kotlinx.multik.ndarray.operations.*
import space.kscience.kmath.misc.PerformancePitfall
import space.kscience.kmath.nd.Shape
import space.kscience.kmath.nd.mapInPlace
import space.kscience.kmath.operations.*
import space.kscience.kmath.tensors.api.Tensor
@ -22,7 +23,7 @@ import space.kscience.kmath.tensors.api.TensorAlgebra
@JvmInline
public value class MultikTensor<T>(public val array: MutableMultiArray<T, DN>) : Tensor<T> {
override val shape: IntArray get() = array.shape
override val shape: Shape get() = array.shape
override fun get(index: IntArray): T = array[index]