forked from kscience/kmath
MutableStructure 2D & 1D
This commit is contained in:
parent
b227a82a80
commit
f8e0d4be17
@ -1,6 +1,8 @@
|
||||
package space.kscience.kmath.nd
|
||||
|
||||
import space.kscience.kmath.structures.Buffer
|
||||
import space.kscience.kmath.structures.MutableBuffer
|
||||
import space.kscience.kmath.structures.asMutableBuffer
|
||||
import space.kscience.kmath.structures.asSequence
|
||||
|
||||
/**
|
||||
@ -17,6 +19,16 @@ public interface Structure1D<T> : NDStructure<T>, Buffer<T> {
|
||||
public override operator fun iterator(): Iterator<T> = (0 until size).asSequence().map(::get).iterator()
|
||||
}
|
||||
|
||||
/**
|
||||
* A mutable structure that is guaranteed to be one-dimensional
|
||||
*/
|
||||
public interface MutableStructure1D<T> : Structure1D<T>, MutableNDStructure<T>, MutableBuffer<T> {
|
||||
public override operator fun set(index: IntArray, value: T) {
|
||||
require(index.size == 1) { "Index dimension mismatch. Expected 1 but found ${index.size}" }
|
||||
set(index[0], value)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A 1D wrapper for nd-structure
|
||||
*/
|
||||
@ -28,6 +40,25 @@ private inline class Structure1DWrapper<T>(val structure: NDStructure<T>) : Stru
|
||||
override fun elements(): Sequence<Pair<IntArray, T>> = structure.elements()
|
||||
}
|
||||
|
||||
/**
|
||||
* A 1D wrapper for a mutable nd-structure
|
||||
*/
|
||||
private inline class MutableStructure1DWrapper<T>(val structure: MutableNDStructure<T>) : MutableStructure1D<T> {
|
||||
override val shape: IntArray get() = structure.shape
|
||||
override val size: Int get() = structure.shape[0]
|
||||
override fun elements(): Sequence<Pair<IntArray, T>> {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun get(index: Int): T = structure[index]
|
||||
override fun set(index: Int, value: T) {
|
||||
set(index, value)
|
||||
}
|
||||
|
||||
override fun copy(): MutableBuffer<T> =
|
||||
structure.elements().map { it.second }.toMutableList().asMutableBuffer()
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* A structure wrapper for buffer
|
||||
@ -42,6 +73,21 @@ private inline class Buffer1DWrapper<T>(val buffer: Buffer<T>) : Structure1D<T>
|
||||
override operator fun get(index: Int): T = buffer[index]
|
||||
}
|
||||
|
||||
private inline class MutableBuffer1DWrapper<T>(val buffer: MutableBuffer<T>) : MutableStructure1D<T> {
|
||||
override val shape: IntArray get() = intArrayOf(buffer.size)
|
||||
override val size: Int get() = buffer.size
|
||||
|
||||
override fun elements(): Sequence<Pair<IntArray, T>> =
|
||||
buffer.asSequence().mapIndexed { index, value -> intArrayOf(index) to value }
|
||||
|
||||
override operator fun get(index: Int): T = buffer[index]
|
||||
override fun set(index: Int, value: T) {
|
||||
buffer[index] = value
|
||||
}
|
||||
|
||||
override fun copy(): MutableBuffer<T> = buffer.copy()
|
||||
}
|
||||
|
||||
/**
|
||||
* Represent a [NDStructure] as [Structure1D]. Throw error in case of dimension mismatch
|
||||
*/
|
||||
@ -52,7 +98,17 @@ public fun <T> NDStructure<T>.as1D(): Structure1D<T> = this as? Structure1D<T> ?
|
||||
}
|
||||
} else error("Can't create 1d-structure from ${shape.size}d-structure")
|
||||
|
||||
public fun <T> MutableNDStructure<T>.as1D(): MutableStructure1D<T> =
|
||||
this as? MutableStructure1D<T> ?: if (shape.size == 1) {
|
||||
when (this) {
|
||||
is MutableNDBuffer -> MutableBuffer1DWrapper(this.buffer)
|
||||
else -> MutableStructure1DWrapper(this)
|
||||
}
|
||||
} else error("Can't create 1d-structure from ${shape.size}d-structure")
|
||||
|
||||
/**
|
||||
* Represent this buffer as 1D structure
|
||||
*/
|
||||
public fun <T> Buffer<T>.asND(): Structure1D<T> = Buffer1DWrapper(this)
|
||||
|
||||
public fun <T> MutableBuffer<T>.asND(): MutableStructure1D<T> = MutableBuffer1DWrapper(this)
|
||||
|
@ -57,6 +57,20 @@ public interface Structure2D<T> : NDStructure<T> {
|
||||
public companion object
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents mutable [Structure2D].
|
||||
*/
|
||||
public interface MutableStructure2D<T> : Structure2D<T>, MutableNDStructure<T> {
|
||||
/**
|
||||
* Inserts an item at the specified indices.
|
||||
*
|
||||
* @param i the first index.
|
||||
* @param j the second index.
|
||||
* @param value the value.
|
||||
*/
|
||||
public operator fun set(i: Int, j: Int, value: T)
|
||||
}
|
||||
|
||||
/**
|
||||
* A 2D wrapper for nd-structure
|
||||
*/
|
||||
@ -79,11 +93,46 @@ private class Structure2DWrapper<T>(val structure: NDStructure<T>) : Structure2D
|
||||
}
|
||||
|
||||
/**
|
||||
* Represent a [NDStructure] as [Structure1D]. Throw error in case of dimension mismatch
|
||||
* A 2D wrapper for a mutable nd-structure
|
||||
*/
|
||||
private class MutableStructure2DWrapper<T>(val structure: MutableNDStructure<T>): MutableStructure2D<T>
|
||||
{
|
||||
override val shape: IntArray get() = structure.shape
|
||||
|
||||
override val rowNum: Int get() = shape[0]
|
||||
override val colNum: Int get() = shape[1]
|
||||
|
||||
override operator fun get(i: Int, j: Int): T = structure[i, j]
|
||||
|
||||
override fun set(index: IntArray, value: T) {
|
||||
structure[index] = value
|
||||
}
|
||||
|
||||
override operator fun set(i: Int, j: Int, value: T){
|
||||
structure[intArrayOf(i, j)] = value
|
||||
}
|
||||
|
||||
override fun elements(): Sequence<Pair<IntArray, T>> = structure.elements()
|
||||
|
||||
override fun equals(other: Any?): Boolean = false
|
||||
|
||||
override fun hashCode(): Int = 0
|
||||
}
|
||||
|
||||
/**
|
||||
* Represent a [NDStructure] as [Structure2D]. Throw error in case of dimension mismatch
|
||||
*/
|
||||
public fun <T> NDStructure<T>.as2D(): Structure2D<T> = this as? Structure2D<T> ?: when (shape.size) {
|
||||
2 -> Structure2DWrapper(this)
|
||||
else -> error("Can't create 2d-structure from ${shape.size}d-structure")
|
||||
}
|
||||
|
||||
internal fun <T> Structure2D<T>.unwrap(): NDStructure<T> = if (this is Structure2DWrapper) structure else this
|
||||
internal fun <T> Structure2D<T>.unwrap(): NDStructure<T> = if (this is Structure2DWrapper) structure else this
|
||||
|
||||
public fun <T> MutableNDStructure<T>.as2D(): MutableStructure2D<T> = this as? MutableStructure2D<T> ?: when (shape.size) {
|
||||
2 -> MutableStructure2DWrapper(this)
|
||||
else -> error("Can't create 2d-structure from ${shape.size}d-structure")
|
||||
}
|
||||
|
||||
internal fun <T> MutableStructure2D<T>.unwrap(): MutableNDStructure<T> =
|
||||
if (this is MutableStructure2DWrapper) structure else this
|
||||
|
@ -236,6 +236,11 @@ public inline class MutableListBuffer<T>(public val list: MutableList<T>) : Muta
|
||||
override fun copy(): MutableBuffer<T> = MutableListBuffer(ArrayList(list))
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns an [MutableListBuffer] that wraps the original list.
|
||||
*/
|
||||
public fun <T> MutableList<T>.asMutableBuffer(): MutableListBuffer<T> = MutableListBuffer(this)
|
||||
|
||||
/**
|
||||
* [MutableBuffer] implementation over [Array].
|
||||
*
|
||||
|
@ -25,7 +25,6 @@ public open class BufferedTensor<T>(
|
||||
this[intArrayOf(i, j)] = value
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
//todo make generator mb nextMatrixIndex?
|
||||
|
@ -10,6 +10,10 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
||||
return this.buffer.unsafeToDoubleArray()[0]
|
||||
}
|
||||
|
||||
override fun DoubleTensor.get(i: Int): DoubleTensor {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun zeros(shape: IntArray): DoubleTensor {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
@ -41,6 +41,8 @@ public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
|
||||
public operator fun TensorType.timesAssign(other: TensorType): Unit
|
||||
public operator fun TensorType.unaryMinus(): TensorType
|
||||
|
||||
//https://pytorch.org/cppdocs/notes/tensor_indexing.html
|
||||
public fun TensorType.get(i: Int): TensorType
|
||||
|
||||
//https://pytorch.org/docs/stable/generated/torch.transpose.html
|
||||
public fun TensorType.transpose(i: Int, j: Int): TensorType
|
||||
|
@ -3,4 +3,3 @@ package space.kscience.kmath.tensors
|
||||
import space.kscience.kmath.nd.MutableNDStructure
|
||||
|
||||
public typealias TensorStructure<T> = MutableNDStructure<T>
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user