Name refactoring for tensors
This commit is contained in:
parent
40c02f4bd7
commit
69e6849a12
@ -12,7 +12,7 @@ import space.kscience.kmath.operations.*
|
|||||||
import space.kscience.kmath.structures.BufferFactory
|
import space.kscience.kmath.structures.BufferFactory
|
||||||
|
|
||||||
public interface BufferAlgebraND<T, out A : Algebra<T>> : AlgebraND<T, A> {
|
public interface BufferAlgebraND<T, out A : Algebra<T>> : AlgebraND<T, A> {
|
||||||
public val indexerBuilder: (IntArray) -> ShapeIndex
|
public val indexerBuilder: (IntArray) -> ShapeIndexer
|
||||||
public val bufferAlgebra: BufferAlgebra<T, A>
|
public val bufferAlgebra: BufferAlgebra<T, A>
|
||||||
override val elementAlgebra: A get() = bufferAlgebra.elementAlgebra
|
override val elementAlgebra: A get() = bufferAlgebra.elementAlgebra
|
||||||
|
|
||||||
@ -43,7 +43,7 @@ public interface BufferAlgebraND<T, out A : Algebra<T>> : AlgebraND<T, A> {
|
|||||||
zipInline(left.toBufferND(), right.toBufferND(), transform)
|
zipInline(left.toBufferND(), right.toBufferND(), transform)
|
||||||
|
|
||||||
public companion object {
|
public companion object {
|
||||||
public val defaultIndexerBuilder: (IntArray) -> ShapeIndex = DefaultStrides.Companion::invoke
|
public val defaultIndexerBuilder: (IntArray) -> ShapeIndexer = DefaultStrides.Companion::invoke
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -80,25 +80,25 @@ internal inline fun <T, A : Algebra<T>> BufferAlgebraND<T, A>.zipInline(
|
|||||||
|
|
||||||
public open class BufferedGroupNDOps<T, out A : Group<T>>(
|
public open class BufferedGroupNDOps<T, out A : Group<T>>(
|
||||||
override val bufferAlgebra: BufferAlgebra<T, A>,
|
override val bufferAlgebra: BufferAlgebra<T, A>,
|
||||||
override val indexerBuilder: (IntArray) -> ShapeIndex = BufferAlgebraND.defaultIndexerBuilder
|
override val indexerBuilder: (IntArray) -> ShapeIndexer = BufferAlgebraND.defaultIndexerBuilder
|
||||||
) : GroupOpsND<T, A>, BufferAlgebraND<T, A> {
|
) : GroupOpsND<T, A>, BufferAlgebraND<T, A> {
|
||||||
override fun StructureND<T>.unaryMinus(): StructureND<T> = map { -it }
|
override fun StructureND<T>.unaryMinus(): StructureND<T> = map { -it }
|
||||||
}
|
}
|
||||||
|
|
||||||
public open class BufferedRingOpsND<T, out A : Ring<T>>(
|
public open class BufferedRingOpsND<T, out A : Ring<T>>(
|
||||||
bufferAlgebra: BufferAlgebra<T, A>,
|
bufferAlgebra: BufferAlgebra<T, A>,
|
||||||
indexerBuilder: (IntArray) -> ShapeIndex = BufferAlgebraND.defaultIndexerBuilder
|
indexerBuilder: (IntArray) -> ShapeIndexer = BufferAlgebraND.defaultIndexerBuilder
|
||||||
) : BufferedGroupNDOps<T, A>(bufferAlgebra, indexerBuilder), RingOpsND<T, A>
|
) : BufferedGroupNDOps<T, A>(bufferAlgebra, indexerBuilder), RingOpsND<T, A>
|
||||||
|
|
||||||
public open class BufferedFieldOpsND<T, out A : Field<T>>(
|
public open class BufferedFieldOpsND<T, out A : Field<T>>(
|
||||||
bufferAlgebra: BufferAlgebra<T, A>,
|
bufferAlgebra: BufferAlgebra<T, A>,
|
||||||
indexerBuilder: (IntArray) -> ShapeIndex = BufferAlgebraND.defaultIndexerBuilder
|
indexerBuilder: (IntArray) -> ShapeIndexer = BufferAlgebraND.defaultIndexerBuilder
|
||||||
) : BufferedRingOpsND<T, A>(bufferAlgebra, indexerBuilder), FieldOpsND<T, A> {
|
) : BufferedRingOpsND<T, A>(bufferAlgebra, indexerBuilder), FieldOpsND<T, A> {
|
||||||
|
|
||||||
public constructor(
|
public constructor(
|
||||||
elementAlgebra: A,
|
elementAlgebra: A,
|
||||||
bufferFactory: BufferFactory<T>,
|
bufferFactory: BufferFactory<T>,
|
||||||
indexerBuilder: (IntArray) -> ShapeIndex = BufferAlgebraND.defaultIndexerBuilder
|
indexerBuilder: (IntArray) -> ShapeIndexer = BufferAlgebraND.defaultIndexerBuilder
|
||||||
) : this(BufferFieldOps(elementAlgebra, bufferFactory), indexerBuilder)
|
) : this(BufferFieldOps(elementAlgebra, bufferFactory), indexerBuilder)
|
||||||
|
|
||||||
override fun scale(a: StructureND<T>, value: Double): StructureND<T> = a.map { it * value }
|
override fun scale(a: StructureND<T>, value: Double): StructureND<T> = a.map { it * value }
|
||||||
|
@ -19,7 +19,7 @@ import space.kscience.kmath.structures.MutableBufferFactory
|
|||||||
* @param buffer The underlying buffer.
|
* @param buffer The underlying buffer.
|
||||||
*/
|
*/
|
||||||
public open class BufferND<out T>(
|
public open class BufferND<out T>(
|
||||||
public val indexes: ShapeIndex,
|
public val indexes: ShapeIndexer,
|
||||||
public open val buffer: Buffer<T>,
|
public open val buffer: Buffer<T>,
|
||||||
) : StructureND<T> {
|
) : StructureND<T> {
|
||||||
|
|
||||||
@ -58,7 +58,7 @@ public inline fun <T, reified R : Any> StructureND<T>.mapToBuffer(
|
|||||||
* @param buffer The underlying buffer.
|
* @param buffer The underlying buffer.
|
||||||
*/
|
*/
|
||||||
public class MutableBufferND<T>(
|
public class MutableBufferND<T>(
|
||||||
strides: ShapeIndex,
|
strides: ShapeIndexer,
|
||||||
override val buffer: MutableBuffer<T>,
|
override val buffer: MutableBuffer<T>,
|
||||||
) : MutableStructureND<T>, BufferND<T>(strides, buffer) {
|
) : MutableStructureND<T>, BufferND<T>(strides, buffer) {
|
||||||
override fun set(index: IntArray, value: T) {
|
override fun set(index: IntArray, value: T) {
|
||||||
|
@ -13,7 +13,7 @@ import kotlin.contracts.contract
|
|||||||
import kotlin.math.pow
|
import kotlin.math.pow
|
||||||
|
|
||||||
public class DoubleBufferND(
|
public class DoubleBufferND(
|
||||||
indexes: ShapeIndex,
|
indexes: ShapeIndexer,
|
||||||
override val buffer: DoubleBuffer,
|
override val buffer: DoubleBuffer,
|
||||||
) : BufferND<Double>(indexes, buffer)
|
) : BufferND<Double>(indexes, buffer)
|
||||||
|
|
||||||
|
@ -10,7 +10,7 @@ import kotlin.native.concurrent.ThreadLocal
|
|||||||
/**
|
/**
|
||||||
* A converter from linear index to multivariate index
|
* A converter from linear index to multivariate index
|
||||||
*/
|
*/
|
||||||
public interface ShapeIndex{
|
public interface ShapeIndexer{
|
||||||
public val shape: Shape
|
public val shape: Shape
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -42,7 +42,7 @@ public interface ShapeIndex{
|
|||||||
/**
|
/**
|
||||||
* Linear transformation of indexes
|
* Linear transformation of indexes
|
||||||
*/
|
*/
|
||||||
public abstract class Strides: ShapeIndex {
|
public abstract class Strides: ShapeIndexer {
|
||||||
/**
|
/**
|
||||||
* Array strides
|
* Array strides
|
||||||
*/
|
*/
|
@ -100,7 +100,7 @@ public interface MutableStructure2D<T> : Structure2D<T>, MutableStructureND<T> {
|
|||||||
*/
|
*/
|
||||||
@JvmInline
|
@JvmInline
|
||||||
private value class Structure2DWrapper<out T>(val structure: StructureND<T>) : Structure2D<T> {
|
private value class Structure2DWrapper<out T>(val structure: StructureND<T>) : Structure2D<T> {
|
||||||
override val shape: IntArray get() = structure.shape
|
override val shape: Shape get() = structure.shape
|
||||||
|
|
||||||
override val rowNum: Int get() = shape[0]
|
override val rowNum: Int get() = shape[0]
|
||||||
override val colNum: Int get() = shape[1]
|
override val colNum: Int get() = shape[1]
|
||||||
@ -116,9 +116,8 @@ private value class Structure2DWrapper<out T>(val structure: StructureND<T>) : S
|
|||||||
/**
|
/**
|
||||||
* A 2D wrapper for a mutable nd-structure
|
* A 2D wrapper for a mutable nd-structure
|
||||||
*/
|
*/
|
||||||
private class MutableStructure2DWrapper<T>(val structure: MutableStructureND<T>): MutableStructure2D<T>
|
private class MutableStructure2DWrapper<T>(val structure: MutableStructureND<T>) : MutableStructure2D<T> {
|
||||||
{
|
override val shape: Shape get() = structure.shape
|
||||||
override val shape: IntArray get() = structure.shape
|
|
||||||
|
|
||||||
override val rowNum: Int get() = shape[0]
|
override val rowNum: Int get() = shape[0]
|
||||||
override val colNum: Int get() = shape[1]
|
override val colNum: Int get() = shape[1]
|
||||||
@ -152,7 +151,8 @@ public fun <T> StructureND<T>.as2D(): Structure2D<T> = this as? Structure2D<T> ?
|
|||||||
/**
|
/**
|
||||||
* Represents a [StructureND] as [Structure2D]. Throws runtime error in case of dimension mismatch.
|
* Represents a [StructureND] as [Structure2D]. Throws runtime error in case of dimension mismatch.
|
||||||
*/
|
*/
|
||||||
public fun <T> MutableStructureND<T>.as2D(): MutableStructure2D<T> = this as? MutableStructure2D<T> ?: when (shape.size) {
|
public fun <T> MutableStructureND<T>.as2D(): MutableStructure2D<T> =
|
||||||
|
this as? MutableStructure2D<T> ?: when (shape.size) {
|
||||||
2 -> MutableStructure2DWrapper(this)
|
2 -> MutableStructure2DWrapper(this)
|
||||||
else -> error("Can't create 2d-structure from ${shape.size}d-structure")
|
else -> error("Can't create 2d-structure from ${shape.size}d-structure")
|
||||||
}
|
}
|
||||||
|
@ -33,7 +33,7 @@ public interface StructureND<out T> : Featured<StructureFeature> {
|
|||||||
* The shape of structure i.e., non-empty sequence of non-negative integers that specify sizes of dimensions of
|
* The shape of structure i.e., non-empty sequence of non-negative integers that specify sizes of dimensions of
|
||||||
* this structure.
|
* this structure.
|
||||||
*/
|
*/
|
||||||
public val shape: IntArray
|
public val shape: Shape
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The count of dimensions in this structure. It should be equal to size of [shape].
|
* The count of dimensions in this structure. It should be equal to size of [shape].
|
||||||
|
@ -15,6 +15,7 @@ import org.jetbrains.kotlinx.multik.api.zeros
|
|||||||
import org.jetbrains.kotlinx.multik.ndarray.data.*
|
import org.jetbrains.kotlinx.multik.ndarray.data.*
|
||||||
import org.jetbrains.kotlinx.multik.ndarray.operations.*
|
import org.jetbrains.kotlinx.multik.ndarray.operations.*
|
||||||
import space.kscience.kmath.misc.PerformancePitfall
|
import space.kscience.kmath.misc.PerformancePitfall
|
||||||
|
import space.kscience.kmath.nd.Shape
|
||||||
import space.kscience.kmath.nd.mapInPlace
|
import space.kscience.kmath.nd.mapInPlace
|
||||||
import space.kscience.kmath.operations.*
|
import space.kscience.kmath.operations.*
|
||||||
import space.kscience.kmath.tensors.api.Tensor
|
import space.kscience.kmath.tensors.api.Tensor
|
||||||
@ -22,7 +23,7 @@ import space.kscience.kmath.tensors.api.TensorAlgebra
|
|||||||
|
|
||||||
@JvmInline
|
@JvmInline
|
||||||
public value class MultikTensor<T>(public val array: MutableMultiArray<T, DN>) : Tensor<T> {
|
public value class MultikTensor<T>(public val array: MutableMultiArray<T, DN>) : Tensor<T> {
|
||||||
override val shape: IntArray get() = array.shape
|
override val shape: Shape get() = array.shape
|
||||||
|
|
||||||
override fun get(index: IntArray): T = array[index]
|
override fun get(index: IntArray): T = array[index]
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user