forked from kscience/kmath
Examples for type-safe dimensions
This commit is contained in:
parent
a156feb397
commit
4060c70b17
@ -5,7 +5,7 @@ plugins {
|
||||
java
|
||||
kotlin("jvm")
|
||||
kotlin("plugin.allopen") version "1.3.61"
|
||||
id("kotlinx.benchmark") version "0.2.0-dev-5"
|
||||
id("kotlinx.benchmark") version "0.2.0-dev-7"
|
||||
}
|
||||
|
||||
configure<AllOpenExtension> {
|
||||
@ -13,10 +13,7 @@ configure<AllOpenExtension> {
|
||||
}
|
||||
|
||||
repositories {
|
||||
maven("https://dl.bintray.com/kotlin/kotlin-eap")
|
||||
maven("http://dl.bintray.com/kyonifer/maven")
|
||||
maven ("https://dl.bintray.com/orangy/maven")
|
||||
|
||||
mavenCentral()
|
||||
}
|
||||
|
||||
@ -30,12 +27,10 @@ dependencies {
|
||||
implementation(project(":kmath-commons"))
|
||||
implementation(project(":kmath-koma"))
|
||||
implementation(project(":kmath-viktor"))
|
||||
implementation(project(":kmath-dimensions"))
|
||||
implementation("com.kyonifer:koma-core-ejml:0.12")
|
||||
implementation("org.jetbrains.kotlinx:kotlinx-io-jvm:${Scientifik.ioVersion}")
|
||||
|
||||
implementation("org.jetbrains.kotlinx:kotlinx.benchmark.runtime:0.2.0-dev-2")
|
||||
|
||||
|
||||
implementation("org.jetbrains.kotlinx:kotlinx.benchmark.runtime:0.2.0-dev-7")
|
||||
"benchmarksCompile"(sourceSets.main.get().compileClasspath)
|
||||
}
|
||||
|
||||
|
@ -1,28 +1,26 @@
|
||||
package scientifik.kmath.structures
|
||||
|
||||
import org.jetbrains.bio.viktor.F64Array
|
||||
import org.openjdk.jmh.annotations.Benchmark
|
||||
import org.openjdk.jmh.annotations.Scope
|
||||
import org.openjdk.jmh.annotations.State
|
||||
import scientifik.kmath.operations.RealField
|
||||
import scientifik.kmath.viktor.ViktorNDField
|
||||
|
||||
|
||||
fun main() {
|
||||
val dim = 1000
|
||||
val n = 400
|
||||
@State(Scope.Benchmark)
|
||||
class ViktorBenchmark {
|
||||
final val dim = 1000
|
||||
final val n = 100
|
||||
|
||||
// automatically build context most suited for given type.
|
||||
val autoField = NDField.auto(RealField, dim, dim)
|
||||
val realField = NDField.real(dim,dim)
|
||||
final val autoField = NDField.auto(RealField, dim, dim)
|
||||
final val realField = NDField.real(dim, dim)
|
||||
|
||||
val viktorField = ViktorNDField(intArrayOf(dim, dim))
|
||||
final val viktorField = ViktorNDField(intArrayOf(dim, dim))
|
||||
|
||||
autoField.run {
|
||||
var res = one
|
||||
repeat(n/2) {
|
||||
res += 1.0
|
||||
}
|
||||
}
|
||||
|
||||
measureAndPrint("Automatic field addition") {
|
||||
@Benchmark
|
||||
fun `Automatic field addition`() {
|
||||
autoField.run {
|
||||
var res = one
|
||||
repeat(n) {
|
||||
@ -31,14 +29,8 @@ fun main() {
|
||||
}
|
||||
}
|
||||
|
||||
viktorField.run {
|
||||
var res = one
|
||||
repeat(n/2) {
|
||||
res += one
|
||||
}
|
||||
}
|
||||
|
||||
measureAndPrint("Viktor field addition") {
|
||||
@Benchmark
|
||||
fun `Viktor field addition`() {
|
||||
viktorField.run {
|
||||
var res = one
|
||||
repeat(n) {
|
||||
@ -47,7 +39,8 @@ fun main() {
|
||||
}
|
||||
}
|
||||
|
||||
measureAndPrint("Raw Viktor") {
|
||||
@Benchmark
|
||||
fun `Raw Viktor`() {
|
||||
val one = F64Array.full(init = 1.0, shape = *intArrayOf(dim, dim))
|
||||
var res = one
|
||||
repeat(n) {
|
||||
@ -55,7 +48,8 @@ fun main() {
|
||||
}
|
||||
}
|
||||
|
||||
measureAndPrint("Automatic field log") {
|
||||
@Benchmark
|
||||
fun `Real field log`() {
|
||||
realField.run {
|
||||
val fortyTwo = produce { 42.0 }
|
||||
var res = one
|
||||
@ -66,7 +60,8 @@ fun main() {
|
||||
}
|
||||
}
|
||||
|
||||
measureAndPrint("Raw Viktor log") {
|
||||
@Benchmark
|
||||
fun `Raw Viktor log`() {
|
||||
val fortyTwo = F64Array.full(dim, dim, init = 42.0)
|
||||
var res: F64Array
|
||||
repeat(n) {
|
@ -0,0 +1,8 @@
|
||||
package scientifik.kmath.utils
|
||||
|
||||
import kotlin.system.measureTimeMillis
|
||||
|
||||
internal inline fun measureAndPrint(title: String, block: () -> Unit) {
|
||||
val time = measureTimeMillis(block)
|
||||
println("$title completed in $time millis")
|
||||
}
|
@ -0,0 +1,34 @@
|
||||
package scientifik.kmath.structures
|
||||
|
||||
import scientifik.kmath.dimensions.D2
|
||||
import scientifik.kmath.dimensions.D3
|
||||
import scientifik.kmath.dimensions.DMatrixContext
|
||||
import scientifik.kmath.dimensions.Dimension
|
||||
import scientifik.kmath.operations.RealField
|
||||
|
||||
fun DMatrixContext<Double, RealField>.simple() {
|
||||
val m1 = produce<D2, D3> { i, j -> (i + j).toDouble() }
|
||||
val m2 = produce<D3, D2> { i, j -> (i + j).toDouble() }
|
||||
|
||||
//Dimension-safe addition
|
||||
m1.transpose() + m2
|
||||
}
|
||||
|
||||
object D5: Dimension{
|
||||
override val dim: UInt = 5u
|
||||
}
|
||||
|
||||
fun DMatrixContext<Double, RealField>.custom() {
|
||||
val m1 = produce<D2, D5> { i, j -> (i+j).toDouble() }
|
||||
val m2 = produce<D5,D2> { i, j -> (i-j).toDouble() }
|
||||
val m3 = produce<D2,D2> { i, j -> (i-j).toDouble() }
|
||||
|
||||
(m1 dot m2) + m3
|
||||
}
|
||||
|
||||
fun main() {
|
||||
DMatrixContext.real.run {
|
||||
simple()
|
||||
custom()
|
||||
}
|
||||
}
|
@ -1,5 +1,5 @@
|
||||
plugins {
|
||||
`npm-multiplatform`
|
||||
id("scientifik.mpp")
|
||||
}
|
||||
|
||||
description = "A proof of concept module for adding typ-safe dimensions to structures"
|
||||
|
@ -1,9 +1,14 @@
|
||||
package scientifik.kmath.dimensions
|
||||
|
||||
/**
|
||||
* An interface which is not used in runtime. Designates a size of some structure.
|
||||
* All descendants should be singleton objects.
|
||||
*/
|
||||
interface Dimension {
|
||||
val dim: UInt
|
||||
|
||||
companion object {
|
||||
@Suppress("NOTHING_TO_INLINE")
|
||||
inline fun of(dim: UInt): Dimension {
|
||||
return when (dim) {
|
||||
1u -> D1
|
||||
@ -16,7 +21,7 @@ interface Dimension {
|
||||
}
|
||||
|
||||
inline fun <reified D : Dimension> dim(): UInt {
|
||||
return D::class.objectInstance!!.dim
|
||||
return D::class.objectInstance?.dim ?: error("Dimension object must be a singleton")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -8,14 +8,21 @@ import scientifik.kmath.operations.Ring
|
||||
import scientifik.kmath.structures.Matrix
|
||||
import scientifik.kmath.structures.Structure2D
|
||||
|
||||
/**
|
||||
* A matrix with compile-time controlled dimension
|
||||
*/
|
||||
interface DMatrix<T, R : Dimension, C : Dimension> : Structure2D<T> {
|
||||
companion object {
|
||||
/**
|
||||
* Coerces a regular matrix to a matrix with type-safe dimensions and throws a error if coercion failed
|
||||
*/
|
||||
inline fun <T, reified R : Dimension, reified C : Dimension> coerce(structure: Structure2D<T>): DMatrix<T, R, C> {
|
||||
if (structure.rowNum != Dimension.dim<R>().toInt()) error("Row number mismatch: expected ${Dimension.dim<R>()} but found ${structure.rowNum}")
|
||||
if (structure.colNum != Dimension.dim<C>().toInt()) error("Column number mismatch: expected ${Dimension.dim<C>()} but found ${structure.colNum}")
|
||||
if (structure.rowNum != Dimension.dim<R>().toInt()) {
|
||||
error("Row number mismatch: expected ${Dimension.dim<R>()} but found ${structure.rowNum}")
|
||||
}
|
||||
if (structure.colNum != Dimension.dim<C>().toInt()) {
|
||||
error("Column number mismatch: expected ${Dimension.dim<C>()} but found ${structure.colNum}")
|
||||
}
|
||||
return DMatrixWrapper(structure)
|
||||
}
|
||||
|
||||
@ -28,6 +35,9 @@ interface DMatrix<T, R : Dimension, C : Dimension> : Structure2D<T> {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* An inline wrapper for a Matrix
|
||||
*/
|
||||
inline class DMatrixWrapper<T, R : Dimension, C : Dimension>(
|
||||
val structure: Structure2D<T>
|
||||
) : DMatrix<T, R, C> {
|
||||
@ -35,10 +45,15 @@ inline class DMatrixWrapper<T, R : Dimension, C : Dimension>(
|
||||
override fun get(i: Int, j: Int): T = structure[i, j]
|
||||
}
|
||||
|
||||
/**
|
||||
* Dimension-safe point
|
||||
*/
|
||||
interface DPoint<T, D : Dimension> : Point<T> {
|
||||
companion object {
|
||||
inline fun <T, reified D : Dimension> coerce(point: Point<T>): DPoint<T, D> {
|
||||
if (point.size != Dimension.dim<D>().toInt()) error("Vector dimension mismatch: expected ${Dimension.dim<D>()}, but found ${point.size}")
|
||||
if (point.size != Dimension.dim<D>().toInt()) {
|
||||
error("Vector dimension mismatch: expected ${Dimension.dim<D>()}, but found ${point.size}")
|
||||
}
|
||||
return DPointWrapper(point)
|
||||
}
|
||||
|
||||
@ -48,6 +63,9 @@ interface DPoint<T, D : Dimension> : Point<T> {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Dimension-safe point wrapper
|
||||
*/
|
||||
inline class DPointWrapper<T, D : Dimension>(val point: Point<T>) : DPoint<T, D> {
|
||||
override val size: Int get() = point.size
|
||||
|
||||
@ -58,7 +76,7 @@ inline class DPointWrapper<T, D : Dimension>(val point: Point<T>) : DPoint<T, D>
|
||||
|
||||
|
||||
/**
|
||||
* Basic operations on matrices. Operates on [Matrix]
|
||||
* Basic operations on dimension-safe matrices. Operates on [Matrix]
|
||||
*/
|
||||
inline class DMatrixContext<T : Any, Ri : Ring<T>>(val context: GenericMatrixContext<T, Ri>) {
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user