MutableStructure 2D & 1D

This commit is contained in:
Roland Grinis 2021-03-15 21:18:15 +00:00
parent b227a82a80
commit f8e0d4be17
7 changed files with 118 additions and 4 deletions

View File

@ -1,6 +1,8 @@
package space.kscience.kmath.nd
import space.kscience.kmath.structures.Buffer
import space.kscience.kmath.structures.MutableBuffer
import space.kscience.kmath.structures.asMutableBuffer
import space.kscience.kmath.structures.asSequence
/**
@ -17,6 +19,16 @@ public interface Structure1D<T> : NDStructure<T>, Buffer<T> {
public override operator fun iterator(): Iterator<T> = (0 until size).asSequence().map(::get).iterator()
}
/**
* A mutable structure that is guaranteed to be one-dimensional
*/
public interface MutableStructure1D<T> : Structure1D<T>, MutableNDStructure<T>, MutableBuffer<T> {
public override operator fun set(index: IntArray, value: T) {
require(index.size == 1) { "Index dimension mismatch. Expected 1 but found ${index.size}" }
set(index[0], value)
}
}
/**
* A 1D wrapper for nd-structure
*/
@ -28,6 +40,25 @@ private inline class Structure1DWrapper<T>(val structure: NDStructure<T>) : Stru
override fun elements(): Sequence<Pair<IntArray, T>> = structure.elements()
}
/**
* A 1D wrapper for a mutable nd-structure
*/
private inline class MutableStructure1DWrapper<T>(val structure: MutableNDStructure<T>) : MutableStructure1D<T> {
override val shape: IntArray get() = structure.shape
override val size: Int get() = structure.shape[0]
override fun elements(): Sequence<Pair<IntArray, T>> {
TODO("Not yet implemented")
}
override fun get(index: Int): T = structure[index]
override fun set(index: Int, value: T) {
set(index, value)
}
override fun copy(): MutableBuffer<T> =
structure.elements().map { it.second }.toMutableList().asMutableBuffer()
}
/**
* A structure wrapper for buffer
@ -42,6 +73,21 @@ private inline class Buffer1DWrapper<T>(val buffer: Buffer<T>) : Structure1D<T>
override operator fun get(index: Int): T = buffer[index]
}
private inline class MutableBuffer1DWrapper<T>(val buffer: MutableBuffer<T>) : MutableStructure1D<T> {
override val shape: IntArray get() = intArrayOf(buffer.size)
override val size: Int get() = buffer.size
override fun elements(): Sequence<Pair<IntArray, T>> =
buffer.asSequence().mapIndexed { index, value -> intArrayOf(index) to value }
override operator fun get(index: Int): T = buffer[index]
override fun set(index: Int, value: T) {
buffer[index] = value
}
override fun copy(): MutableBuffer<T> = buffer.copy()
}
/**
* Represent a [NDStructure] as [Structure1D]. Throw error in case of dimension mismatch
*/
@ -52,7 +98,17 @@ public fun <T> NDStructure<T>.as1D(): Structure1D<T> = this as? Structure1D<T> ?
}
} else error("Can't create 1d-structure from ${shape.size}d-structure")
public fun <T> MutableNDStructure<T>.as1D(): MutableStructure1D<T> =
this as? MutableStructure1D<T> ?: if (shape.size == 1) {
when (this) {
is MutableNDBuffer -> MutableBuffer1DWrapper(this.buffer)
else -> MutableStructure1DWrapper(this)
}
} else error("Can't create 1d-structure from ${shape.size}d-structure")
/**
* Represent this buffer as 1D structure
*/
public fun <T> Buffer<T>.asND(): Structure1D<T> = Buffer1DWrapper(this)
public fun <T> MutableBuffer<T>.asND(): MutableStructure1D<T> = MutableBuffer1DWrapper(this)

View File

@ -57,6 +57,20 @@ public interface Structure2D<T> : NDStructure<T> {
public companion object
}
/**
* Represents mutable [Structure2D].
*/
public interface MutableStructure2D<T> : Structure2D<T>, MutableNDStructure<T> {
/**
* Inserts an item at the specified indices.
*
* @param i the first index.
* @param j the second index.
* @param value the value.
*/
public operator fun set(i: Int, j: Int, value: T)
}
/**
* A 2D wrapper for nd-structure
*/
@ -79,11 +93,46 @@ private class Structure2DWrapper<T>(val structure: NDStructure<T>) : Structure2D
}
/**
* Represent a [NDStructure] as [Structure1D]. Throw error in case of dimension mismatch
* A 2D wrapper for a mutable nd-structure
*/
private class MutableStructure2DWrapper<T>(val structure: MutableNDStructure<T>): MutableStructure2D<T>
{
override val shape: IntArray get() = structure.shape
override val rowNum: Int get() = shape[0]
override val colNum: Int get() = shape[1]
override operator fun get(i: Int, j: Int): T = structure[i, j]
override fun set(index: IntArray, value: T) {
structure[index] = value
}
override operator fun set(i: Int, j: Int, value: T){
structure[intArrayOf(i, j)] = value
}
override fun elements(): Sequence<Pair<IntArray, T>> = structure.elements()
override fun equals(other: Any?): Boolean = false
override fun hashCode(): Int = 0
}
/**
* Represent a [NDStructure] as [Structure2D]. Throw error in case of dimension mismatch
*/
public fun <T> NDStructure<T>.as2D(): Structure2D<T> = this as? Structure2D<T> ?: when (shape.size) {
2 -> Structure2DWrapper(this)
else -> error("Can't create 2d-structure from ${shape.size}d-structure")
}
internal fun <T> Structure2D<T>.unwrap(): NDStructure<T> = if (this is Structure2DWrapper) structure else this
internal fun <T> Structure2D<T>.unwrap(): NDStructure<T> = if (this is Structure2DWrapper) structure else this
public fun <T> MutableNDStructure<T>.as2D(): MutableStructure2D<T> = this as? MutableStructure2D<T> ?: when (shape.size) {
2 -> MutableStructure2DWrapper(this)
else -> error("Can't create 2d-structure from ${shape.size}d-structure")
}
internal fun <T> MutableStructure2D<T>.unwrap(): MutableNDStructure<T> =
if (this is MutableStructure2DWrapper) structure else this

View File

@ -236,6 +236,11 @@ public inline class MutableListBuffer<T>(public val list: MutableList<T>) : Muta
override fun copy(): MutableBuffer<T> = MutableListBuffer(ArrayList(list))
}
/**
* Returns an [MutableListBuffer] that wraps the original list.
*/
public fun <T> MutableList<T>.asMutableBuffer(): MutableListBuffer<T> = MutableListBuffer(this)
/**
* [MutableBuffer] implementation over [Array].
*

View File

@ -25,7 +25,6 @@ public open class BufferedTensor<T>(
this[intArrayOf(i, j)] = value
}
}
//todo make generator mb nextMatrixIndex?

View File

@ -10,6 +10,10 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
return this.buffer.unsafeToDoubleArray()[0]
}
override fun DoubleTensor.get(i: Int): DoubleTensor {
TODO("Not yet implemented")
}
override fun zeros(shape: IntArray): DoubleTensor {
TODO("Not yet implemented")
}

View File

@ -41,6 +41,8 @@ public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
public operator fun TensorType.timesAssign(other: TensorType): Unit
public operator fun TensorType.unaryMinus(): TensorType
//https://pytorch.org/cppdocs/notes/tensor_indexing.html
public fun TensorType.get(i: Int): TensorType
//https://pytorch.org/docs/stable/generated/torch.transpose.html
public fun TensorType.transpose(i: Int, j: Int): TensorType

View File

@ -3,4 +3,3 @@ package space.kscience.kmath.tensors
import space.kscience.kmath.nd.MutableNDStructure
public typealias TensorStructure<T> = MutableNDStructure<T>