MutableStructure 2D & 1D
This commit is contained in:
parent
b227a82a80
commit
f8e0d4be17
@ -1,6 +1,8 @@
|
|||||||
package space.kscience.kmath.nd
|
package space.kscience.kmath.nd
|
||||||
|
|
||||||
import space.kscience.kmath.structures.Buffer
|
import space.kscience.kmath.structures.Buffer
|
||||||
|
import space.kscience.kmath.structures.MutableBuffer
|
||||||
|
import space.kscience.kmath.structures.asMutableBuffer
|
||||||
import space.kscience.kmath.structures.asSequence
|
import space.kscience.kmath.structures.asSequence
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -17,6 +19,16 @@ public interface Structure1D<T> : NDStructure<T>, Buffer<T> {
|
|||||||
public override operator fun iterator(): Iterator<T> = (0 until size).asSequence().map(::get).iterator()
|
public override operator fun iterator(): Iterator<T> = (0 until size).asSequence().map(::get).iterator()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A mutable structure that is guaranteed to be one-dimensional
|
||||||
|
*/
|
||||||
|
public interface MutableStructure1D<T> : Structure1D<T>, MutableNDStructure<T>, MutableBuffer<T> {
|
||||||
|
public override operator fun set(index: IntArray, value: T) {
|
||||||
|
require(index.size == 1) { "Index dimension mismatch. Expected 1 but found ${index.size}" }
|
||||||
|
set(index[0], value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A 1D wrapper for nd-structure
|
* A 1D wrapper for nd-structure
|
||||||
*/
|
*/
|
||||||
@ -28,6 +40,25 @@ private inline class Structure1DWrapper<T>(val structure: NDStructure<T>) : Stru
|
|||||||
override fun elements(): Sequence<Pair<IntArray, T>> = structure.elements()
|
override fun elements(): Sequence<Pair<IntArray, T>> = structure.elements()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A 1D wrapper for a mutable nd-structure
|
||||||
|
*/
|
||||||
|
private inline class MutableStructure1DWrapper<T>(val structure: MutableNDStructure<T>) : MutableStructure1D<T> {
|
||||||
|
override val shape: IntArray get() = structure.shape
|
||||||
|
override val size: Int get() = structure.shape[0]
|
||||||
|
override fun elements(): Sequence<Pair<IntArray, T>> {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun get(index: Int): T = structure[index]
|
||||||
|
override fun set(index: Int, value: T) {
|
||||||
|
set(index, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun copy(): MutableBuffer<T> =
|
||||||
|
structure.elements().map { it.second }.toMutableList().asMutableBuffer()
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A structure wrapper for buffer
|
* A structure wrapper for buffer
|
||||||
@ -42,6 +73,21 @@ private inline class Buffer1DWrapper<T>(val buffer: Buffer<T>) : Structure1D<T>
|
|||||||
override operator fun get(index: Int): T = buffer[index]
|
override operator fun get(index: Int): T = buffer[index]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private inline class MutableBuffer1DWrapper<T>(val buffer: MutableBuffer<T>) : MutableStructure1D<T> {
|
||||||
|
override val shape: IntArray get() = intArrayOf(buffer.size)
|
||||||
|
override val size: Int get() = buffer.size
|
||||||
|
|
||||||
|
override fun elements(): Sequence<Pair<IntArray, T>> =
|
||||||
|
buffer.asSequence().mapIndexed { index, value -> intArrayOf(index) to value }
|
||||||
|
|
||||||
|
override operator fun get(index: Int): T = buffer[index]
|
||||||
|
override fun set(index: Int, value: T) {
|
||||||
|
buffer[index] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun copy(): MutableBuffer<T> = buffer.copy()
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Represent a [NDStructure] as [Structure1D]. Throw error in case of dimension mismatch
|
* Represent a [NDStructure] as [Structure1D]. Throw error in case of dimension mismatch
|
||||||
*/
|
*/
|
||||||
@ -52,7 +98,17 @@ public fun <T> NDStructure<T>.as1D(): Structure1D<T> = this as? Structure1D<T> ?
|
|||||||
}
|
}
|
||||||
} else error("Can't create 1d-structure from ${shape.size}d-structure")
|
} else error("Can't create 1d-structure from ${shape.size}d-structure")
|
||||||
|
|
||||||
|
public fun <T> MutableNDStructure<T>.as1D(): MutableStructure1D<T> =
|
||||||
|
this as? MutableStructure1D<T> ?: if (shape.size == 1) {
|
||||||
|
when (this) {
|
||||||
|
is MutableNDBuffer -> MutableBuffer1DWrapper(this.buffer)
|
||||||
|
else -> MutableStructure1DWrapper(this)
|
||||||
|
}
|
||||||
|
} else error("Can't create 1d-structure from ${shape.size}d-structure")
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Represent this buffer as 1D structure
|
* Represent this buffer as 1D structure
|
||||||
*/
|
*/
|
||||||
public fun <T> Buffer<T>.asND(): Structure1D<T> = Buffer1DWrapper(this)
|
public fun <T> Buffer<T>.asND(): Structure1D<T> = Buffer1DWrapper(this)
|
||||||
|
|
||||||
|
public fun <T> MutableBuffer<T>.asND(): MutableStructure1D<T> = MutableBuffer1DWrapper(this)
|
||||||
|
@ -57,6 +57,20 @@ public interface Structure2D<T> : NDStructure<T> {
|
|||||||
public companion object
|
public companion object
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents mutable [Structure2D].
|
||||||
|
*/
|
||||||
|
public interface MutableStructure2D<T> : Structure2D<T>, MutableNDStructure<T> {
|
||||||
|
/**
|
||||||
|
* Inserts an item at the specified indices.
|
||||||
|
*
|
||||||
|
* @param i the first index.
|
||||||
|
* @param j the second index.
|
||||||
|
* @param value the value.
|
||||||
|
*/
|
||||||
|
public operator fun set(i: Int, j: Int, value: T)
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A 2D wrapper for nd-structure
|
* A 2D wrapper for nd-structure
|
||||||
*/
|
*/
|
||||||
@ -79,7 +93,34 @@ private class Structure2DWrapper<T>(val structure: NDStructure<T>) : Structure2D
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Represent a [NDStructure] as [Structure1D]. Throw error in case of dimension mismatch
|
* A 2D wrapper for a mutable nd-structure
|
||||||
|
*/
|
||||||
|
private class MutableStructure2DWrapper<T>(val structure: MutableNDStructure<T>): MutableStructure2D<T>
|
||||||
|
{
|
||||||
|
override val shape: IntArray get() = structure.shape
|
||||||
|
|
||||||
|
override val rowNum: Int get() = shape[0]
|
||||||
|
override val colNum: Int get() = shape[1]
|
||||||
|
|
||||||
|
override operator fun get(i: Int, j: Int): T = structure[i, j]
|
||||||
|
|
||||||
|
override fun set(index: IntArray, value: T) {
|
||||||
|
structure[index] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
override operator fun set(i: Int, j: Int, value: T){
|
||||||
|
structure[intArrayOf(i, j)] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun elements(): Sequence<Pair<IntArray, T>> = structure.elements()
|
||||||
|
|
||||||
|
override fun equals(other: Any?): Boolean = false
|
||||||
|
|
||||||
|
override fun hashCode(): Int = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represent a [NDStructure] as [Structure2D]. Throw error in case of dimension mismatch
|
||||||
*/
|
*/
|
||||||
public fun <T> NDStructure<T>.as2D(): Structure2D<T> = this as? Structure2D<T> ?: when (shape.size) {
|
public fun <T> NDStructure<T>.as2D(): Structure2D<T> = this as? Structure2D<T> ?: when (shape.size) {
|
||||||
2 -> Structure2DWrapper(this)
|
2 -> Structure2DWrapper(this)
|
||||||
@ -87,3 +128,11 @@ public fun <T> NDStructure<T>.as2D(): Structure2D<T> = this as? Structure2D<T> ?
|
|||||||
}
|
}
|
||||||
|
|
||||||
internal fun <T> Structure2D<T>.unwrap(): NDStructure<T> = if (this is Structure2DWrapper) structure else this
|
internal fun <T> Structure2D<T>.unwrap(): NDStructure<T> = if (this is Structure2DWrapper) structure else this
|
||||||
|
|
||||||
|
public fun <T> MutableNDStructure<T>.as2D(): MutableStructure2D<T> = this as? MutableStructure2D<T> ?: when (shape.size) {
|
||||||
|
2 -> MutableStructure2DWrapper(this)
|
||||||
|
else -> error("Can't create 2d-structure from ${shape.size}d-structure")
|
||||||
|
}
|
||||||
|
|
||||||
|
internal fun <T> MutableStructure2D<T>.unwrap(): MutableNDStructure<T> =
|
||||||
|
if (this is MutableStructure2DWrapper) structure else this
|
||||||
|
@ -236,6 +236,11 @@ public inline class MutableListBuffer<T>(public val list: MutableList<T>) : Muta
|
|||||||
override fun copy(): MutableBuffer<T> = MutableListBuffer(ArrayList(list))
|
override fun copy(): MutableBuffer<T> = MutableListBuffer(ArrayList(list))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns an [MutableListBuffer] that wraps the original list.
|
||||||
|
*/
|
||||||
|
public fun <T> MutableList<T>.asMutableBuffer(): MutableListBuffer<T> = MutableListBuffer(this)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* [MutableBuffer] implementation over [Array].
|
* [MutableBuffer] implementation over [Array].
|
||||||
*
|
*
|
||||||
|
@ -25,7 +25,6 @@ public open class BufferedTensor<T>(
|
|||||||
this[intArrayOf(i, j)] = value
|
this[intArrayOf(i, j)] = value
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//todo make generator mb nextMatrixIndex?
|
//todo make generator mb nextMatrixIndex?
|
||||||
|
@ -10,6 +10,10 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
|||||||
return this.buffer.unsafeToDoubleArray()[0]
|
return this.buffer.unsafeToDoubleArray()[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override fun DoubleTensor.get(i: Int): DoubleTensor {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
override fun zeros(shape: IntArray): DoubleTensor {
|
override fun zeros(shape: IntArray): DoubleTensor {
|
||||||
TODO("Not yet implemented")
|
TODO("Not yet implemented")
|
||||||
}
|
}
|
||||||
|
@ -41,6 +41,8 @@ public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
|
|||||||
public operator fun TensorType.timesAssign(other: TensorType): Unit
|
public operator fun TensorType.timesAssign(other: TensorType): Unit
|
||||||
public operator fun TensorType.unaryMinus(): TensorType
|
public operator fun TensorType.unaryMinus(): TensorType
|
||||||
|
|
||||||
|
//https://pytorch.org/cppdocs/notes/tensor_indexing.html
|
||||||
|
public fun TensorType.get(i: Int): TensorType
|
||||||
|
|
||||||
//https://pytorch.org/docs/stable/generated/torch.transpose.html
|
//https://pytorch.org/docs/stable/generated/torch.transpose.html
|
||||||
public fun TensorType.transpose(i: Int, j: Int): TensorType
|
public fun TensorType.transpose(i: Int, j: Int): TensorType
|
||||||
|
@ -3,4 +3,3 @@ package space.kscience.kmath.tensors
|
|||||||
import space.kscience.kmath.nd.MutableNDStructure
|
import space.kscience.kmath.nd.MutableNDStructure
|
||||||
|
|
||||||
public typealias TensorStructure<T> = MutableNDStructure<T>
|
public typealias TensorStructure<T> = MutableNDStructure<T>
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user