Specialized 1d and 2d NDStructures

This commit is contained in:
Alexander Nozik 2019-01-31 18:19:02 +03:00
parent a2ef50ab47
commit 388d016fdf
5 changed files with 118 additions and 7 deletions

View File

@ -1,3 +1,5 @@
import org.jetbrains.kotlin.gradle.dsl.KotlinMultiplatformExtension
buildscript {
val kotlinVersion: String by rootProject.extra("1.3.20")
val ioVersion: String by rootProject.extra("0.1.2")
@ -27,11 +29,27 @@ allprojects {
version = "0.0.3-dev-4"
repositories {
maven("https://dl.bintray.com/kotlin/kotlin-eap")
//maven("https://dl.bintray.com/kotlin/kotlin-eap")
jcenter()
}
extensions.findByType<KotlinMultiplatformExtension>()?.apply {
jvm {
compilations.all {
kotlinOptions {
jvmTarget = "1.8"
}
}
}
targets.all {
sourceSets.all {
languageSettings.progressiveMode = true
languageSettings.enableLanguageFeature("+InlineClasses")
}
}
}
}
if (file("artifactory.gradle").exists()) {
apply(from = "artifactory.gradle")
}
}

View File

@ -1,3 +1,5 @@
import org.jetbrains.kotlin.gradle.tasks.KotlinCompile
plugins {
kotlin("multiplatform")
}
@ -8,13 +10,16 @@ kotlin {
compilations.all {
kotlinOptions {
jvmTarget = "1.8"
freeCompilerArgs += "-progressive"
}
}
}
js()
sourceSets {
all {
languageSettings.progressiveMode = true
}
val commonMain by getting {
dependencies {
api(kotlin("stdlib"))

View File

@ -110,7 +110,7 @@ interface GenericMatrixContext<T : Any, R : Ring<T>> : MatrixContext<T> {
/**
* Specialized 2-d structure
*/
interface Matrix<T : Any> : NDStructure<T> {
interface Matrix<T : Any> : Structure2D<T> {
val rowNum: Int
val colNum: Int
@ -124,8 +124,6 @@ interface Matrix<T : Any> : NDStructure<T> {
*/
fun suggestFeature(vararg features: MatrixFeature): Matrix<T>
operator fun get(i: Int, j: Int): T
override fun get(index: IntArray): T = get(index[0], index[1])
override val shape: IntArray get() = intArrayOf(rowNum, colNum)

View File

@ -16,7 +16,9 @@ interface NDStructure<T> {
fun equals(st1: NDStructure<*>, st2: NDStructure<*>): Boolean {
return when {
st1 === st2 -> true
st1 is BufferNDStructure<*> && st2 is BufferNDStructure<*> && st1.strides == st2.strides -> st1.buffer.contentEquals(st2.buffer)
st1 is BufferNDStructure<*> && st2 is BufferNDStructure<*> && st1.strides == st2.strides -> st1.buffer.contentEquals(
st2.buffer
)
else -> st1.elements().all { (index, value) -> value == st2[index] }
}
}

View File

@ -0,0 +1,88 @@
package scientifik.kmath.structures
/**
* A structure that is guaranteed to be one-dimensional
*/
interface Structure1D<T> : NDStructure<T>, Buffer<T> {
override val dimension: Int get() = 1
override fun get(index: IntArray): T {
if (index.size != 1) error("Index dimension mismatch. Expected 1 but found ${index.size}")
return get(index[0])
}
override fun iterator(): Iterator<T> = (0 until size).asSequence().map { get(it) }.iterator()
}
/**
* A 1D wrapper for nd-structure
*/
private inline class Structure1DWrapper<T>(val structure: NDStructure<T>) : Structure1D<T> {
override val shape: IntArray get() = structure.shape
override val size: Int get() = structure.shape[0]
override fun get(index: Int): T = structure[index]
override fun elements(): Sequence<Pair<IntArray, T>> = structure.elements()
}
/**
* Represent a [NDStructure] as [Structure1D]. Throw error in case of dimension mismatch
*/
fun <T> NDStructure<T>.as1D(): Structure1D<T> = if (shape.size == 1) {
Structure1DWrapper(this)
} else {
error("Can't create 1d-structure from ${shape.size}d-structure")
}
/**
* A structure wrapper for buffer
*/
private inline class Buffer1DWrapper<T>(val buffer: Buffer<T>) : Structure1D<T> {
override val shape: IntArray get() = intArrayOf(buffer.size)
override val size: Int get() = buffer.size
override fun elements(): Sequence<Pair<IntArray, T>> =
asSequence().mapIndexed { index, value -> intArrayOf(index) to value }
override fun get(index: Int): T = buffer.get(index)
}
/**
* Represent this buffer as 1D structure
*/
fun <T> Buffer<T>.asND(): Structure1D<T> = Buffer1DWrapper(this)
/**
* A structure that is guaranteed to be two-dimensional
*/
interface Structure2D<T> : NDStructure<T> {
operator fun get(i: Int, j: Int): T
override fun get(index: IntArray): T {
if (index.size != 2) error("Index dimension mismatch. Expected 2 but found ${index.size}")
return get(index[0], index[1])
}
}
/**
* A 2D wrapper for nd-structure
*/
private inline class Structure2DWrapper<T>(val structure: NDStructure<T>) : Structure2D<T> {
override fun get(i: Int, j: Int): T = structure[i, j]
override val shape: IntArray get() = structure.shape
override fun elements(): Sequence<Pair<IntArray, T>> = structure.elements()
}
/**
* Represent a [NDStructure] as [Structure1D]. Throw error in case of dimension mismatch
*/
fun <T> NDStructure<T>.as2D(): Structure2D<T> = if (shape.size == 2) {
Structure2DWrapper(this)
} else {
error("Can't create 2d-structure from ${shape.size}d-structure")
}