forked from kscience/kmath
Specialized 1d and 2d NDStructures
This commit is contained in:
parent
a2ef50ab47
commit
388d016fdf
@ -1,3 +1,5 @@
|
|||||||
|
import org.jetbrains.kotlin.gradle.dsl.KotlinMultiplatformExtension
|
||||||
|
|
||||||
buildscript {
|
buildscript {
|
||||||
val kotlinVersion: String by rootProject.extra("1.3.20")
|
val kotlinVersion: String by rootProject.extra("1.3.20")
|
||||||
val ioVersion: String by rootProject.extra("0.1.2")
|
val ioVersion: String by rootProject.extra("0.1.2")
|
||||||
@ -27,9 +29,25 @@ allprojects {
|
|||||||
version = "0.0.3-dev-4"
|
version = "0.0.3-dev-4"
|
||||||
|
|
||||||
repositories {
|
repositories {
|
||||||
maven("https://dl.bintray.com/kotlin/kotlin-eap")
|
//maven("https://dl.bintray.com/kotlin/kotlin-eap")
|
||||||
jcenter()
|
jcenter()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
extensions.findByType<KotlinMultiplatformExtension>()?.apply {
|
||||||
|
jvm {
|
||||||
|
compilations.all {
|
||||||
|
kotlinOptions {
|
||||||
|
jvmTarget = "1.8"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
targets.all {
|
||||||
|
sourceSets.all {
|
||||||
|
languageSettings.progressiveMode = true
|
||||||
|
languageSettings.enableLanguageFeature("+InlineClasses")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (file("artifactory.gradle").exists()) {
|
if (file("artifactory.gradle").exists()) {
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
import org.jetbrains.kotlin.gradle.tasks.KotlinCompile
|
||||||
|
|
||||||
plugins {
|
plugins {
|
||||||
kotlin("multiplatform")
|
kotlin("multiplatform")
|
||||||
}
|
}
|
||||||
@ -8,13 +10,16 @@ kotlin {
|
|||||||
compilations.all {
|
compilations.all {
|
||||||
kotlinOptions {
|
kotlinOptions {
|
||||||
jvmTarget = "1.8"
|
jvmTarget = "1.8"
|
||||||
freeCompilerArgs += "-progressive"
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
js()
|
js()
|
||||||
|
|
||||||
sourceSets {
|
sourceSets {
|
||||||
|
all {
|
||||||
|
languageSettings.progressiveMode = true
|
||||||
|
}
|
||||||
|
|
||||||
val commonMain by getting {
|
val commonMain by getting {
|
||||||
dependencies {
|
dependencies {
|
||||||
api(kotlin("stdlib"))
|
api(kotlin("stdlib"))
|
||||||
|
@ -110,7 +110,7 @@ interface GenericMatrixContext<T : Any, R : Ring<T>> : MatrixContext<T> {
|
|||||||
/**
|
/**
|
||||||
* Specialized 2-d structure
|
* Specialized 2-d structure
|
||||||
*/
|
*/
|
||||||
interface Matrix<T : Any> : NDStructure<T> {
|
interface Matrix<T : Any> : Structure2D<T> {
|
||||||
val rowNum: Int
|
val rowNum: Int
|
||||||
val colNum: Int
|
val colNum: Int
|
||||||
|
|
||||||
@ -124,8 +124,6 @@ interface Matrix<T : Any> : NDStructure<T> {
|
|||||||
*/
|
*/
|
||||||
fun suggestFeature(vararg features: MatrixFeature): Matrix<T>
|
fun suggestFeature(vararg features: MatrixFeature): Matrix<T>
|
||||||
|
|
||||||
operator fun get(i: Int, j: Int): T
|
|
||||||
|
|
||||||
override fun get(index: IntArray): T = get(index[0], index[1])
|
override fun get(index: IntArray): T = get(index[0], index[1])
|
||||||
|
|
||||||
override val shape: IntArray get() = intArrayOf(rowNum, colNum)
|
override val shape: IntArray get() = intArrayOf(rowNum, colNum)
|
||||||
|
@ -16,7 +16,9 @@ interface NDStructure<T> {
|
|||||||
fun equals(st1: NDStructure<*>, st2: NDStructure<*>): Boolean {
|
fun equals(st1: NDStructure<*>, st2: NDStructure<*>): Boolean {
|
||||||
return when {
|
return when {
|
||||||
st1 === st2 -> true
|
st1 === st2 -> true
|
||||||
st1 is BufferNDStructure<*> && st2 is BufferNDStructure<*> && st1.strides == st2.strides -> st1.buffer.contentEquals(st2.buffer)
|
st1 is BufferNDStructure<*> && st2 is BufferNDStructure<*> && st1.strides == st2.strides -> st1.buffer.contentEquals(
|
||||||
|
st2.buffer
|
||||||
|
)
|
||||||
else -> st1.elements().all { (index, value) -> value == st2[index] }
|
else -> st1.elements().all { (index, value) -> value == st2[index] }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,88 @@
|
|||||||
|
package scientifik.kmath.structures
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A structure that is guaranteed to be one-dimensional
|
||||||
|
*/
|
||||||
|
interface Structure1D<T> : NDStructure<T>, Buffer<T> {
|
||||||
|
override val dimension: Int get() = 1
|
||||||
|
|
||||||
|
override fun get(index: IntArray): T {
|
||||||
|
if (index.size != 1) error("Index dimension mismatch. Expected 1 but found ${index.size}")
|
||||||
|
return get(index[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun iterator(): Iterator<T> = (0 until size).asSequence().map { get(it) }.iterator()
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A 1D wrapper for nd-structure
|
||||||
|
*/
|
||||||
|
private inline class Structure1DWrapper<T>(val structure: NDStructure<T>) : Structure1D<T> {
|
||||||
|
|
||||||
|
override val shape: IntArray get() = structure.shape
|
||||||
|
override val size: Int get() = structure.shape[0]
|
||||||
|
|
||||||
|
override fun get(index: Int): T = structure[index]
|
||||||
|
|
||||||
|
override fun elements(): Sequence<Pair<IntArray, T>> = structure.elements()
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represent a [NDStructure] as [Structure1D]. Throw error in case of dimension mismatch
|
||||||
|
*/
|
||||||
|
fun <T> NDStructure<T>.as1D(): Structure1D<T> = if (shape.size == 1) {
|
||||||
|
Structure1DWrapper(this)
|
||||||
|
} else {
|
||||||
|
error("Can't create 1d-structure from ${shape.size}d-structure")
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A structure wrapper for buffer
|
||||||
|
*/
|
||||||
|
private inline class Buffer1DWrapper<T>(val buffer: Buffer<T>) : Structure1D<T> {
|
||||||
|
override val shape: IntArray get() = intArrayOf(buffer.size)
|
||||||
|
|
||||||
|
override val size: Int get() = buffer.size
|
||||||
|
|
||||||
|
override fun elements(): Sequence<Pair<IntArray, T>> =
|
||||||
|
asSequence().mapIndexed { index, value -> intArrayOf(index) to value }
|
||||||
|
|
||||||
|
override fun get(index: Int): T = buffer.get(index)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represent this buffer as 1D structure
|
||||||
|
*/
|
||||||
|
fun <T> Buffer<T>.asND(): Structure1D<T> = Buffer1DWrapper(this)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A structure that is guaranteed to be two-dimensional
|
||||||
|
*/
|
||||||
|
interface Structure2D<T> : NDStructure<T> {
|
||||||
|
operator fun get(i: Int, j: Int): T
|
||||||
|
|
||||||
|
override fun get(index: IntArray): T {
|
||||||
|
if (index.size != 2) error("Index dimension mismatch. Expected 2 but found ${index.size}")
|
||||||
|
return get(index[0], index[1])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A 2D wrapper for nd-structure
|
||||||
|
*/
|
||||||
|
private inline class Structure2DWrapper<T>(val structure: NDStructure<T>) : Structure2D<T> {
|
||||||
|
override fun get(i: Int, j: Int): T = structure[i, j]
|
||||||
|
|
||||||
|
override val shape: IntArray get() = structure.shape
|
||||||
|
|
||||||
|
override fun elements(): Sequence<Pair<IntArray, T>> = structure.elements()
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represent a [NDStructure] as [Structure1D]. Throw error in case of dimension mismatch
|
||||||
|
*/
|
||||||
|
fun <T> NDStructure<T>.as2D(): Structure2D<T> = if (shape.size == 2) {
|
||||||
|
Structure2DWrapper(this)
|
||||||
|
} else {
|
||||||
|
error("Can't create 2d-structure from ${shape.size}d-structure")
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user