From 00f6c027b41fae02e24e55055d821a40ce21e093 Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Fri, 21 Jun 2019 12:34:04 +0300 Subject: [PATCH] Experimental type-safe dimensions --- build.gradle.kts | 2 +- buildSrc/build.gradle.kts | 2 +- kmath-dimensions/build.gradle.kts | 14 ++ .../scientifik/kmath/dimensions/Dimensions.kt | 34 +++++ .../scientifik/kmath/dimensions/Wrappers.kt | 130 ++++++++++++++++++ .../kmath/dimensions/DMatrixContextTest.kt | 27 ++++ settings.gradle.kts | 1 + 7 files changed, 208 insertions(+), 2 deletions(-) create mode 100644 kmath-dimensions/build.gradle.kts create mode 100644 kmath-dimensions/src/commonMain/kotlin/scientifik/kmath/dimensions/Dimensions.kt create mode 100644 kmath-dimensions/src/commonMain/kotlin/scientifik/kmath/dimensions/Wrappers.kt create mode 100644 kmath-dimensions/src/commonTest/kotlin/scientifik/kmath/dimensions/DMatrixContextTest.kt diff --git a/build.gradle.kts b/build.gradle.kts index d7b7ac78a..33fe7c33a 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -1,4 +1,4 @@ -val kmathVersion by extra("0.1.3-dev-1") +val kmathVersion by extra("0.1.3-dev-2") allprojects { repositories { diff --git a/buildSrc/build.gradle.kts b/buildSrc/build.gradle.kts index 318189162..a2760f609 100644 --- a/buildSrc/build.gradle.kts +++ b/buildSrc/build.gradle.kts @@ -7,7 +7,7 @@ repositories { jcenter() } -val kotlinVersion = "1.3.31" +val kotlinVersion = "1.3.40" // Add plugins used in buildSrc as dependencies, also we should specify version only here dependencies { diff --git a/kmath-dimensions/build.gradle.kts b/kmath-dimensions/build.gradle.kts new file mode 100644 index 000000000..59a4c0fc3 --- /dev/null +++ b/kmath-dimensions/build.gradle.kts @@ -0,0 +1,14 @@ +plugins { + `npm-multiplatform` +} + +description = "A proof of concept module for adding typ-safe dimensions to structures" + +kotlin.sourceSets { + commonMain { + dependencies { + implementation(kotlin("reflect")) + api(project(":kmath-core")) + } + } +} \ No newline at end of file diff --git a/kmath-dimensions/src/commonMain/kotlin/scientifik/kmath/dimensions/Dimensions.kt b/kmath-dimensions/src/commonMain/kotlin/scientifik/kmath/dimensions/Dimensions.kt new file mode 100644 index 000000000..421ccb853 --- /dev/null +++ b/kmath-dimensions/src/commonMain/kotlin/scientifik/kmath/dimensions/Dimensions.kt @@ -0,0 +1,34 @@ +package scientifik.kmath.dimensions + +interface Dimension { + val dim: UInt + + companion object { + inline fun of(dim: UInt): Dimension { + return when (dim) { + 1u -> D1 + 2u -> D2 + 3u -> D3 + else -> object : Dimension { + override val dim: UInt = dim + } + } + } + + inline fun dim(): UInt{ + return D::class.objectInstance!!.dim + } + } +} + +object D1 : Dimension { + override val dim: UInt = 1u +} + +object D2 : Dimension { + override val dim: UInt = 2u +} + +object D3 : Dimension { + override val dim: UInt = 3u +} \ No newline at end of file diff --git a/kmath-dimensions/src/commonMain/kotlin/scientifik/kmath/dimensions/Wrappers.kt b/kmath-dimensions/src/commonMain/kotlin/scientifik/kmath/dimensions/Wrappers.kt new file mode 100644 index 000000000..3138558b8 --- /dev/null +++ b/kmath-dimensions/src/commonMain/kotlin/scientifik/kmath/dimensions/Wrappers.kt @@ -0,0 +1,130 @@ +package scientifik.kmath.dimensions + +import scientifik.kmath.linear.GenericMatrixContext +import scientifik.kmath.linear.MatrixContext +import scientifik.kmath.linear.Point +import scientifik.kmath.linear.transpose +import scientifik.kmath.operations.Ring +import scientifik.kmath.structures.Matrix +import scientifik.kmath.structures.Structure2D + +interface DMatrix : Structure2D { + companion object { + /** + * Coerces a regular matrix to a matrix with type-safe dimensions and throws a error if coercion failed + */ + inline fun coerce(structure: Structure2D): DMatrix { + if (structure.rowNum != Dimension.dim().toInt()) error("Row number mismatch: expected ${Dimension.dim()} but found ${structure.rowNum}") + if (structure.colNum != Dimension.dim().toInt()) error("Column number mismatch: expected ${Dimension.dim()} but found ${structure.colNum}") + return DMatrixWrapper(structure) + } + + /** + * The same as [coerce] but without dimension checks. Use with caution + */ + inline fun coerceUnsafe(structure: Structure2D): DMatrix { + return DMatrixWrapper(structure) + } + } +} + +inline class DMatrixWrapper( + val structure: Structure2D +) : DMatrix { + override val shape: IntArray get() = structure.shape + override fun get(i: Int, j: Int): T = structure[i, j] +} + +interface DPoint : Point { + companion object { + inline fun coerce(point: Point): DPoint { + if (point.size != Dimension.dim().toInt()) error("Vector dimension mismatch: expected ${Dimension.dim()}, but found ${point.size}") + return DPointWrapper(point) + } + + inline fun coerceUnsafe(point: Point): DPoint { + return DPointWrapper(point) + } + } +} + +inline class DPointWrapper(val point: Point) : DPoint { + override val size: Int get() = point.size + + override fun get(index: Int): T = point[index] + + override fun iterator(): Iterator = point.iterator() +} + + +/** + * Basic operations on matrices. Operates on [Matrix] + */ +inline class DMatrixContext>(val context: GenericMatrixContext) { + + inline fun Matrix.coerce(): DMatrix = + DMatrix.coerceUnsafe(this) + + /** + * Produce a matrix with this context and given dimensions + */ + inline fun produce(noinline initializer: (i: Int, j: Int) -> T): DMatrix { + val rows = Dimension.dim() + val cols = Dimension.dim() + return context.produce(rows.toInt(), cols.toInt(), initializer).coerce() + } + + inline fun point(noinline initializer: (Int) -> T): DPoint { + val size = Dimension.dim() + return DPoint.coerceUnsafe(context.point(size.toInt(), initializer)) + } + + inline infix fun DMatrix.dot( + other: DMatrix + ): DMatrix { + return context.run { this@dot dot other }.coerce() + } + + inline infix fun DMatrix.dot(vector: DPoint): DPoint { + return DPoint.coerceUnsafe(context.run { this@dot dot vector }) + } + + inline operator fun DMatrix.times(value: T): DMatrix { + return context.run { this@times.times(value) }.coerce() + } + + inline operator fun T.times(m: DMatrix): DMatrix = + m * this + + + inline operator fun DMatrix.plus(other: DMatrix): DMatrix { + return context.run { this@plus + other }.coerce() + } + + inline operator fun DMatrix.minus(other: DMatrix): DMatrix { + return context.run { this@minus + other }.coerce() + } + + inline operator fun DMatrix.unaryMinus(): DMatrix { + return context.run { this@unaryMinus.unaryMinus() }.coerce() + } + + inline fun DMatrix.transpose(): DMatrix { + return context.run { (this@transpose as Matrix).transpose() }.coerce() + } + + /** + * A square unit matrix + */ + inline fun one(): DMatrix = produce { i, j -> + if (i == j) context.elementContext.one else context.elementContext.zero + } + + inline fun zero(): DMatrix = produce { _, _ -> + context.elementContext.zero + } + + companion object { + val real = DMatrixContext(MatrixContext.real) + } +} \ No newline at end of file diff --git a/kmath-dimensions/src/commonTest/kotlin/scientifik/kmath/dimensions/DMatrixContextTest.kt b/kmath-dimensions/src/commonTest/kotlin/scientifik/kmath/dimensions/DMatrixContextTest.kt new file mode 100644 index 000000000..8d4bae810 --- /dev/null +++ b/kmath-dimensions/src/commonTest/kotlin/scientifik/kmath/dimensions/DMatrixContextTest.kt @@ -0,0 +1,27 @@ +package scientifik.kmath.dimensions + +import kotlin.test.Test + + +class DMatrixContextTest { + @Test + fun testDimensionSafeMatrix() { + val res = DMatrixContext.real.run { + val m = produce { i, j -> (i + j).toDouble() } + + //The dimension of `one()` is inferred from type + (m + one()) + } + } + + @Test + fun testTypeCheck() { + val res = DMatrixContext.real.run { + val m1 = produce { i, j -> (i + j).toDouble() } + val m2 = produce { i, j -> (i + j).toDouble() } + + //Dimension-safe addition + m1.transpose() + m2 + } + } +} \ No newline at end of file diff --git a/settings.gradle.kts b/settings.gradle.kts index 004b432fd..4afae465f 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -29,5 +29,6 @@ include( ":kmath-commons", ":kmath-koma", ":kmath-prob", + ":kmath-dimensions", ":examples" )