forked from kscience/kmath
Experimental type-safe dimensions
This commit is contained in:
parent
5a8a45cf6a
commit
00f6c027b4
@ -1,4 +1,4 @@
|
|||||||
val kmathVersion by extra("0.1.3-dev-1")
|
val kmathVersion by extra("0.1.3-dev-2")
|
||||||
|
|
||||||
allprojects {
|
allprojects {
|
||||||
repositories {
|
repositories {
|
||||||
|
@ -7,7 +7,7 @@ repositories {
|
|||||||
jcenter()
|
jcenter()
|
||||||
}
|
}
|
||||||
|
|
||||||
val kotlinVersion = "1.3.31"
|
val kotlinVersion = "1.3.40"
|
||||||
|
|
||||||
// Add plugins used in buildSrc as dependencies, also we should specify version only here
|
// Add plugins used in buildSrc as dependencies, also we should specify version only here
|
||||||
dependencies {
|
dependencies {
|
||||||
|
14
kmath-dimensions/build.gradle.kts
Normal file
14
kmath-dimensions/build.gradle.kts
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
plugins {
|
||||||
|
`npm-multiplatform`
|
||||||
|
}
|
||||||
|
|
||||||
|
description = "A proof of concept module for adding typ-safe dimensions to structures"
|
||||||
|
|
||||||
|
kotlin.sourceSets {
|
||||||
|
commonMain {
|
||||||
|
dependencies {
|
||||||
|
implementation(kotlin("reflect"))
|
||||||
|
api(project(":kmath-core"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,34 @@
|
|||||||
|
package scientifik.kmath.dimensions
|
||||||
|
|
||||||
|
interface Dimension {
|
||||||
|
val dim: UInt
|
||||||
|
|
||||||
|
companion object {
|
||||||
|
inline fun of(dim: UInt): Dimension {
|
||||||
|
return when (dim) {
|
||||||
|
1u -> D1
|
||||||
|
2u -> D2
|
||||||
|
3u -> D3
|
||||||
|
else -> object : Dimension {
|
||||||
|
override val dim: UInt = dim
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline fun <reified D: Dimension> dim(): UInt{
|
||||||
|
return D::class.objectInstance!!.dim
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
object D1 : Dimension {
|
||||||
|
override val dim: UInt = 1u
|
||||||
|
}
|
||||||
|
|
||||||
|
object D2 : Dimension {
|
||||||
|
override val dim: UInt = 2u
|
||||||
|
}
|
||||||
|
|
||||||
|
object D3 : Dimension {
|
||||||
|
override val dim: UInt = 3u
|
||||||
|
}
|
@ -0,0 +1,130 @@
|
|||||||
|
package scientifik.kmath.dimensions
|
||||||
|
|
||||||
|
import scientifik.kmath.linear.GenericMatrixContext
|
||||||
|
import scientifik.kmath.linear.MatrixContext
|
||||||
|
import scientifik.kmath.linear.Point
|
||||||
|
import scientifik.kmath.linear.transpose
|
||||||
|
import scientifik.kmath.operations.Ring
|
||||||
|
import scientifik.kmath.structures.Matrix
|
||||||
|
import scientifik.kmath.structures.Structure2D
|
||||||
|
|
||||||
|
interface DMatrix<T, R : Dimension, C : Dimension> : Structure2D<T> {
|
||||||
|
companion object {
|
||||||
|
/**
|
||||||
|
* Coerces a regular matrix to a matrix with type-safe dimensions and throws a error if coercion failed
|
||||||
|
*/
|
||||||
|
inline fun <T, reified R : Dimension, reified C : Dimension> coerce(structure: Structure2D<T>): DMatrix<T, R, C> {
|
||||||
|
if (structure.rowNum != Dimension.dim<R>().toInt()) error("Row number mismatch: expected ${Dimension.dim<R>()} but found ${structure.rowNum}")
|
||||||
|
if (structure.colNum != Dimension.dim<C>().toInt()) error("Column number mismatch: expected ${Dimension.dim<C>()} but found ${structure.colNum}")
|
||||||
|
return DMatrixWrapper(structure)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The same as [coerce] but without dimension checks. Use with caution
|
||||||
|
*/
|
||||||
|
inline fun <T, reified R : Dimension, reified C : Dimension> coerceUnsafe(structure: Structure2D<T>): DMatrix<T, R, C> {
|
||||||
|
return DMatrixWrapper(structure)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline class DMatrixWrapper<T, R : Dimension, C : Dimension>(
|
||||||
|
val structure: Structure2D<T>
|
||||||
|
) : DMatrix<T, R, C> {
|
||||||
|
override val shape: IntArray get() = structure.shape
|
||||||
|
override fun get(i: Int, j: Int): T = structure[i, j]
|
||||||
|
}
|
||||||
|
|
||||||
|
interface DPoint<T, D : Dimension> : Point<T> {
|
||||||
|
companion object {
|
||||||
|
inline fun <T, reified D : Dimension> coerce(point: Point<T>): DPoint<T, D> {
|
||||||
|
if (point.size != Dimension.dim<D>().toInt()) error("Vector dimension mismatch: expected ${Dimension.dim<D>()}, but found ${point.size}")
|
||||||
|
return DPointWrapper(point)
|
||||||
|
}
|
||||||
|
|
||||||
|
inline fun <T, reified D : Dimension> coerceUnsafe(point: Point<T>): DPoint<T, D> {
|
||||||
|
return DPointWrapper(point)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline class DPointWrapper<T, D : Dimension>(val point: Point<T>) : DPoint<T, D> {
|
||||||
|
override val size: Int get() = point.size
|
||||||
|
|
||||||
|
override fun get(index: Int): T = point[index]
|
||||||
|
|
||||||
|
override fun iterator(): Iterator<T> = point.iterator()
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Basic operations on matrices. Operates on [Matrix]
|
||||||
|
*/
|
||||||
|
inline class DMatrixContext<T : Any, Ri : Ring<T>>(val context: GenericMatrixContext<T, Ri>) {
|
||||||
|
|
||||||
|
inline fun <reified R : Dimension, reified C : Dimension> Matrix<T>.coerce(): DMatrix<T, C, R> =
|
||||||
|
DMatrix.coerceUnsafe(this)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Produce a matrix with this context and given dimensions
|
||||||
|
*/
|
||||||
|
inline fun <reified R : Dimension, reified C : Dimension> produce(noinline initializer: (i: Int, j: Int) -> T): DMatrix<T, R, C> {
|
||||||
|
val rows = Dimension.dim<R>()
|
||||||
|
val cols = Dimension.dim<C>()
|
||||||
|
return context.produce(rows.toInt(), cols.toInt(), initializer).coerce()
|
||||||
|
}
|
||||||
|
|
||||||
|
inline fun <reified D : Dimension> point(noinline initializer: (Int) -> T): DPoint<T, D> {
|
||||||
|
val size = Dimension.dim<D>()
|
||||||
|
return DPoint.coerceUnsafe(context.point(size.toInt(), initializer))
|
||||||
|
}
|
||||||
|
|
||||||
|
inline infix fun <reified R1 : Dimension, reified C1 : Dimension, reified C2 : Dimension> DMatrix<T, R1, C1>.dot(
|
||||||
|
other: DMatrix<T, C1, C2>
|
||||||
|
): DMatrix<T, R1, C2> {
|
||||||
|
return context.run { this@dot dot other }.coerce()
|
||||||
|
}
|
||||||
|
|
||||||
|
inline infix fun <reified R : Dimension, reified C : Dimension> DMatrix<T, R, C>.dot(vector: DPoint<T, C>): DPoint<T, R> {
|
||||||
|
return DPoint.coerceUnsafe(context.run { this@dot dot vector })
|
||||||
|
}
|
||||||
|
|
||||||
|
inline operator fun <reified R : Dimension, reified C : Dimension> DMatrix<T, R, C>.times(value: T): DMatrix<T, R, C> {
|
||||||
|
return context.run { this@times.times(value) }.coerce()
|
||||||
|
}
|
||||||
|
|
||||||
|
inline operator fun <reified R : Dimension, reified C : Dimension> T.times(m: DMatrix<T, R, C>): DMatrix<T, R, C> =
|
||||||
|
m * this
|
||||||
|
|
||||||
|
|
||||||
|
inline operator fun <reified R : Dimension, reified C : Dimension> DMatrix<T, C, R>.plus(other: DMatrix<T, C, R>): DMatrix<T, C, R> {
|
||||||
|
return context.run { this@plus + other }.coerce()
|
||||||
|
}
|
||||||
|
|
||||||
|
inline operator fun <reified R : Dimension, reified C : Dimension> DMatrix<T, C, R>.minus(other: DMatrix<T, C, R>): DMatrix<T, C, R> {
|
||||||
|
return context.run { this@minus + other }.coerce()
|
||||||
|
}
|
||||||
|
|
||||||
|
inline operator fun <reified R : Dimension, reified C : Dimension> DMatrix<T, C, R>.unaryMinus(): DMatrix<T, C, R> {
|
||||||
|
return context.run { this@unaryMinus.unaryMinus() }.coerce()
|
||||||
|
}
|
||||||
|
|
||||||
|
inline fun <reified R : Dimension, reified C : Dimension> DMatrix<T, C, R>.transpose(): DMatrix<T, R, C> {
|
||||||
|
return context.run { (this@transpose as Matrix<T>).transpose() }.coerce()
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A square unit matrix
|
||||||
|
*/
|
||||||
|
inline fun <reified D : Dimension> one(): DMatrix<T, D, D> = produce { i, j ->
|
||||||
|
if (i == j) context.elementContext.one else context.elementContext.zero
|
||||||
|
}
|
||||||
|
|
||||||
|
inline fun <reified R : Dimension, reified C : Dimension> zero(): DMatrix<T, R, C> = produce { _, _ ->
|
||||||
|
context.elementContext.zero
|
||||||
|
}
|
||||||
|
|
||||||
|
companion object {
|
||||||
|
val real = DMatrixContext(MatrixContext.real)
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,27 @@
|
|||||||
|
package scientifik.kmath.dimensions
|
||||||
|
|
||||||
|
import kotlin.test.Test
|
||||||
|
|
||||||
|
|
||||||
|
class DMatrixContextTest {
|
||||||
|
@Test
|
||||||
|
fun testDimensionSafeMatrix() {
|
||||||
|
val res = DMatrixContext.real.run {
|
||||||
|
val m = produce<D2, D2> { i, j -> (i + j).toDouble() }
|
||||||
|
|
||||||
|
//The dimension of `one()` is inferred from type
|
||||||
|
(m + one())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testTypeCheck() {
|
||||||
|
val res = DMatrixContext.real.run {
|
||||||
|
val m1 = produce<D2, D3> { i, j -> (i + j).toDouble() }
|
||||||
|
val m2 = produce<D3, D2> { i, j -> (i + j).toDouble() }
|
||||||
|
|
||||||
|
//Dimension-safe addition
|
||||||
|
m1.transpose() + m2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -29,5 +29,6 @@ include(
|
|||||||
":kmath-commons",
|
":kmath-commons",
|
||||||
":kmath-koma",
|
":kmath-koma",
|
||||||
":kmath-prob",
|
":kmath-prob",
|
||||||
|
":kmath-dimensions",
|
||||||
":examples"
|
":examples"
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user