Examples for type-safe dimensions

This commit is contained in:
Alexander Nozik 2019-12-09 19:52:00 +03:00
parent a156feb397
commit 4060c70b17
7 changed files with 96 additions and 41 deletions

View File

@ -5,7 +5,7 @@ plugins {
java
kotlin("jvm")
kotlin("plugin.allopen") version "1.3.61"
id("kotlinx.benchmark") version "0.2.0-dev-5"
id("kotlinx.benchmark") version "0.2.0-dev-7"
}
configure<AllOpenExtension> {
@ -13,10 +13,7 @@ configure<AllOpenExtension> {
}
repositories {
maven("https://dl.bintray.com/kotlin/kotlin-eap")
maven("http://dl.bintray.com/kyonifer/maven")
maven ("https://dl.bintray.com/orangy/maven")
mavenCentral()
}
@ -30,12 +27,10 @@ dependencies {
implementation(project(":kmath-commons"))
implementation(project(":kmath-koma"))
implementation(project(":kmath-viktor"))
implementation(project(":kmath-dimensions"))
implementation("com.kyonifer:koma-core-ejml:0.12")
implementation("org.jetbrains.kotlinx:kotlinx-io-jvm:${Scientifik.ioVersion}")
implementation("org.jetbrains.kotlinx:kotlinx.benchmark.runtime:0.2.0-dev-2")
implementation("org.jetbrains.kotlinx:kotlinx.benchmark.runtime:0.2.0-dev-7")
"benchmarksCompile"(sourceSets.main.get().compileClasspath)
}

View File

@ -1,28 +1,26 @@
package scientifik.kmath.structures
import org.jetbrains.bio.viktor.F64Array
import org.openjdk.jmh.annotations.Benchmark
import org.openjdk.jmh.annotations.Scope
import org.openjdk.jmh.annotations.State
import scientifik.kmath.operations.RealField
import scientifik.kmath.viktor.ViktorNDField
fun main() {
val dim = 1000
val n = 400
@State(Scope.Benchmark)
class ViktorBenchmark {
final val dim = 1000
final val n = 100
// automatically build context most suited for given type.
val autoField = NDField.auto(RealField, dim, dim)
val realField = NDField.real(dim,dim)
final val autoField = NDField.auto(RealField, dim, dim)
final val realField = NDField.real(dim, dim)
val viktorField = ViktorNDField(intArrayOf(dim, dim))
final val viktorField = ViktorNDField(intArrayOf(dim, dim))
autoField.run {
var res = one
repeat(n/2) {
res += 1.0
}
}
measureAndPrint("Automatic field addition") {
@Benchmark
fun `Automatic field addition`() {
autoField.run {
var res = one
repeat(n) {
@ -31,14 +29,8 @@ fun main() {
}
}
viktorField.run {
var res = one
repeat(n/2) {
res += one
}
}
measureAndPrint("Viktor field addition") {
@Benchmark
fun `Viktor field addition`() {
viktorField.run {
var res = one
repeat(n) {
@ -47,7 +39,8 @@ fun main() {
}
}
measureAndPrint("Raw Viktor") {
@Benchmark
fun `Raw Viktor`() {
val one = F64Array.full(init = 1.0, shape = *intArrayOf(dim, dim))
var res = one
repeat(n) {
@ -55,7 +48,8 @@ fun main() {
}
}
measureAndPrint("Automatic field log") {
@Benchmark
fun `Real field log`() {
realField.run {
val fortyTwo = produce { 42.0 }
var res = one
@ -66,7 +60,8 @@ fun main() {
}
}
measureAndPrint("Raw Viktor log") {
@Benchmark
fun `Raw Viktor log`() {
val fortyTwo = F64Array.full(dim, dim, init = 42.0)
var res: F64Array
repeat(n) {

View File

@ -0,0 +1,8 @@
package scientifik.kmath.utils
import kotlin.system.measureTimeMillis
internal inline fun measureAndPrint(title: String, block: () -> Unit) {
val time = measureTimeMillis(block)
println("$title completed in $time millis")
}

View File

@ -0,0 +1,34 @@
package scientifik.kmath.structures
import scientifik.kmath.dimensions.D2
import scientifik.kmath.dimensions.D3
import scientifik.kmath.dimensions.DMatrixContext
import scientifik.kmath.dimensions.Dimension
import scientifik.kmath.operations.RealField
fun DMatrixContext<Double, RealField>.simple() {
val m1 = produce<D2, D3> { i, j -> (i + j).toDouble() }
val m2 = produce<D3, D2> { i, j -> (i + j).toDouble() }
//Dimension-safe addition
m1.transpose() + m2
}
object D5: Dimension{
override val dim: UInt = 5u
}
fun DMatrixContext<Double, RealField>.custom() {
val m1 = produce<D2, D5> { i, j -> (i+j).toDouble() }
val m2 = produce<D5,D2> { i, j -> (i-j).toDouble() }
val m3 = produce<D2,D2> { i, j -> (i-j).toDouble() }
(m1 dot m2) + m3
}
fun main() {
DMatrixContext.real.run {
simple()
custom()
}
}

View File

@ -1,5 +1,5 @@
plugins {
`npm-multiplatform`
id("scientifik.mpp")
}
description = "A proof of concept module for adding typ-safe dimensions to structures"

View File

@ -1,9 +1,14 @@
package scientifik.kmath.dimensions
/**
* An interface which is not used in runtime. Designates a size of some structure.
* All descendants should be singleton objects.
*/
interface Dimension {
val dim: UInt
companion object {
@Suppress("NOTHING_TO_INLINE")
inline fun of(dim: UInt): Dimension {
return when (dim) {
1u -> D1
@ -16,7 +21,7 @@ interface Dimension {
}
inline fun <reified D : Dimension> dim(): UInt {
return D::class.objectInstance!!.dim
return D::class.objectInstance?.dim ?: error("Dimension object must be a singleton")
}
}
}

View File

@ -8,14 +8,21 @@ import scientifik.kmath.operations.Ring
import scientifik.kmath.structures.Matrix
import scientifik.kmath.structures.Structure2D
/**
* A matrix with compile-time controlled dimension
*/
interface DMatrix<T, R : Dimension, C : Dimension> : Structure2D<T> {
companion object {
/**
* Coerces a regular matrix to a matrix with type-safe dimensions and throws a error if coercion failed
*/
inline fun <T, reified R : Dimension, reified C : Dimension> coerce(structure: Structure2D<T>): DMatrix<T, R, C> {
if (structure.rowNum != Dimension.dim<R>().toInt()) error("Row number mismatch: expected ${Dimension.dim<R>()} but found ${structure.rowNum}")
if (structure.colNum != Dimension.dim<C>().toInt()) error("Column number mismatch: expected ${Dimension.dim<C>()} but found ${structure.colNum}")
if (structure.rowNum != Dimension.dim<R>().toInt()) {
error("Row number mismatch: expected ${Dimension.dim<R>()} but found ${structure.rowNum}")
}
if (structure.colNum != Dimension.dim<C>().toInt()) {
error("Column number mismatch: expected ${Dimension.dim<C>()} but found ${structure.colNum}")
}
return DMatrixWrapper(structure)
}
@ -28,6 +35,9 @@ interface DMatrix<T, R : Dimension, C : Dimension> : Structure2D<T> {
}
}
/**
* An inline wrapper for a Matrix
*/
inline class DMatrixWrapper<T, R : Dimension, C : Dimension>(
val structure: Structure2D<T>
) : DMatrix<T, R, C> {
@ -35,10 +45,15 @@ inline class DMatrixWrapper<T, R : Dimension, C : Dimension>(
override fun get(i: Int, j: Int): T = structure[i, j]
}
/**
* Dimension-safe point
*/
interface DPoint<T, D : Dimension> : Point<T> {
companion object {
inline fun <T, reified D : Dimension> coerce(point: Point<T>): DPoint<T, D> {
if (point.size != Dimension.dim<D>().toInt()) error("Vector dimension mismatch: expected ${Dimension.dim<D>()}, but found ${point.size}")
if (point.size != Dimension.dim<D>().toInt()) {
error("Vector dimension mismatch: expected ${Dimension.dim<D>()}, but found ${point.size}")
}
return DPointWrapper(point)
}
@ -48,6 +63,9 @@ interface DPoint<T, D : Dimension> : Point<T> {
}
}
/**
* Dimension-safe point wrapper
*/
inline class DPointWrapper<T, D : Dimension>(val point: Point<T>) : DPoint<T, D> {
override val size: Int get() = point.size
@ -58,7 +76,7 @@ inline class DPointWrapper<T, D : Dimension>(val point: Point<T>) : DPoint<T, D>
/**
* Basic operations on matrices. Operates on [Matrix]
* Basic operations on dimension-safe matrices. Operates on [Matrix]
*/
inline class DMatrixContext<T : Any, Ri : Ring<T>>(val context: GenericMatrixContext<T, Ri>) {