Examples for type-safe dimensions

This commit is contained in:
Alexander Nozik 2019-12-09 19:52:00 +03:00
parent a156feb397
commit 4060c70b17
7 changed files with 96 additions and 41 deletions

View File

@ -5,7 +5,7 @@ plugins {
java java
kotlin("jvm") kotlin("jvm")
kotlin("plugin.allopen") version "1.3.61" kotlin("plugin.allopen") version "1.3.61"
id("kotlinx.benchmark") version "0.2.0-dev-5" id("kotlinx.benchmark") version "0.2.0-dev-7"
} }
configure<AllOpenExtension> { configure<AllOpenExtension> {
@ -13,10 +13,7 @@ configure<AllOpenExtension> {
} }
repositories { repositories {
maven("https://dl.bintray.com/kotlin/kotlin-eap")
maven("http://dl.bintray.com/kyonifer/maven") maven("http://dl.bintray.com/kyonifer/maven")
maven ("https://dl.bintray.com/orangy/maven")
mavenCentral() mavenCentral()
} }
@ -30,12 +27,10 @@ dependencies {
implementation(project(":kmath-commons")) implementation(project(":kmath-commons"))
implementation(project(":kmath-koma")) implementation(project(":kmath-koma"))
implementation(project(":kmath-viktor")) implementation(project(":kmath-viktor"))
implementation(project(":kmath-dimensions"))
implementation("com.kyonifer:koma-core-ejml:0.12") implementation("com.kyonifer:koma-core-ejml:0.12")
implementation("org.jetbrains.kotlinx:kotlinx-io-jvm:${Scientifik.ioVersion}") implementation("org.jetbrains.kotlinx:kotlinx-io-jvm:${Scientifik.ioVersion}")
implementation("org.jetbrains.kotlinx:kotlinx.benchmark.runtime:0.2.0-dev-7")
implementation("org.jetbrains.kotlinx:kotlinx.benchmark.runtime:0.2.0-dev-2")
"benchmarksCompile"(sourceSets.main.get().compileClasspath) "benchmarksCompile"(sourceSets.main.get().compileClasspath)
} }

View File

@ -1,28 +1,26 @@
package scientifik.kmath.structures package scientifik.kmath.structures
import org.jetbrains.bio.viktor.F64Array import org.jetbrains.bio.viktor.F64Array
import org.openjdk.jmh.annotations.Benchmark
import org.openjdk.jmh.annotations.Scope
import org.openjdk.jmh.annotations.State
import scientifik.kmath.operations.RealField import scientifik.kmath.operations.RealField
import scientifik.kmath.viktor.ViktorNDField import scientifik.kmath.viktor.ViktorNDField
fun main() { @State(Scope.Benchmark)
val dim = 1000 class ViktorBenchmark {
val n = 400 final val dim = 1000
final val n = 100
// automatically build context most suited for given type. // automatically build context most suited for given type.
val autoField = NDField.auto(RealField, dim, dim) final val autoField = NDField.auto(RealField, dim, dim)
val realField = NDField.real(dim,dim) final val realField = NDField.real(dim, dim)
val viktorField = ViktorNDField(intArrayOf(dim, dim)) final val viktorField = ViktorNDField(intArrayOf(dim, dim))
autoField.run { @Benchmark
var res = one fun `Automatic field addition`() {
repeat(n/2) {
res += 1.0
}
}
measureAndPrint("Automatic field addition") {
autoField.run { autoField.run {
var res = one var res = one
repeat(n) { repeat(n) {
@ -31,14 +29,8 @@ fun main() {
} }
} }
viktorField.run { @Benchmark
var res = one fun `Viktor field addition`() {
repeat(n/2) {
res += one
}
}
measureAndPrint("Viktor field addition") {
viktorField.run { viktorField.run {
var res = one var res = one
repeat(n) { repeat(n) {
@ -47,7 +39,8 @@ fun main() {
} }
} }
measureAndPrint("Raw Viktor") { @Benchmark
fun `Raw Viktor`() {
val one = F64Array.full(init = 1.0, shape = *intArrayOf(dim, dim)) val one = F64Array.full(init = 1.0, shape = *intArrayOf(dim, dim))
var res = one var res = one
repeat(n) { repeat(n) {
@ -55,7 +48,8 @@ fun main() {
} }
} }
measureAndPrint("Automatic field log") { @Benchmark
fun `Real field log`() {
realField.run { realField.run {
val fortyTwo = produce { 42.0 } val fortyTwo = produce { 42.0 }
var res = one var res = one
@ -66,7 +60,8 @@ fun main() {
} }
} }
measureAndPrint("Raw Viktor log") { @Benchmark
fun `Raw Viktor log`() {
val fortyTwo = F64Array.full(dim, dim, init = 42.0) val fortyTwo = F64Array.full(dim, dim, init = 42.0)
var res: F64Array var res: F64Array
repeat(n) { repeat(n) {

View File

@ -0,0 +1,8 @@
package scientifik.kmath.utils
import kotlin.system.measureTimeMillis
internal inline fun measureAndPrint(title: String, block: () -> Unit) {
val time = measureTimeMillis(block)
println("$title completed in $time millis")
}

View File

@ -0,0 +1,34 @@
package scientifik.kmath.structures
import scientifik.kmath.dimensions.D2
import scientifik.kmath.dimensions.D3
import scientifik.kmath.dimensions.DMatrixContext
import scientifik.kmath.dimensions.Dimension
import scientifik.kmath.operations.RealField
fun DMatrixContext<Double, RealField>.simple() {
val m1 = produce<D2, D3> { i, j -> (i + j).toDouble() }
val m2 = produce<D3, D2> { i, j -> (i + j).toDouble() }
//Dimension-safe addition
m1.transpose() + m2
}
object D5: Dimension{
override val dim: UInt = 5u
}
fun DMatrixContext<Double, RealField>.custom() {
val m1 = produce<D2, D5> { i, j -> (i+j).toDouble() }
val m2 = produce<D5,D2> { i, j -> (i-j).toDouble() }
val m3 = produce<D2,D2> { i, j -> (i-j).toDouble() }
(m1 dot m2) + m3
}
fun main() {
DMatrixContext.real.run {
simple()
custom()
}
}

View File

@ -1,5 +1,5 @@
plugins { plugins {
`npm-multiplatform` id("scientifik.mpp")
} }
description = "A proof of concept module for adding typ-safe dimensions to structures" description = "A proof of concept module for adding typ-safe dimensions to structures"

View File

@ -1,9 +1,14 @@
package scientifik.kmath.dimensions package scientifik.kmath.dimensions
/**
* An interface which is not used in runtime. Designates a size of some structure.
* All descendants should be singleton objects.
*/
interface Dimension { interface Dimension {
val dim: UInt val dim: UInt
companion object { companion object {
@Suppress("NOTHING_TO_INLINE")
inline fun of(dim: UInt): Dimension { inline fun of(dim: UInt): Dimension {
return when (dim) { return when (dim) {
1u -> D1 1u -> D1
@ -15,8 +20,8 @@ interface Dimension {
} }
} }
inline fun <reified D: Dimension> dim(): UInt{ inline fun <reified D : Dimension> dim(): UInt {
return D::class.objectInstance!!.dim return D::class.objectInstance?.dim ?: error("Dimension object must be a singleton")
} }
} }
} }

View File

@ -8,14 +8,21 @@ import scientifik.kmath.operations.Ring
import scientifik.kmath.structures.Matrix import scientifik.kmath.structures.Matrix
import scientifik.kmath.structures.Structure2D import scientifik.kmath.structures.Structure2D
/**
* A matrix with compile-time controlled dimension
*/
interface DMatrix<T, R : Dimension, C : Dimension> : Structure2D<T> { interface DMatrix<T, R : Dimension, C : Dimension> : Structure2D<T> {
companion object { companion object {
/** /**
* Coerces a regular matrix to a matrix with type-safe dimensions and throws a error if coercion failed * Coerces a regular matrix to a matrix with type-safe dimensions and throws a error if coercion failed
*/ */
inline fun <T, reified R : Dimension, reified C : Dimension> coerce(structure: Structure2D<T>): DMatrix<T, R, C> { inline fun <T, reified R : Dimension, reified C : Dimension> coerce(structure: Structure2D<T>): DMatrix<T, R, C> {
if (structure.rowNum != Dimension.dim<R>().toInt()) error("Row number mismatch: expected ${Dimension.dim<R>()} but found ${structure.rowNum}") if (structure.rowNum != Dimension.dim<R>().toInt()) {
if (structure.colNum != Dimension.dim<C>().toInt()) error("Column number mismatch: expected ${Dimension.dim<C>()} but found ${structure.colNum}") error("Row number mismatch: expected ${Dimension.dim<R>()} but found ${structure.rowNum}")
}
if (structure.colNum != Dimension.dim<C>().toInt()) {
error("Column number mismatch: expected ${Dimension.dim<C>()} but found ${structure.colNum}")
}
return DMatrixWrapper(structure) return DMatrixWrapper(structure)
} }
@ -28,6 +35,9 @@ interface DMatrix<T, R : Dimension, C : Dimension> : Structure2D<T> {
} }
} }
/**
* An inline wrapper for a Matrix
*/
inline class DMatrixWrapper<T, R : Dimension, C : Dimension>( inline class DMatrixWrapper<T, R : Dimension, C : Dimension>(
val structure: Structure2D<T> val structure: Structure2D<T>
) : DMatrix<T, R, C> { ) : DMatrix<T, R, C> {
@ -35,10 +45,15 @@ inline class DMatrixWrapper<T, R : Dimension, C : Dimension>(
override fun get(i: Int, j: Int): T = structure[i, j] override fun get(i: Int, j: Int): T = structure[i, j]
} }
/**
* Dimension-safe point
*/
interface DPoint<T, D : Dimension> : Point<T> { interface DPoint<T, D : Dimension> : Point<T> {
companion object { companion object {
inline fun <T, reified D : Dimension> coerce(point: Point<T>): DPoint<T, D> { inline fun <T, reified D : Dimension> coerce(point: Point<T>): DPoint<T, D> {
if (point.size != Dimension.dim<D>().toInt()) error("Vector dimension mismatch: expected ${Dimension.dim<D>()}, but found ${point.size}") if (point.size != Dimension.dim<D>().toInt()) {
error("Vector dimension mismatch: expected ${Dimension.dim<D>()}, but found ${point.size}")
}
return DPointWrapper(point) return DPointWrapper(point)
} }
@ -48,6 +63,9 @@ interface DPoint<T, D : Dimension> : Point<T> {
} }
} }
/**
* Dimension-safe point wrapper
*/
inline class DPointWrapper<T, D : Dimension>(val point: Point<T>) : DPoint<T, D> { inline class DPointWrapper<T, D : Dimension>(val point: Point<T>) : DPoint<T, D> {
override val size: Int get() = point.size override val size: Int get() = point.size
@ -58,7 +76,7 @@ inline class DPointWrapper<T, D : Dimension>(val point: Point<T>) : DPoint<T, D>
/** /**
* Basic operations on matrices. Operates on [Matrix] * Basic operations on dimension-safe matrices. Operates on [Matrix]
*/ */
inline class DMatrixContext<T : Any, Ri : Ring<T>>(val context: GenericMatrixContext<T, Ri>) { inline class DMatrixContext<T : Any, Ri : Ring<T>>(val context: GenericMatrixContext<T, Ri>) {