From f8e0d4be17baac7e2f5930083b2bfeb7338f99e8 Mon Sep 17 00:00:00 2001 From: Roland Grinis Date: Mon, 15 Mar 2021 21:18:15 +0000 Subject: [PATCH] MutableStructure 2D & 1D --- .../space/kscience/kmath/nd/Structure1D.kt | 56 +++++++++++++++++++ .../space/kscience/kmath/nd/Structure2D.kt | 53 +++++++++++++++++- .../space/kscience/kmath/structures/Buffer.kt | 5 ++ .../kscience/kmath/tensors/BufferedTensor.kt | 1 - .../kmath/tensors/DoubleTensorAlgebra.kt | 4 ++ .../kscience/kmath/tensors/TensorAlgebra.kt | 2 + .../kscience/kmath/tensors/TensorStructure.kt | 1 - 7 files changed, 118 insertions(+), 4 deletions(-) diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/Structure1D.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/Structure1D.kt index 1335a4933..2926b3d1b 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/Structure1D.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/Structure1D.kt @@ -1,6 +1,8 @@ package space.kscience.kmath.nd import space.kscience.kmath.structures.Buffer +import space.kscience.kmath.structures.MutableBuffer +import space.kscience.kmath.structures.asMutableBuffer import space.kscience.kmath.structures.asSequence /** @@ -17,6 +19,16 @@ public interface Structure1D : NDStructure, Buffer { public override operator fun iterator(): Iterator = (0 until size).asSequence().map(::get).iterator() } +/** + * A mutable structure that is guaranteed to be one-dimensional + */ +public interface MutableStructure1D : Structure1D, MutableNDStructure, MutableBuffer { + public override operator fun set(index: IntArray, value: T) { + require(index.size == 1) { "Index dimension mismatch. Expected 1 but found ${index.size}" } + set(index[0], value) + } +} + /** * A 1D wrapper for nd-structure */ @@ -28,6 +40,25 @@ private inline class Structure1DWrapper(val structure: NDStructure) : Stru override fun elements(): Sequence> = structure.elements() } +/** + * A 1D wrapper for a mutable nd-structure + */ +private inline class MutableStructure1DWrapper(val structure: MutableNDStructure) : MutableStructure1D { + override val shape: IntArray get() = structure.shape + override val size: Int get() = structure.shape[0] + override fun elements(): Sequence> { + TODO("Not yet implemented") + } + + override fun get(index: Int): T = structure[index] + override fun set(index: Int, value: T) { + set(index, value) + } + + override fun copy(): MutableBuffer = + structure.elements().map { it.second }.toMutableList().asMutableBuffer() +} + /** * A structure wrapper for buffer @@ -42,6 +73,21 @@ private inline class Buffer1DWrapper(val buffer: Buffer) : Structure1D override operator fun get(index: Int): T = buffer[index] } +private inline class MutableBuffer1DWrapper(val buffer: MutableBuffer) : MutableStructure1D { + override val shape: IntArray get() = intArrayOf(buffer.size) + override val size: Int get() = buffer.size + + override fun elements(): Sequence> = + buffer.asSequence().mapIndexed { index, value -> intArrayOf(index) to value } + + override operator fun get(index: Int): T = buffer[index] + override fun set(index: Int, value: T) { + buffer[index] = value + } + + override fun copy(): MutableBuffer = buffer.copy() +} + /** * Represent a [NDStructure] as [Structure1D]. Throw error in case of dimension mismatch */ @@ -52,7 +98,17 @@ public fun NDStructure.as1D(): Structure1D = this as? Structure1D ? } } else error("Can't create 1d-structure from ${shape.size}d-structure") +public fun MutableNDStructure.as1D(): MutableStructure1D = + this as? MutableStructure1D ?: if (shape.size == 1) { + when (this) { + is MutableNDBuffer -> MutableBuffer1DWrapper(this.buffer) + else -> MutableStructure1DWrapper(this) + } + } else error("Can't create 1d-structure from ${shape.size}d-structure") + /** * Represent this buffer as 1D structure */ public fun Buffer.asND(): Structure1D = Buffer1DWrapper(this) + +public fun MutableBuffer.asND(): MutableStructure1D = MutableBuffer1DWrapper(this) diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/Structure2D.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/Structure2D.kt index e9f8234e5..2f2fd653e 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/Structure2D.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/Structure2D.kt @@ -57,6 +57,20 @@ public interface Structure2D : NDStructure { public companion object } +/** + * Represents mutable [Structure2D]. + */ +public interface MutableStructure2D : Structure2D, MutableNDStructure { + /** + * Inserts an item at the specified indices. + * + * @param i the first index. + * @param j the second index. + * @param value the value. + */ + public operator fun set(i: Int, j: Int, value: T) +} + /** * A 2D wrapper for nd-structure */ @@ -79,11 +93,46 @@ private class Structure2DWrapper(val structure: NDStructure) : Structure2D } /** - * Represent a [NDStructure] as [Structure1D]. Throw error in case of dimension mismatch + * A 2D wrapper for a mutable nd-structure + */ +private class MutableStructure2DWrapper(val structure: MutableNDStructure): MutableStructure2D +{ + override val shape: IntArray get() = structure.shape + + override val rowNum: Int get() = shape[0] + override val colNum: Int get() = shape[1] + + override operator fun get(i: Int, j: Int): T = structure[i, j] + + override fun set(index: IntArray, value: T) { + structure[index] = value + } + + override operator fun set(i: Int, j: Int, value: T){ + structure[intArrayOf(i, j)] = value + } + + override fun elements(): Sequence> = structure.elements() + + override fun equals(other: Any?): Boolean = false + + override fun hashCode(): Int = 0 +} + +/** + * Represent a [NDStructure] as [Structure2D]. Throw error in case of dimension mismatch */ public fun NDStructure.as2D(): Structure2D = this as? Structure2D ?: when (shape.size) { 2 -> Structure2DWrapper(this) else -> error("Can't create 2d-structure from ${shape.size}d-structure") } -internal fun Structure2D.unwrap(): NDStructure = if (this is Structure2DWrapper) structure else this \ No newline at end of file +internal fun Structure2D.unwrap(): NDStructure = if (this is Structure2DWrapper) structure else this + +public fun MutableNDStructure.as2D(): MutableStructure2D = this as? MutableStructure2D ?: when (shape.size) { + 2 -> MutableStructure2DWrapper(this) + else -> error("Can't create 2d-structure from ${shape.size}d-structure") +} + +internal fun MutableStructure2D.unwrap(): MutableNDStructure = + if (this is MutableStructure2DWrapper) structure else this diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/Buffer.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/Buffer.kt index 2bde18fce..c62fa30ba 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/Buffer.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/Buffer.kt @@ -236,6 +236,11 @@ public inline class MutableListBuffer(public val list: MutableList) : Muta override fun copy(): MutableBuffer = MutableListBuffer(ArrayList(list)) } +/** + * Returns an [MutableListBuffer] that wraps the original list. + */ +public fun MutableList.asMutableBuffer(): MutableListBuffer = MutableListBuffer(this) + /** * [MutableBuffer] implementation over [Array]. * diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/BufferedTensor.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/BufferedTensor.kt index d5adf380c..68fc0412e 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/BufferedTensor.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/BufferedTensor.kt @@ -25,7 +25,6 @@ public open class BufferedTensor( this[intArrayOf(i, j)] = value } - } //todo make generator mb nextMatrixIndex? diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/DoubleTensorAlgebra.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/DoubleTensorAlgebra.kt index 8b9701127..76a3c4c9c 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/DoubleTensorAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/DoubleTensorAlgebra.kt @@ -10,6 +10,10 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra> { public operator fun TensorType.timesAssign(other: TensorType): Unit public operator fun TensorType.unaryMinus(): TensorType + //https://pytorch.org/cppdocs/notes/tensor_indexing.html + public fun TensorType.get(i: Int): TensorType //https://pytorch.org/docs/stable/generated/torch.transpose.html public fun TensorType.transpose(i: Int, j: Int): TensorType diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/TensorStructure.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/TensorStructure.kt index 5463877ce..f5ea39d1b 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/TensorStructure.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/TensorStructure.kt @@ -3,4 +3,3 @@ package space.kscience.kmath.tensors import space.kscience.kmath.nd.MutableNDStructure public typealias TensorStructure = MutableNDStructure -