Name refactoring for tensors

This commit is contained in:
Alexander Nozik 2021-10-20 16:11:36 +03:00
parent 40c02f4bd7
commit 69e6849a12
7 changed files with 24 additions and 23 deletions

View File

@ -12,7 +12,7 @@ import space.kscience.kmath.operations.*
import space.kscience.kmath.structures.BufferFactory import space.kscience.kmath.structures.BufferFactory
public interface BufferAlgebraND<T, out A : Algebra<T>> : AlgebraND<T, A> { public interface BufferAlgebraND<T, out A : Algebra<T>> : AlgebraND<T, A> {
public val indexerBuilder: (IntArray) -> ShapeIndex public val indexerBuilder: (IntArray) -> ShapeIndexer
public val bufferAlgebra: BufferAlgebra<T, A> public val bufferAlgebra: BufferAlgebra<T, A>
override val elementAlgebra: A get() = bufferAlgebra.elementAlgebra override val elementAlgebra: A get() = bufferAlgebra.elementAlgebra
@ -43,7 +43,7 @@ public interface BufferAlgebraND<T, out A : Algebra<T>> : AlgebraND<T, A> {
zipInline(left.toBufferND(), right.toBufferND(), transform) zipInline(left.toBufferND(), right.toBufferND(), transform)
public companion object { public companion object {
public val defaultIndexerBuilder: (IntArray) -> ShapeIndex = DefaultStrides.Companion::invoke public val defaultIndexerBuilder: (IntArray) -> ShapeIndexer = DefaultStrides.Companion::invoke
} }
} }
@ -80,25 +80,25 @@ internal inline fun <T, A : Algebra<T>> BufferAlgebraND<T, A>.zipInline(
public open class BufferedGroupNDOps<T, out A : Group<T>>( public open class BufferedGroupNDOps<T, out A : Group<T>>(
override val bufferAlgebra: BufferAlgebra<T, A>, override val bufferAlgebra: BufferAlgebra<T, A>,
override val indexerBuilder: (IntArray) -> ShapeIndex = BufferAlgebraND.defaultIndexerBuilder override val indexerBuilder: (IntArray) -> ShapeIndexer = BufferAlgebraND.defaultIndexerBuilder
) : GroupOpsND<T, A>, BufferAlgebraND<T, A> { ) : GroupOpsND<T, A>, BufferAlgebraND<T, A> {
override fun StructureND<T>.unaryMinus(): StructureND<T> = map { -it } override fun StructureND<T>.unaryMinus(): StructureND<T> = map { -it }
} }
public open class BufferedRingOpsND<T, out A : Ring<T>>( public open class BufferedRingOpsND<T, out A : Ring<T>>(
bufferAlgebra: BufferAlgebra<T, A>, bufferAlgebra: BufferAlgebra<T, A>,
indexerBuilder: (IntArray) -> ShapeIndex = BufferAlgebraND.defaultIndexerBuilder indexerBuilder: (IntArray) -> ShapeIndexer = BufferAlgebraND.defaultIndexerBuilder
) : BufferedGroupNDOps<T, A>(bufferAlgebra, indexerBuilder), RingOpsND<T, A> ) : BufferedGroupNDOps<T, A>(bufferAlgebra, indexerBuilder), RingOpsND<T, A>
public open class BufferedFieldOpsND<T, out A : Field<T>>( public open class BufferedFieldOpsND<T, out A : Field<T>>(
bufferAlgebra: BufferAlgebra<T, A>, bufferAlgebra: BufferAlgebra<T, A>,
indexerBuilder: (IntArray) -> ShapeIndex = BufferAlgebraND.defaultIndexerBuilder indexerBuilder: (IntArray) -> ShapeIndexer = BufferAlgebraND.defaultIndexerBuilder
) : BufferedRingOpsND<T, A>(bufferAlgebra, indexerBuilder), FieldOpsND<T, A> { ) : BufferedRingOpsND<T, A>(bufferAlgebra, indexerBuilder), FieldOpsND<T, A> {
public constructor( public constructor(
elementAlgebra: A, elementAlgebra: A,
bufferFactory: BufferFactory<T>, bufferFactory: BufferFactory<T>,
indexerBuilder: (IntArray) -> ShapeIndex = BufferAlgebraND.defaultIndexerBuilder indexerBuilder: (IntArray) -> ShapeIndexer = BufferAlgebraND.defaultIndexerBuilder
) : this(BufferFieldOps(elementAlgebra, bufferFactory), indexerBuilder) ) : this(BufferFieldOps(elementAlgebra, bufferFactory), indexerBuilder)
override fun scale(a: StructureND<T>, value: Double): StructureND<T> = a.map { it * value } override fun scale(a: StructureND<T>, value: Double): StructureND<T> = a.map { it * value }

View File

@ -19,7 +19,7 @@ import space.kscience.kmath.structures.MutableBufferFactory
* @param buffer The underlying buffer. * @param buffer The underlying buffer.
*/ */
public open class BufferND<out T>( public open class BufferND<out T>(
public val indexes: ShapeIndex, public val indexes: ShapeIndexer,
public open val buffer: Buffer<T>, public open val buffer: Buffer<T>,
) : StructureND<T> { ) : StructureND<T> {
@ -58,7 +58,7 @@ public inline fun <T, reified R : Any> StructureND<T>.mapToBuffer(
* @param buffer The underlying buffer. * @param buffer The underlying buffer.
*/ */
public class MutableBufferND<T>( public class MutableBufferND<T>(
strides: ShapeIndex, strides: ShapeIndexer,
override val buffer: MutableBuffer<T>, override val buffer: MutableBuffer<T>,
) : MutableStructureND<T>, BufferND<T>(strides, buffer) { ) : MutableStructureND<T>, BufferND<T>(strides, buffer) {
override fun set(index: IntArray, value: T) { override fun set(index: IntArray, value: T) {

View File

@ -13,7 +13,7 @@ import kotlin.contracts.contract
import kotlin.math.pow import kotlin.math.pow
public class DoubleBufferND( public class DoubleBufferND(
indexes: ShapeIndex, indexes: ShapeIndexer,
override val buffer: DoubleBuffer, override val buffer: DoubleBuffer,
) : BufferND<Double>(indexes, buffer) ) : BufferND<Double>(indexes, buffer)

View File

@ -10,7 +10,7 @@ import kotlin.native.concurrent.ThreadLocal
/** /**
* A converter from linear index to multivariate index * A converter from linear index to multivariate index
*/ */
public interface ShapeIndex{ public interface ShapeIndexer{
public val shape: Shape public val shape: Shape
/** /**
@ -42,7 +42,7 @@ public interface ShapeIndex{
/** /**
* Linear transformation of indexes * Linear transformation of indexes
*/ */
public abstract class Strides: ShapeIndex { public abstract class Strides: ShapeIndexer {
/** /**
* Array strides * Array strides
*/ */

View File

@ -100,7 +100,7 @@ public interface MutableStructure2D<T> : Structure2D<T>, MutableStructureND<T> {
*/ */
@JvmInline @JvmInline
private value class Structure2DWrapper<out T>(val structure: StructureND<T>) : Structure2D<T> { private value class Structure2DWrapper<out T>(val structure: StructureND<T>) : Structure2D<T> {
override val shape: IntArray get() = structure.shape override val shape: Shape get() = structure.shape
override val rowNum: Int get() = shape[0] override val rowNum: Int get() = shape[0]
override val colNum: Int get() = shape[1] override val colNum: Int get() = shape[1]
@ -116,9 +116,8 @@ private value class Structure2DWrapper<out T>(val structure: StructureND<T>) : S
/** /**
* A 2D wrapper for a mutable nd-structure * A 2D wrapper for a mutable nd-structure
*/ */
private class MutableStructure2DWrapper<T>(val structure: MutableStructureND<T>): MutableStructure2D<T> private class MutableStructure2DWrapper<T>(val structure: MutableStructureND<T>) : MutableStructure2D<T> {
{ override val shape: Shape get() = structure.shape
override val shape: IntArray get() = structure.shape
override val rowNum: Int get() = shape[0] override val rowNum: Int get() = shape[0]
override val colNum: Int get() = shape[1] override val colNum: Int get() = shape[1]
@ -152,7 +151,8 @@ public fun <T> StructureND<T>.as2D(): Structure2D<T> = this as? Structure2D<T> ?
/** /**
* Represents a [StructureND] as [Structure2D]. Throws runtime error in case of dimension mismatch. * Represents a [StructureND] as [Structure2D]. Throws runtime error in case of dimension mismatch.
*/ */
public fun <T> MutableStructureND<T>.as2D(): MutableStructure2D<T> = this as? MutableStructure2D<T> ?: when (shape.size) { public fun <T> MutableStructureND<T>.as2D(): MutableStructure2D<T> =
this as? MutableStructure2D<T> ?: when (shape.size) {
2 -> MutableStructure2DWrapper(this) 2 -> MutableStructure2DWrapper(this)
else -> error("Can't create 2d-structure from ${shape.size}d-structure") else -> error("Can't create 2d-structure from ${shape.size}d-structure")
} }

View File

@ -33,7 +33,7 @@ public interface StructureND<out T> : Featured<StructureFeature> {
* The shape of structure i.e., non-empty sequence of non-negative integers that specify sizes of dimensions of * The shape of structure i.e., non-empty sequence of non-negative integers that specify sizes of dimensions of
* this structure. * this structure.
*/ */
public val shape: IntArray public val shape: Shape
/** /**
* The count of dimensions in this structure. It should be equal to size of [shape]. * The count of dimensions in this structure. It should be equal to size of [shape].

View File

@ -15,6 +15,7 @@ import org.jetbrains.kotlinx.multik.api.zeros
import org.jetbrains.kotlinx.multik.ndarray.data.* import org.jetbrains.kotlinx.multik.ndarray.data.*
import org.jetbrains.kotlinx.multik.ndarray.operations.* import org.jetbrains.kotlinx.multik.ndarray.operations.*
import space.kscience.kmath.misc.PerformancePitfall import space.kscience.kmath.misc.PerformancePitfall
import space.kscience.kmath.nd.Shape
import space.kscience.kmath.nd.mapInPlace import space.kscience.kmath.nd.mapInPlace
import space.kscience.kmath.operations.* import space.kscience.kmath.operations.*
import space.kscience.kmath.tensors.api.Tensor import space.kscience.kmath.tensors.api.Tensor
@ -22,7 +23,7 @@ import space.kscience.kmath.tensors.api.TensorAlgebra
@JvmInline @JvmInline
public value class MultikTensor<T>(public val array: MutableMultiArray<T, DN>) : Tensor<T> { public value class MultikTensor<T>(public val array: MutableMultiArray<T, DN>) : Tensor<T> {
override val shape: IntArray get() = array.shape override val shape: Shape get() = array.shape
override fun get(index: IntArray): T = array[index] override fun get(index: IntArray): T = array[index]