forked from kscience/kmath
Examples for type-safe dimensions
This commit is contained in:
parent
a156feb397
commit
4060c70b17
@ -5,7 +5,7 @@ plugins {
|
|||||||
java
|
java
|
||||||
kotlin("jvm")
|
kotlin("jvm")
|
||||||
kotlin("plugin.allopen") version "1.3.61"
|
kotlin("plugin.allopen") version "1.3.61"
|
||||||
id("kotlinx.benchmark") version "0.2.0-dev-5"
|
id("kotlinx.benchmark") version "0.2.0-dev-7"
|
||||||
}
|
}
|
||||||
|
|
||||||
configure<AllOpenExtension> {
|
configure<AllOpenExtension> {
|
||||||
@ -13,10 +13,7 @@ configure<AllOpenExtension> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
repositories {
|
repositories {
|
||||||
maven("https://dl.bintray.com/kotlin/kotlin-eap")
|
|
||||||
maven("http://dl.bintray.com/kyonifer/maven")
|
maven("http://dl.bintray.com/kyonifer/maven")
|
||||||
maven ("https://dl.bintray.com/orangy/maven")
|
|
||||||
|
|
||||||
mavenCentral()
|
mavenCentral()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -30,12 +27,10 @@ dependencies {
|
|||||||
implementation(project(":kmath-commons"))
|
implementation(project(":kmath-commons"))
|
||||||
implementation(project(":kmath-koma"))
|
implementation(project(":kmath-koma"))
|
||||||
implementation(project(":kmath-viktor"))
|
implementation(project(":kmath-viktor"))
|
||||||
|
implementation(project(":kmath-dimensions"))
|
||||||
implementation("com.kyonifer:koma-core-ejml:0.12")
|
implementation("com.kyonifer:koma-core-ejml:0.12")
|
||||||
implementation("org.jetbrains.kotlinx:kotlinx-io-jvm:${Scientifik.ioVersion}")
|
implementation("org.jetbrains.kotlinx:kotlinx-io-jvm:${Scientifik.ioVersion}")
|
||||||
|
implementation("org.jetbrains.kotlinx:kotlinx.benchmark.runtime:0.2.0-dev-7")
|
||||||
implementation("org.jetbrains.kotlinx:kotlinx.benchmark.runtime:0.2.0-dev-2")
|
|
||||||
|
|
||||||
|
|
||||||
"benchmarksCompile"(sourceSets.main.get().compileClasspath)
|
"benchmarksCompile"(sourceSets.main.get().compileClasspath)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,28 +1,26 @@
|
|||||||
package scientifik.kmath.structures
|
package scientifik.kmath.structures
|
||||||
|
|
||||||
import org.jetbrains.bio.viktor.F64Array
|
import org.jetbrains.bio.viktor.F64Array
|
||||||
|
import org.openjdk.jmh.annotations.Benchmark
|
||||||
|
import org.openjdk.jmh.annotations.Scope
|
||||||
|
import org.openjdk.jmh.annotations.State
|
||||||
import scientifik.kmath.operations.RealField
|
import scientifik.kmath.operations.RealField
|
||||||
import scientifik.kmath.viktor.ViktorNDField
|
import scientifik.kmath.viktor.ViktorNDField
|
||||||
|
|
||||||
|
|
||||||
fun main() {
|
@State(Scope.Benchmark)
|
||||||
val dim = 1000
|
class ViktorBenchmark {
|
||||||
val n = 400
|
final val dim = 1000
|
||||||
|
final val n = 100
|
||||||
|
|
||||||
// automatically build context most suited for given type.
|
// automatically build context most suited for given type.
|
||||||
val autoField = NDField.auto(RealField, dim, dim)
|
final val autoField = NDField.auto(RealField, dim, dim)
|
||||||
val realField = NDField.real(dim,dim)
|
final val realField = NDField.real(dim, dim)
|
||||||
|
|
||||||
val viktorField = ViktorNDField(intArrayOf(dim, dim))
|
final val viktorField = ViktorNDField(intArrayOf(dim, dim))
|
||||||
|
|
||||||
autoField.run {
|
@Benchmark
|
||||||
var res = one
|
fun `Automatic field addition`() {
|
||||||
repeat(n/2) {
|
|
||||||
res += 1.0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
measureAndPrint("Automatic field addition") {
|
|
||||||
autoField.run {
|
autoField.run {
|
||||||
var res = one
|
var res = one
|
||||||
repeat(n) {
|
repeat(n) {
|
||||||
@ -31,14 +29,8 @@ fun main() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
viktorField.run {
|
@Benchmark
|
||||||
var res = one
|
fun `Viktor field addition`() {
|
||||||
repeat(n/2) {
|
|
||||||
res += one
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
measureAndPrint("Viktor field addition") {
|
|
||||||
viktorField.run {
|
viktorField.run {
|
||||||
var res = one
|
var res = one
|
||||||
repeat(n) {
|
repeat(n) {
|
||||||
@ -47,7 +39,8 @@ fun main() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
measureAndPrint("Raw Viktor") {
|
@Benchmark
|
||||||
|
fun `Raw Viktor`() {
|
||||||
val one = F64Array.full(init = 1.0, shape = *intArrayOf(dim, dim))
|
val one = F64Array.full(init = 1.0, shape = *intArrayOf(dim, dim))
|
||||||
var res = one
|
var res = one
|
||||||
repeat(n) {
|
repeat(n) {
|
||||||
@ -55,7 +48,8 @@ fun main() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
measureAndPrint("Automatic field log") {
|
@Benchmark
|
||||||
|
fun `Real field log`() {
|
||||||
realField.run {
|
realField.run {
|
||||||
val fortyTwo = produce { 42.0 }
|
val fortyTwo = produce { 42.0 }
|
||||||
var res = one
|
var res = one
|
||||||
@ -66,7 +60,8 @@ fun main() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
measureAndPrint("Raw Viktor log") {
|
@Benchmark
|
||||||
|
fun `Raw Viktor log`() {
|
||||||
val fortyTwo = F64Array.full(dim, dim, init = 42.0)
|
val fortyTwo = F64Array.full(dim, dim, init = 42.0)
|
||||||
var res: F64Array
|
var res: F64Array
|
||||||
repeat(n) {
|
repeat(n) {
|
@ -0,0 +1,8 @@
|
|||||||
|
package scientifik.kmath.utils
|
||||||
|
|
||||||
|
import kotlin.system.measureTimeMillis
|
||||||
|
|
||||||
|
internal inline fun measureAndPrint(title: String, block: () -> Unit) {
|
||||||
|
val time = measureTimeMillis(block)
|
||||||
|
println("$title completed in $time millis")
|
||||||
|
}
|
@ -0,0 +1,34 @@
|
|||||||
|
package scientifik.kmath.structures
|
||||||
|
|
||||||
|
import scientifik.kmath.dimensions.D2
|
||||||
|
import scientifik.kmath.dimensions.D3
|
||||||
|
import scientifik.kmath.dimensions.DMatrixContext
|
||||||
|
import scientifik.kmath.dimensions.Dimension
|
||||||
|
import scientifik.kmath.operations.RealField
|
||||||
|
|
||||||
|
fun DMatrixContext<Double, RealField>.simple() {
|
||||||
|
val m1 = produce<D2, D3> { i, j -> (i + j).toDouble() }
|
||||||
|
val m2 = produce<D3, D2> { i, j -> (i + j).toDouble() }
|
||||||
|
|
||||||
|
//Dimension-safe addition
|
||||||
|
m1.transpose() + m2
|
||||||
|
}
|
||||||
|
|
||||||
|
object D5: Dimension{
|
||||||
|
override val dim: UInt = 5u
|
||||||
|
}
|
||||||
|
|
||||||
|
fun DMatrixContext<Double, RealField>.custom() {
|
||||||
|
val m1 = produce<D2, D5> { i, j -> (i+j).toDouble() }
|
||||||
|
val m2 = produce<D5,D2> { i, j -> (i-j).toDouble() }
|
||||||
|
val m3 = produce<D2,D2> { i, j -> (i-j).toDouble() }
|
||||||
|
|
||||||
|
(m1 dot m2) + m3
|
||||||
|
}
|
||||||
|
|
||||||
|
fun main() {
|
||||||
|
DMatrixContext.real.run {
|
||||||
|
simple()
|
||||||
|
custom()
|
||||||
|
}
|
||||||
|
}
|
@ -1,5 +1,5 @@
|
|||||||
plugins {
|
plugins {
|
||||||
`npm-multiplatform`
|
id("scientifik.mpp")
|
||||||
}
|
}
|
||||||
|
|
||||||
description = "A proof of concept module for adding typ-safe dimensions to structures"
|
description = "A proof of concept module for adding typ-safe dimensions to structures"
|
||||||
|
@ -1,9 +1,14 @@
|
|||||||
package scientifik.kmath.dimensions
|
package scientifik.kmath.dimensions
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An interface which is not used in runtime. Designates a size of some structure.
|
||||||
|
* All descendants should be singleton objects.
|
||||||
|
*/
|
||||||
interface Dimension {
|
interface Dimension {
|
||||||
val dim: UInt
|
val dim: UInt
|
||||||
|
|
||||||
companion object {
|
companion object {
|
||||||
|
@Suppress("NOTHING_TO_INLINE")
|
||||||
inline fun of(dim: UInt): Dimension {
|
inline fun of(dim: UInt): Dimension {
|
||||||
return when (dim) {
|
return when (dim) {
|
||||||
1u -> D1
|
1u -> D1
|
||||||
@ -16,7 +21,7 @@ interface Dimension {
|
|||||||
}
|
}
|
||||||
|
|
||||||
inline fun <reified D : Dimension> dim(): UInt {
|
inline fun <reified D : Dimension> dim(): UInt {
|
||||||
return D::class.objectInstance!!.dim
|
return D::class.objectInstance?.dim ?: error("Dimension object must be a singleton")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -8,14 +8,21 @@ import scientifik.kmath.operations.Ring
|
|||||||
import scientifik.kmath.structures.Matrix
|
import scientifik.kmath.structures.Matrix
|
||||||
import scientifik.kmath.structures.Structure2D
|
import scientifik.kmath.structures.Structure2D
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A matrix with compile-time controlled dimension
|
||||||
|
*/
|
||||||
interface DMatrix<T, R : Dimension, C : Dimension> : Structure2D<T> {
|
interface DMatrix<T, R : Dimension, C : Dimension> : Structure2D<T> {
|
||||||
companion object {
|
companion object {
|
||||||
/**
|
/**
|
||||||
* Coerces a regular matrix to a matrix with type-safe dimensions and throws a error if coercion failed
|
* Coerces a regular matrix to a matrix with type-safe dimensions and throws a error if coercion failed
|
||||||
*/
|
*/
|
||||||
inline fun <T, reified R : Dimension, reified C : Dimension> coerce(structure: Structure2D<T>): DMatrix<T, R, C> {
|
inline fun <T, reified R : Dimension, reified C : Dimension> coerce(structure: Structure2D<T>): DMatrix<T, R, C> {
|
||||||
if (structure.rowNum != Dimension.dim<R>().toInt()) error("Row number mismatch: expected ${Dimension.dim<R>()} but found ${structure.rowNum}")
|
if (structure.rowNum != Dimension.dim<R>().toInt()) {
|
||||||
if (structure.colNum != Dimension.dim<C>().toInt()) error("Column number mismatch: expected ${Dimension.dim<C>()} but found ${structure.colNum}")
|
error("Row number mismatch: expected ${Dimension.dim<R>()} but found ${structure.rowNum}")
|
||||||
|
}
|
||||||
|
if (structure.colNum != Dimension.dim<C>().toInt()) {
|
||||||
|
error("Column number mismatch: expected ${Dimension.dim<C>()} but found ${structure.colNum}")
|
||||||
|
}
|
||||||
return DMatrixWrapper(structure)
|
return DMatrixWrapper(structure)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -28,6 +35,9 @@ interface DMatrix<T, R : Dimension, C : Dimension> : Structure2D<T> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An inline wrapper for a Matrix
|
||||||
|
*/
|
||||||
inline class DMatrixWrapper<T, R : Dimension, C : Dimension>(
|
inline class DMatrixWrapper<T, R : Dimension, C : Dimension>(
|
||||||
val structure: Structure2D<T>
|
val structure: Structure2D<T>
|
||||||
) : DMatrix<T, R, C> {
|
) : DMatrix<T, R, C> {
|
||||||
@ -35,10 +45,15 @@ inline class DMatrixWrapper<T, R : Dimension, C : Dimension>(
|
|||||||
override fun get(i: Int, j: Int): T = structure[i, j]
|
override fun get(i: Int, j: Int): T = structure[i, j]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Dimension-safe point
|
||||||
|
*/
|
||||||
interface DPoint<T, D : Dimension> : Point<T> {
|
interface DPoint<T, D : Dimension> : Point<T> {
|
||||||
companion object {
|
companion object {
|
||||||
inline fun <T, reified D : Dimension> coerce(point: Point<T>): DPoint<T, D> {
|
inline fun <T, reified D : Dimension> coerce(point: Point<T>): DPoint<T, D> {
|
||||||
if (point.size != Dimension.dim<D>().toInt()) error("Vector dimension mismatch: expected ${Dimension.dim<D>()}, but found ${point.size}")
|
if (point.size != Dimension.dim<D>().toInt()) {
|
||||||
|
error("Vector dimension mismatch: expected ${Dimension.dim<D>()}, but found ${point.size}")
|
||||||
|
}
|
||||||
return DPointWrapper(point)
|
return DPointWrapper(point)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -48,6 +63,9 @@ interface DPoint<T, D : Dimension> : Point<T> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Dimension-safe point wrapper
|
||||||
|
*/
|
||||||
inline class DPointWrapper<T, D : Dimension>(val point: Point<T>) : DPoint<T, D> {
|
inline class DPointWrapper<T, D : Dimension>(val point: Point<T>) : DPoint<T, D> {
|
||||||
override val size: Int get() = point.size
|
override val size: Int get() = point.size
|
||||||
|
|
||||||
@ -58,7 +76,7 @@ inline class DPointWrapper<T, D : Dimension>(val point: Point<T>) : DPoint<T, D>
|
|||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Basic operations on matrices. Operates on [Matrix]
|
* Basic operations on dimension-safe matrices. Operates on [Matrix]
|
||||||
*/
|
*/
|
||||||
inline class DMatrixContext<T : Any, Ri : Ring<T>>(val context: GenericMatrixContext<T, Ri>) {
|
inline class DMatrixContext<T : Any, Ri : Ring<T>>(val context: GenericMatrixContext<T, Ri>) {
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user