diff --git a/examples/src/main/kotlin/scientifik/kmath/structures/typeSafeDimensions.kt b/examples/src/main/kotlin/scientifik/kmath/structures/typeSafeDimensions.kt index 9169c7d7b..fdc09ed5d 100644 --- a/examples/src/main/kotlin/scientifik/kmath/structures/typeSafeDimensions.kt +++ b/examples/src/main/kotlin/scientifik/kmath/structures/typeSafeDimensions.kt @@ -14,6 +14,7 @@ fun DMatrixContext.simple() { m1.transpose() + m2 } + object D5 : Dimension { override val dim: UInt = 5u } diff --git a/kmath-dimensions/build.gradle.kts b/kmath-dimensions/build.gradle.kts index 50fb41391..dda6cd2f0 100644 --- a/kmath-dimensions/build.gradle.kts +++ b/kmath-dimensions/build.gradle.kts @@ -7,8 +7,13 @@ description = "A proof of concept module for adding typ-safe dimensions to struc kotlin.sourceSets { commonMain { dependencies { - implementation(kotlin("reflect")) api(project(":kmath-core")) } } + + jvmMain{ + dependencies{ + api(kotlin("reflect")) + } + } } \ No newline at end of file diff --git a/kmath-dimensions/src/commonMain/kotlin/scientifik/kmath/dimensions/Dimensions.kt b/kmath-dimensions/src/commonMain/kotlin/scientifik/kmath/dimensions/Dimensions.kt index 2eef69a0c..37e89c111 100644 --- a/kmath-dimensions/src/commonMain/kotlin/scientifik/kmath/dimensions/Dimensions.kt +++ b/kmath-dimensions/src/commonMain/kotlin/scientifik/kmath/dimensions/Dimensions.kt @@ -1,37 +1,35 @@ package scientifik.kmath.dimensions +import kotlin.reflect.KClass + /** - * An interface which is not used in runtime. Designates a size of some structure. - * All descendants should be singleton objects. + * An abstract class which is not used in runtime. Designates a size of some structure. + * Could be replaced later by fully inline constructs */ interface Dimension { - val dim: UInt + val dim: UInt companion object { - @Suppress("NOTHING_TO_INLINE") - inline fun of(dim: UInt): Dimension { - return when (dim) { - 1u -> D1 - 2u -> D2 - 3u -> D3 - else -> object : Dimension { - override val dim: UInt = dim - } - } - } + } } -expect inline fun Dimension.Companion.dim(): UInt +fun KClass.dim(): UInt = Dimension.resolve(this).dim + +expect fun Dimension.Companion.resolve(type: KClass): D + +expect fun Dimension.Companion.of(dim: UInt): Dimension + +inline fun Dimension.Companion.dim(): UInt = D::class.dim() object D1 : Dimension { - override val dim: UInt = 1u + override val dim: UInt get() = 1U } object D2 : Dimension { - override val dim: UInt = 2u + override val dim: UInt get() = 2U } object D3 : Dimension { - override val dim: UInt = 3u + override val dim: UInt get() = 31U } \ No newline at end of file diff --git a/kmath-dimensions/src/commonMain/kotlin/scientifik/kmath/dimensions/Wrappers.kt b/kmath-dimensions/src/commonMain/kotlin/scientifik/kmath/dimensions/Wrappers.kt index c78018753..f447866c0 100644 --- a/kmath-dimensions/src/commonMain/kotlin/scientifik/kmath/dimensions/Wrappers.kt +++ b/kmath-dimensions/src/commonMain/kotlin/scientifik/kmath/dimensions/Wrappers.kt @@ -29,7 +29,7 @@ interface DMatrix : Structure2D { /** * The same as [coerce] but without dimension checks. Use with caution */ - inline fun coerceUnsafe(structure: Structure2D): DMatrix { + fun coerceUnsafe(structure: Structure2D): DMatrix { return DMatrixWrapper(structure) } } @@ -57,7 +57,7 @@ interface DPoint : Point { return DPointWrapper(point) } - inline fun coerceUnsafe(point: Point): DPoint { + fun coerceUnsafe(point: Point): DPoint { return DPointWrapper(point) } } @@ -81,8 +81,15 @@ inline class DPointWrapper(val point: Point) : */ inline class DMatrixContext>(val context: GenericMatrixContext) { - inline fun Matrix.coerce(): DMatrix = - DMatrix.coerceUnsafe(this) + inline fun Matrix.coerce(): DMatrix { + if (rowNum != Dimension.dim().toInt()) { + error("Row number mismatch: expected ${Dimension.dim()} but found $rowNum") + } + if (colNum != Dimension.dim().toInt()) { + error("Column number mismatch: expected ${Dimension.dim()} but found $colNum") + } + return DMatrix.coerceUnsafe(this) + } /** * Produce a matrix with this context and given dimensions @@ -90,7 +97,7 @@ inline class DMatrixContext>(val context: GenericMatrixCon inline fun produce(noinline initializer: (i: Int, j: Int) -> T): DMatrix { val rows = Dimension.dim() val cols = Dimension.dim() - return context.produce(rows.toInt(), cols.toInt(), initializer).coerce() + return context.produce(rows.toInt(), cols.toInt(), initializer).coerce() } inline fun point(noinline initializer: (Int) -> T): DPoint { diff --git a/kmath-dimensions/src/jvmTest/kotlin/scientifik/kmath/dimensions/DMatrixContextTest.kt b/kmath-dimensions/src/commonTest/kotlin/scientifik/dimensions/DMatrixContextTest.kt similarity index 79% rename from kmath-dimensions/src/jvmTest/kotlin/scientifik/kmath/dimensions/DMatrixContextTest.kt rename to kmath-dimensions/src/commonTest/kotlin/scientifik/dimensions/DMatrixContextTest.kt index 8d4bae810..74d20205c 100644 --- a/kmath-dimensions/src/jvmTest/kotlin/scientifik/kmath/dimensions/DMatrixContextTest.kt +++ b/kmath-dimensions/src/commonTest/kotlin/scientifik/dimensions/DMatrixContextTest.kt @@ -1,5 +1,8 @@ -package scientifik.kmath.dimensions +package scientifik.dimensions +import scientifik.kmath.dimensions.D2 +import scientifik.kmath.dimensions.D3 +import scientifik.kmath.dimensions.DMatrixContext import kotlin.test.Test diff --git a/kmath-dimensions/src/jsMain/kotlin/scientifik/kmath/dimensions/dim.kt b/kmath-dimensions/src/jsMain/kotlin/scientifik/kmath/dimensions/dim.kt index 557234d84..bbd580629 100644 --- a/kmath-dimensions/src/jsMain/kotlin/scientifik/kmath/dimensions/dim.kt +++ b/kmath-dimensions/src/jsMain/kotlin/scientifik/kmath/dimensions/dim.kt @@ -1,5 +1,22 @@ package scientifik.kmath.dimensions -actual inline fun Dimension.Companion.dim(): UInt { - TODO("KClass::objectInstance does not work") +import kotlin.reflect.KClass + +private val dimensionMap = hashMapOf( + 1u to D1, + 2u to D2, + 3u to D3 +) + +@Suppress("UNCHECKED_CAST") +actual fun Dimension.Companion.resolve(type: KClass): D { + return dimensionMap.entries.find { it.value::class == type }?.value as? D ?: error("Can't resolve dimension $type") +} + +actual fun Dimension.Companion.of(dim: UInt): Dimension { + return dimensionMap.getOrPut(dim) { + object : Dimension { + override val dim: UInt get() = dim + } + } } \ No newline at end of file diff --git a/kmath-dimensions/src/jvmMain/kotlin/scientifik/kmath/dimensions/dim.kt b/kmath-dimensions/src/jvmMain/kotlin/scientifik/kmath/dimensions/dim.kt index a16dd2266..e8fe8f59b 100644 --- a/kmath-dimensions/src/jvmMain/kotlin/scientifik/kmath/dimensions/dim.kt +++ b/kmath-dimensions/src/jvmMain/kotlin/scientifik/kmath/dimensions/dim.kt @@ -1,4 +1,18 @@ package scientifik.kmath.dimensions -actual inline fun Dimension.Companion.dim(): UInt = - D::class.objectInstance?.dim ?: error("Dimension object must be a singleton") \ No newline at end of file +import kotlin.reflect.KClass + +actual fun Dimension.Companion.resolve(type: KClass): D{ + return type.objectInstance ?: error("No object instance for dimension class") +} + +actual fun Dimension.Companion.of(dim: UInt): Dimension{ + return when(dim){ + 1u -> D1 + 2u -> D2 + 3u -> D3 + else -> object : Dimension { + override val dim: UInt get() = dim + } + } +} \ No newline at end of file