Experimental type-safe dimensions
This commit is contained in:
parent
5a8a45cf6a
commit
00f6c027b4
@ -1,4 +1,4 @@
|
||||
val kmathVersion by extra("0.1.3-dev-1")
|
||||
val kmathVersion by extra("0.1.3-dev-2")
|
||||
|
||||
allprojects {
|
||||
repositories {
|
||||
|
@ -7,7 +7,7 @@ repositories {
|
||||
jcenter()
|
||||
}
|
||||
|
||||
val kotlinVersion = "1.3.31"
|
||||
val kotlinVersion = "1.3.40"
|
||||
|
||||
// Add plugins used in buildSrc as dependencies, also we should specify version only here
|
||||
dependencies {
|
||||
|
14
kmath-dimensions/build.gradle.kts
Normal file
14
kmath-dimensions/build.gradle.kts
Normal file
@ -0,0 +1,14 @@
|
||||
plugins {
|
||||
`npm-multiplatform`
|
||||
}
|
||||
|
||||
description = "A proof of concept module for adding typ-safe dimensions to structures"
|
||||
|
||||
kotlin.sourceSets {
|
||||
commonMain {
|
||||
dependencies {
|
||||
implementation(kotlin("reflect"))
|
||||
api(project(":kmath-core"))
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,34 @@
|
||||
package scientifik.kmath.dimensions
|
||||
|
||||
interface Dimension {
|
||||
val dim: UInt
|
||||
|
||||
companion object {
|
||||
inline fun of(dim: UInt): Dimension {
|
||||
return when (dim) {
|
||||
1u -> D1
|
||||
2u -> D2
|
||||
3u -> D3
|
||||
else -> object : Dimension {
|
||||
override val dim: UInt = dim
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline fun <reified D: Dimension> dim(): UInt{
|
||||
return D::class.objectInstance!!.dim
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
object D1 : Dimension {
|
||||
override val dim: UInt = 1u
|
||||
}
|
||||
|
||||
object D2 : Dimension {
|
||||
override val dim: UInt = 2u
|
||||
}
|
||||
|
||||
object D3 : Dimension {
|
||||
override val dim: UInt = 3u
|
||||
}
|
@ -0,0 +1,130 @@
|
||||
package scientifik.kmath.dimensions
|
||||
|
||||
import scientifik.kmath.linear.GenericMatrixContext
|
||||
import scientifik.kmath.linear.MatrixContext
|
||||
import scientifik.kmath.linear.Point
|
||||
import scientifik.kmath.linear.transpose
|
||||
import scientifik.kmath.operations.Ring
|
||||
import scientifik.kmath.structures.Matrix
|
||||
import scientifik.kmath.structures.Structure2D
|
||||
|
||||
interface DMatrix<T, R : Dimension, C : Dimension> : Structure2D<T> {
|
||||
companion object {
|
||||
/**
|
||||
* Coerces a regular matrix to a matrix with type-safe dimensions and throws a error if coercion failed
|
||||
*/
|
||||
inline fun <T, reified R : Dimension, reified C : Dimension> coerce(structure: Structure2D<T>): DMatrix<T, R, C> {
|
||||
if (structure.rowNum != Dimension.dim<R>().toInt()) error("Row number mismatch: expected ${Dimension.dim<R>()} but found ${structure.rowNum}")
|
||||
if (structure.colNum != Dimension.dim<C>().toInt()) error("Column number mismatch: expected ${Dimension.dim<C>()} but found ${structure.colNum}")
|
||||
return DMatrixWrapper(structure)
|
||||
}
|
||||
|
||||
/**
|
||||
* The same as [coerce] but without dimension checks. Use with caution
|
||||
*/
|
||||
inline fun <T, reified R : Dimension, reified C : Dimension> coerceUnsafe(structure: Structure2D<T>): DMatrix<T, R, C> {
|
||||
return DMatrixWrapper(structure)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline class DMatrixWrapper<T, R : Dimension, C : Dimension>(
|
||||
val structure: Structure2D<T>
|
||||
) : DMatrix<T, R, C> {
|
||||
override val shape: IntArray get() = structure.shape
|
||||
override fun get(i: Int, j: Int): T = structure[i, j]
|
||||
}
|
||||
|
||||
interface DPoint<T, D : Dimension> : Point<T> {
|
||||
companion object {
|
||||
inline fun <T, reified D : Dimension> coerce(point: Point<T>): DPoint<T, D> {
|
||||
if (point.size != Dimension.dim<D>().toInt()) error("Vector dimension mismatch: expected ${Dimension.dim<D>()}, but found ${point.size}")
|
||||
return DPointWrapper(point)
|
||||
}
|
||||
|
||||
inline fun <T, reified D : Dimension> coerceUnsafe(point: Point<T>): DPoint<T, D> {
|
||||
return DPointWrapper(point)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline class DPointWrapper<T, D : Dimension>(val point: Point<T>) : DPoint<T, D> {
|
||||
override val size: Int get() = point.size
|
||||
|
||||
override fun get(index: Int): T = point[index]
|
||||
|
||||
override fun iterator(): Iterator<T> = point.iterator()
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Basic operations on matrices. Operates on [Matrix]
|
||||
*/
|
||||
inline class DMatrixContext<T : Any, Ri : Ring<T>>(val context: GenericMatrixContext<T, Ri>) {
|
||||
|
||||
inline fun <reified R : Dimension, reified C : Dimension> Matrix<T>.coerce(): DMatrix<T, C, R> =
|
||||
DMatrix.coerceUnsafe(this)
|
||||
|
||||
/**
|
||||
* Produce a matrix with this context and given dimensions
|
||||
*/
|
||||
inline fun <reified R : Dimension, reified C : Dimension> produce(noinline initializer: (i: Int, j: Int) -> T): DMatrix<T, R, C> {
|
||||
val rows = Dimension.dim<R>()
|
||||
val cols = Dimension.dim<C>()
|
||||
return context.produce(rows.toInt(), cols.toInt(), initializer).coerce()
|
||||
}
|
||||
|
||||
inline fun <reified D : Dimension> point(noinline initializer: (Int) -> T): DPoint<T, D> {
|
||||
val size = Dimension.dim<D>()
|
||||
return DPoint.coerceUnsafe(context.point(size.toInt(), initializer))
|
||||
}
|
||||
|
||||
inline infix fun <reified R1 : Dimension, reified C1 : Dimension, reified C2 : Dimension> DMatrix<T, R1, C1>.dot(
|
||||
other: DMatrix<T, C1, C2>
|
||||
): DMatrix<T, R1, C2> {
|
||||
return context.run { this@dot dot other }.coerce()
|
||||
}
|
||||
|
||||
inline infix fun <reified R : Dimension, reified C : Dimension> DMatrix<T, R, C>.dot(vector: DPoint<T, C>): DPoint<T, R> {
|
||||
return DPoint.coerceUnsafe(context.run { this@dot dot vector })
|
||||
}
|
||||
|
||||
inline operator fun <reified R : Dimension, reified C : Dimension> DMatrix<T, R, C>.times(value: T): DMatrix<T, R, C> {
|
||||
return context.run { this@times.times(value) }.coerce()
|
||||
}
|
||||
|
||||
inline operator fun <reified R : Dimension, reified C : Dimension> T.times(m: DMatrix<T, R, C>): DMatrix<T, R, C> =
|
||||
m * this
|
||||
|
||||
|
||||
inline operator fun <reified R : Dimension, reified C : Dimension> DMatrix<T, C, R>.plus(other: DMatrix<T, C, R>): DMatrix<T, C, R> {
|
||||
return context.run { this@plus + other }.coerce()
|
||||
}
|
||||
|
||||
inline operator fun <reified R : Dimension, reified C : Dimension> DMatrix<T, C, R>.minus(other: DMatrix<T, C, R>): DMatrix<T, C, R> {
|
||||
return context.run { this@minus + other }.coerce()
|
||||
}
|
||||
|
||||
inline operator fun <reified R : Dimension, reified C : Dimension> DMatrix<T, C, R>.unaryMinus(): DMatrix<T, C, R> {
|
||||
return context.run { this@unaryMinus.unaryMinus() }.coerce()
|
||||
}
|
||||
|
||||
inline fun <reified R : Dimension, reified C : Dimension> DMatrix<T, C, R>.transpose(): DMatrix<T, R, C> {
|
||||
return context.run { (this@transpose as Matrix<T>).transpose() }.coerce()
|
||||
}
|
||||
|
||||
/**
|
||||
* A square unit matrix
|
||||
*/
|
||||
inline fun <reified D : Dimension> one(): DMatrix<T, D, D> = produce { i, j ->
|
||||
if (i == j) context.elementContext.one else context.elementContext.zero
|
||||
}
|
||||
|
||||
inline fun <reified R : Dimension, reified C : Dimension> zero(): DMatrix<T, R, C> = produce { _, _ ->
|
||||
context.elementContext.zero
|
||||
}
|
||||
|
||||
companion object {
|
||||
val real = DMatrixContext(MatrixContext.real)
|
||||
}
|
||||
}
|
@ -0,0 +1,27 @@
|
||||
package scientifik.kmath.dimensions
|
||||
|
||||
import kotlin.test.Test
|
||||
|
||||
|
||||
class DMatrixContextTest {
|
||||
@Test
|
||||
fun testDimensionSafeMatrix() {
|
||||
val res = DMatrixContext.real.run {
|
||||
val m = produce<D2, D2> { i, j -> (i + j).toDouble() }
|
||||
|
||||
//The dimension of `one()` is inferred from type
|
||||
(m + one())
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testTypeCheck() {
|
||||
val res = DMatrixContext.real.run {
|
||||
val m1 = produce<D2, D3> { i, j -> (i + j).toDouble() }
|
||||
val m2 = produce<D3, D2> { i, j -> (i + j).toDouble() }
|
||||
|
||||
//Dimension-safe addition
|
||||
m1.transpose() + m2
|
||||
}
|
||||
}
|
||||
}
|
@ -29,5 +29,6 @@ include(
|
||||
":kmath-commons",
|
||||
":kmath-koma",
|
||||
":kmath-prob",
|
||||
":kmath-dimensions",
|
||||
":examples"
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user