Version update + some useful extensions
This commit is contained in:
parent
e5e9367c43
commit
1e70bebba6
@ -1,5 +1,5 @@
|
|||||||
plugins {
|
plugins {
|
||||||
id("scientifik.publish") version "0.1.6" apply false
|
id("scientifik.publish") version "0.2.5" apply false
|
||||||
}
|
}
|
||||||
|
|
||||||
val kmathVersion by extra("0.1.4-dev")
|
val kmathVersion by extra("0.1.4-dev")
|
||||||
|
@ -4,8 +4,8 @@ import org.jetbrains.kotlin.gradle.tasks.KotlinCompile
|
|||||||
plugins {
|
plugins {
|
||||||
java
|
java
|
||||||
kotlin("jvm")
|
kotlin("jvm")
|
||||||
kotlin("plugin.allopen") version "1.3.41"
|
kotlin("plugin.allopen") version "1.3.60"
|
||||||
id("kotlinx.benchmark") version "0.2.0-dev-2"
|
id("kotlinx.benchmark") version "0.2.0-dev-5"
|
||||||
}
|
}
|
||||||
|
|
||||||
configure<AllOpenExtension> {
|
configure<AllOpenExtension> {
|
||||||
@ -59,6 +59,6 @@ benchmark {
|
|||||||
|
|
||||||
tasks.withType<KotlinCompile> {
|
tasks.withType<KotlinCompile> {
|
||||||
kotlinOptions {
|
kotlinOptions {
|
||||||
jvmTarget = "1.8"
|
jvmTarget = Scientifik.JVM_VERSION
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -2,17 +2,17 @@ package scientifik.kmath.commons.prob
|
|||||||
|
|
||||||
import kotlinx.coroutines.runBlocking
|
import kotlinx.coroutines.runBlocking
|
||||||
import scientifik.kmath.chains.Chain
|
import scientifik.kmath.chains.Chain
|
||||||
import scientifik.kmath.chains.mapWithState
|
import scientifik.kmath.chains.collectWithState
|
||||||
import scientifik.kmath.prob.Distribution
|
import scientifik.kmath.prob.Distribution
|
||||||
import scientifik.kmath.prob.RandomGenerator
|
import scientifik.kmath.prob.RandomGenerator
|
||||||
|
|
||||||
data class AveragingChainState(var num: Int = 0, var value: Double = 0.0)
|
data class AveragingChainState(var num: Int = 0, var value: Double = 0.0)
|
||||||
|
|
||||||
fun Chain<Double>.mean(): Chain<Double> = mapWithState(AveragingChainState(),{it.copy()}){chain->
|
fun Chain<Double>.mean(): Chain<Double> = collectWithState(AveragingChainState(),{it.copy()}){ chain->
|
||||||
val next = chain.next()
|
val next = chain.next()
|
||||||
num++
|
num++
|
||||||
value += next
|
value += next
|
||||||
return@mapWithState value / num
|
return@collectWithState value / num
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -2,7 +2,8 @@ package scientifik.kmath.structures
|
|||||||
|
|
||||||
import scientifik.kmath.linear.transpose
|
import scientifik.kmath.linear.transpose
|
||||||
import scientifik.kmath.operations.Complex
|
import scientifik.kmath.operations.Complex
|
||||||
import scientifik.kmath.operations.toComplex
|
import scientifik.kmath.operations.ComplexField
|
||||||
|
import scientifik.kmath.operations.invoke
|
||||||
import kotlin.system.measureTimeMillis
|
import kotlin.system.measureTimeMillis
|
||||||
|
|
||||||
fun main() {
|
fun main() {
|
||||||
@ -39,19 +40,21 @@ fun main() {
|
|||||||
|
|
||||||
fun complexExample() {
|
fun complexExample() {
|
||||||
//Create a context for 2-d structure with complex values
|
//Create a context for 2-d structure with complex values
|
||||||
NDField.complex(4, 8).run {
|
ComplexField {
|
||||||
//a constant real-valued structure
|
nd(4, 8) {
|
||||||
val x = one * 2.5
|
//a constant real-valued structure
|
||||||
operator fun Number.plus(other: Complex) = Complex(this.toDouble() + other.re, other.im)
|
val x = one * 2.5
|
||||||
//a structure generator specific to this context
|
operator fun Number.plus(other: Complex) = Complex(this.toDouble() + other.re, other.im)
|
||||||
val matrix = produce { (k, l) ->
|
//a structure generator specific to this context
|
||||||
k + l*i
|
val matrix = produce { (k, l) ->
|
||||||
|
k + l * i
|
||||||
|
}
|
||||||
|
|
||||||
|
//Perform sum
|
||||||
|
val sum = matrix + x + 1.0
|
||||||
|
|
||||||
|
//Represent the sum as 2d-structure and transpose
|
||||||
|
sum.as2D().transpose()
|
||||||
}
|
}
|
||||||
|
|
||||||
//Perform sum
|
|
||||||
val sum = matrix + x + 1.0
|
|
||||||
|
|
||||||
//Represent the sum as 2d-structure and transpose
|
|
||||||
sum.as2D().transpose()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,31 +0,0 @@
|
|||||||
apply plugin: "com.jfrog.artifactory"
|
|
||||||
|
|
||||||
artifactory {
|
|
||||||
def artifactory_user = project.hasProperty('artifactoryUser') ? project.property('artifactoryUser') : ""
|
|
||||||
def artifactory_password = project.hasProperty('artifactoryPassword') ? project.property('artifactoryPassword') : ""
|
|
||||||
def artifactory_contextUrl = 'http://npm.mipt.ru:8081/artifactory'
|
|
||||||
|
|
||||||
contextUrl = artifactory_contextUrl //The base Artifactory URL if not overridden by the publisher/resolver
|
|
||||||
publish {
|
|
||||||
repository {
|
|
||||||
repoKey = 'gradle-dev-local'
|
|
||||||
username = artifactory_user
|
|
||||||
password = artifactory_password
|
|
||||||
}
|
|
||||||
|
|
||||||
defaults {
|
|
||||||
publications('jvm', 'js', 'kotlinMultiplatform', 'metadata')
|
|
||||||
publishBuildInfo = false
|
|
||||||
publishArtifacts = true
|
|
||||||
publishPom = true
|
|
||||||
publishIvy = false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
resolve {
|
|
||||||
repository {
|
|
||||||
repoKey = 'gradle-dev'
|
|
||||||
username = artifactory_user
|
|
||||||
password = artifactory_password
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,85 +0,0 @@
|
|||||||
apply plugin: 'com.jfrog.bintray'
|
|
||||||
|
|
||||||
def vcs = "https://github.com/mipt-npm/kmath"
|
|
||||||
|
|
||||||
def pomConfig = {
|
|
||||||
licenses {
|
|
||||||
license {
|
|
||||||
name "The Apache Software License, Version 2.0"
|
|
||||||
url "http://www.apache.org/licenses/LICENSE-2.0.txt"
|
|
||||||
distribution "repo"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
developers {
|
|
||||||
developer {
|
|
||||||
id "MIPT-NPM"
|
|
||||||
name "MIPT nuclear physics methods laboratory"
|
|
||||||
organization "MIPT"
|
|
||||||
organizationUrl "http://npm.mipt.ru"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
scm {
|
|
||||||
url vcs
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
project.ext.configureMavenCentralMetadata = { pom ->
|
|
||||||
def root = asNode()
|
|
||||||
root.appendNode('name', project.name)
|
|
||||||
root.appendNode('description', project.description)
|
|
||||||
root.appendNode('url', vcs)
|
|
||||||
root.children().last() + pomConfig
|
|
||||||
}
|
|
||||||
|
|
||||||
project.ext.configurePom = pomConfig
|
|
||||||
|
|
||||||
|
|
||||||
// Configure publishing
|
|
||||||
publishing {
|
|
||||||
repositories {
|
|
||||||
maven {
|
|
||||||
url = "https://bintray.com/mipt-npm/scientifik"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Process each publication we have in this project
|
|
||||||
publications.all { publication ->
|
|
||||||
// apply changes to pom.xml files, see pom.gradle
|
|
||||||
pom.withXml(configureMavenCentralMetadata)
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
bintray {
|
|
||||||
user = project.hasProperty('bintrayUser') ? project.property('bintrayUser') : System.getenv('BINTRAY_USER')
|
|
||||||
key = project.hasProperty('bintrayApiKey') ? project.property('bintrayApiKey') : System.getenv('BINTRAY_API_KEY')
|
|
||||||
publish = true
|
|
||||||
override = true // for multi-platform Kotlin/Native publishing
|
|
||||||
|
|
||||||
pkg {
|
|
||||||
userOrg = "mipt-npm"
|
|
||||||
repo = "scientifik"
|
|
||||||
name = "scientifik.kmath"
|
|
||||||
issueTrackerUrl = "https://github.com/mipt-npm/kmath/issues"
|
|
||||||
licenses = ['Apache-2.0']
|
|
||||||
vcsUrl = vcs
|
|
||||||
version {
|
|
||||||
name = project.version
|
|
||||||
vcsTag = project.version
|
|
||||||
released = new Date()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
bintrayUpload.dependsOn publishToMavenLocal
|
|
||||||
|
|
||||||
// This is for easier debugging of bintray uploading problems
|
|
||||||
bintrayUpload.doFirst {
|
|
||||||
publications = project.publishing.publications.findAll {
|
|
||||||
!it.name.contains('-test') && it.name != 'kotlinMultiplatform'
|
|
||||||
}.collect {
|
|
||||||
println("Uploading artifact '$it.groupId:$it.artifactId:$it.version' from publication '$it.name'")
|
|
||||||
it.name//https://github.com/bintray/gradle-bintray-plugin/issues/256
|
|
||||||
}
|
|
||||||
}
|
|
BIN
gradle/wrapper/gradle-wrapper.jar
vendored
BIN
gradle/wrapper/gradle-wrapper.jar
vendored
Binary file not shown.
2
gradle/wrapper/gradle-wrapper.properties
vendored
2
gradle/wrapper/gradle-wrapper.properties
vendored
@ -1,5 +1,5 @@
|
|||||||
distributionBase=GRADLE_USER_HOME
|
distributionBase=GRADLE_USER_HOME
|
||||||
distributionPath=wrapper/dists
|
distributionPath=wrapper/dists
|
||||||
distributionUrl=https\://services.gradle.org/distributions/gradle-5.6-bin.zip
|
distributionUrl=https\://services.gradle.org/distributions/gradle-6.0-bin.zip
|
||||||
zipStoreBase=GRADLE_USER_HOME
|
zipStoreBase=GRADLE_USER_HOME
|
||||||
zipStorePath=wrapper/dists
|
zipStorePath=wrapper/dists
|
||||||
|
29
gradlew
vendored
29
gradlew
vendored
@ -154,19 +154,19 @@ if [ "$cygwin" = "true" -o "$msys" = "true" ] ; then
|
|||||||
else
|
else
|
||||||
eval `echo args$i`="\"$arg\""
|
eval `echo args$i`="\"$arg\""
|
||||||
fi
|
fi
|
||||||
i=$((i+1))
|
i=`expr $i + 1`
|
||||||
done
|
done
|
||||||
case $i in
|
case $i in
|
||||||
(0) set -- ;;
|
0) set -- ;;
|
||||||
(1) set -- "$args0" ;;
|
1) set -- "$args0" ;;
|
||||||
(2) set -- "$args0" "$args1" ;;
|
2) set -- "$args0" "$args1" ;;
|
||||||
(3) set -- "$args0" "$args1" "$args2" ;;
|
3) set -- "$args0" "$args1" "$args2" ;;
|
||||||
(4) set -- "$args0" "$args1" "$args2" "$args3" ;;
|
4) set -- "$args0" "$args1" "$args2" "$args3" ;;
|
||||||
(5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;;
|
5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;;
|
||||||
(6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;;
|
6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;;
|
||||||
(7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;;
|
7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;;
|
||||||
(8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;;
|
8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;;
|
||||||
(9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;;
|
9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;;
|
||||||
esac
|
esac
|
||||||
fi
|
fi
|
||||||
|
|
||||||
@ -175,14 +175,9 @@ save () {
|
|||||||
for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done
|
for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done
|
||||||
echo " "
|
echo " "
|
||||||
}
|
}
|
||||||
APP_ARGS=$(save "$@")
|
APP_ARGS=`save "$@"`
|
||||||
|
|
||||||
# Collect all arguments for the java command, following the shell quoting and substitution rules
|
# Collect all arguments for the java command, following the shell quoting and substitution rules
|
||||||
eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS"
|
eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS"
|
||||||
|
|
||||||
# by default we should be in the correct project dir, but when run from Finder on Mac, the cwd is wrong
|
|
||||||
if [ "$(uname)" = "Darwin" ] && [ "$HOME" = "$PWD" ]; then
|
|
||||||
cd "$(dirname "$0")"
|
|
||||||
fi
|
|
||||||
|
|
||||||
exec "$JAVACMD" "$@"
|
exec "$JAVACMD" "$@"
|
||||||
|
@ -5,6 +5,8 @@ package scientifik.kmath.operations
|
|||||||
*/
|
*/
|
||||||
interface Algebra
|
interface Algebra
|
||||||
|
|
||||||
|
inline operator fun <T : Algebra, R> T.invoke(block: T.() -> R): R = run(block)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Space-like operations without neutral element
|
* Space-like operations without neutral element
|
||||||
*/
|
*/
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
package scientifik.kmath.operations
|
package scientifik.kmath.operations
|
||||||
|
|
||||||
import kotlin.math.abs
|
import kotlin.math.abs
|
||||||
import kotlin.math.pow
|
import kotlin.math.pow as kpow
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Advanced Number-like field that implements basic operations
|
* Advanced Number-like field that implements basic operations
|
||||||
@ -45,7 +45,7 @@ object RealField : ExtendedField<Double>, Norm<Double, Double> {
|
|||||||
override inline fun sin(arg: Double) = kotlin.math.sin(arg)
|
override inline fun sin(arg: Double) = kotlin.math.sin(arg)
|
||||||
override inline fun cos(arg: Double) = kotlin.math.cos(arg)
|
override inline fun cos(arg: Double) = kotlin.math.cos(arg)
|
||||||
|
|
||||||
override inline fun power(arg: Double, pow: Number) = arg.pow(pow.toDouble())
|
override inline fun power(arg: Double, pow: Number) = arg.kpow(pow.toDouble())
|
||||||
|
|
||||||
override inline fun exp(arg: Double) = kotlin.math.exp(arg)
|
override inline fun exp(arg: Double) = kotlin.math.exp(arg)
|
||||||
override inline fun ln(arg: Double) = kotlin.math.ln(arg)
|
override inline fun ln(arg: Double) = kotlin.math.ln(arg)
|
||||||
|
@ -70,4 +70,13 @@ class BoxingNDField<T, F : Field<T>>(
|
|||||||
|
|
||||||
override fun NDBuffer<T>.toElement(): FieldElement<NDBuffer<T>, *, out BufferedNDField<T, F>> =
|
override fun NDBuffer<T>.toElement(): FieldElement<NDBuffer<T>, *, out BufferedNDField<T, F>> =
|
||||||
BufferedNDFieldElement(this@BoxingNDField, buffer)
|
BufferedNDFieldElement(this@BoxingNDField, buffer)
|
||||||
|
}
|
||||||
|
|
||||||
|
inline fun <T : Any, F : Field<T>, R> F.nd(
|
||||||
|
noinline bufferFactory: BufferFactory<T>,
|
||||||
|
vararg shape: Int,
|
||||||
|
action: NDField<T, F, *>.() -> R
|
||||||
|
): R {
|
||||||
|
val ndfield: BoxingNDField<T, F> = NDField.boxing(this, *shape, bufferFactory = bufferFactory)
|
||||||
|
return ndfield.action()
|
||||||
}
|
}
|
@ -134,4 +134,11 @@ operator fun ComplexNDElement.minus(arg: Double) =
|
|||||||
fun NDField.Companion.complex(vararg shape: Int): ComplexNDField = ComplexNDField(shape)
|
fun NDField.Companion.complex(vararg shape: Int): ComplexNDField = ComplexNDField(shape)
|
||||||
|
|
||||||
fun NDElement.Companion.complex(vararg shape: Int, initializer: ComplexField.(IntArray) -> Complex): ComplexNDElement =
|
fun NDElement.Companion.complex(vararg shape: Int, initializer: ComplexField.(IntArray) -> Complex): ComplexNDElement =
|
||||||
NDField.complex(*shape).produce(initializer)
|
NDField.complex(*shape).produce(initializer)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Produce a context for n-dimensional operations inside this real field
|
||||||
|
*/
|
||||||
|
inline fun <R> ComplexField.nd(vararg shape: Int, action: ComplexNDField.() -> R): R {
|
||||||
|
return NDField.complex(*shape).run(action)
|
||||||
|
}
|
@ -119,3 +119,10 @@ operator fun RealNDElement.plus(arg: Double) =
|
|||||||
*/
|
*/
|
||||||
operator fun RealNDElement.minus(arg: Double) =
|
operator fun RealNDElement.minus(arg: Double) =
|
||||||
map { it - arg }
|
map { it - arg }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Produce a context for n-dimensional operations inside this real field
|
||||||
|
*/
|
||||||
|
inline fun <R> RealField.nd(vararg shape: Int, action: RealNDField.() -> R): R {
|
||||||
|
return NDField.real(*shape).run(action)
|
||||||
|
}
|
@ -6,7 +6,7 @@ import kotlin.test.assertEquals
|
|||||||
class RealFieldTest {
|
class RealFieldTest {
|
||||||
@Test
|
@Test
|
||||||
fun testSqrt() {
|
fun testSqrt() {
|
||||||
val sqrt = with(RealField) {
|
val sqrt = RealField {
|
||||||
sqrt(25 * one)
|
sqrt(25 * one)
|
||||||
}
|
}
|
||||||
assertEquals(5.0, sqrt)
|
assertEquals(5.0, sqrt)
|
||||||
|
@ -122,23 +122,23 @@ class ConstantChain<out T>(val value: T) : Chain<T> {
|
|||||||
* Map the chain result using suspended transformation. Initial chain result can no longer be safely consumed
|
* Map the chain result using suspended transformation. Initial chain result can no longer be safely consumed
|
||||||
* since mapped chain consumes tokens. Accepts regular transformation function
|
* since mapped chain consumes tokens. Accepts regular transformation function
|
||||||
*/
|
*/
|
||||||
fun <T, R> Chain<T>.pipe(func: suspend (T) -> R): Chain<R> = object : Chain<R> {
|
fun <T, R> Chain<T>.map(func: suspend (T) -> R): Chain<R> = object : Chain<R> {
|
||||||
override suspend fun next(): R = func(this@pipe.next())
|
override suspend fun next(): R = func(this@map.next())
|
||||||
override fun fork(): Chain<R> = this@pipe.fork().pipe(func)
|
override fun fork(): Chain<R> = this@map.fork().map(func)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Map the whole chain
|
* Map the whole chain
|
||||||
*/
|
*/
|
||||||
fun <T, R> Chain<T>.map(mapper: suspend (Chain<T>) -> R): Chain<R> = object : Chain<R> {
|
fun <T, R> Chain<T>.collect(mapper: suspend (Chain<T>) -> R): Chain<R> = object : Chain<R> {
|
||||||
override suspend fun next(): R = mapper(this@map)
|
override suspend fun next(): R = mapper(this@collect)
|
||||||
override fun fork(): Chain<R> = this@map.fork().map(mapper)
|
override fun fork(): Chain<R> = this@collect.fork().collect(mapper)
|
||||||
}
|
}
|
||||||
|
|
||||||
fun <T, S, R> Chain<T>.mapWithState(state: S, stateFork: (S) -> S, mapper: suspend S.(Chain<T>) -> R): Chain<R> =
|
fun <T, S, R> Chain<T>.collectWithState(state: S, stateFork: (S) -> S, mapper: suspend S.(Chain<T>) -> R): Chain<R> =
|
||||||
object : Chain<R> {
|
object : Chain<R> {
|
||||||
override suspend fun next(): R = state.mapper(this@mapWithState)
|
override suspend fun next(): R = state.mapper(this@collectWithState)
|
||||||
override fun fork(): Chain<R> = this@mapWithState.fork().mapWithState(stateFork(state), stateFork, mapper)
|
override fun fork(): Chain<R> = this@collectWithState.fork().collectWithState(stateFork(state), stateFork, mapper)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
package scientifik.kmath.prob
|
package scientifik.kmath.prob
|
||||||
|
|
||||||
import scientifik.kmath.chains.Chain
|
import scientifik.kmath.chains.Chain
|
||||||
import scientifik.kmath.chains.map
|
import scientifik.kmath.chains.collect
|
||||||
import kotlin.jvm.JvmName
|
import scientifik.kmath.structures.Buffer
|
||||||
|
import scientifik.kmath.structures.BufferFactory
|
||||||
|
|
||||||
interface Sampler<T : Any> {
|
interface Sampler<T : Any> {
|
||||||
fun sample(generator: RandomGenerator): Chain<T>
|
fun sample(generator: RandomGenerator): Chain<T>
|
||||||
@ -45,24 +46,27 @@ fun <T : Comparable<T>> UnivariateDistribution<T>.integral(from: T, to: T): Doub
|
|||||||
return cumulative(to) - cumulative(from)
|
return cumulative(to) - cumulative(from)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Sample a bunch of values
|
* Sample a bunch of values
|
||||||
*/
|
*/
|
||||||
fun <T : Any> Sampler<T>.sampleBunch(generator: RandomGenerator, size: Int): Chain<List<T>> {
|
fun <T : Any> Sampler<T>.sampleBuffer(
|
||||||
|
generator: RandomGenerator,
|
||||||
|
size: Int,
|
||||||
|
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing
|
||||||
|
): Chain<Buffer<T>> {
|
||||||
require(size > 1)
|
require(size > 1)
|
||||||
return sample(generator).map{chain ->
|
//creating temporary storage once
|
||||||
List(size){chain.next()}
|
val tmp = ArrayList<T>(size)
|
||||||
|
return sample(generator).collect { chain ->
|
||||||
|
for (i in tmp.indices) {
|
||||||
|
tmp[i] = chain.next()
|
||||||
|
}
|
||||||
|
bufferFactory(size) { tmp[it] }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Generate a bunch of samples from real distributions
|
* Generate a bunch of samples from real distributions
|
||||||
*/
|
*/
|
||||||
@JvmName("realSampleBunch")
|
fun Sampler<Double>.sampleBuffer(generator: RandomGenerator, size: Int) =
|
||||||
fun Sampler<Double>.sampleBunch(generator: RandomGenerator, size: Int): Chain<DoubleArray> {
|
sampleBuffer(generator, size, Buffer.Companion::real)
|
||||||
require(size > 1)
|
|
||||||
return sample(generator).map{chain ->
|
|
||||||
DoubleArray(size){chain.next()}
|
|
||||||
}
|
|
||||||
}
|
|
@ -23,4 +23,25 @@ class FactorizedDistribution<T>(val distributions: Collection<NamedDistribution<
|
|||||||
chains.fold(emptyMap()) { acc, chain -> acc + chain.next() }
|
chains.fold(emptyMap()) { acc, chain -> acc + chain.next() }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
class NamedDistributionWrapper<T : Any>(val name: String, val distribution: Distribution<T>) : NamedDistribution<T> {
|
||||||
|
override fun probability(arg: Map<String, T>): Double = distribution.probability(
|
||||||
|
arg[name] ?: error("Argument with name $name not found in input parameters")
|
||||||
|
)
|
||||||
|
|
||||||
|
override fun sample(generator: RandomGenerator): Chain<Map<String, T>> {
|
||||||
|
val chain = distribution.sample(generator)
|
||||||
|
return SimpleChain {
|
||||||
|
mapOf(name to chain.next())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
class DistributionBuilder<T: Any>{
|
||||||
|
private val distributions = ArrayList<NamedDistribution<T>>()
|
||||||
|
|
||||||
|
infix fun String.to(distribution: Distribution<T>){
|
||||||
|
distributions.add(NamedDistributionWrapper(this,distribution))
|
||||||
|
}
|
||||||
}
|
}
|
@ -2,7 +2,7 @@ package scientifik.kmath.prob
|
|||||||
|
|
||||||
import scientifik.kmath.chains.Chain
|
import scientifik.kmath.chains.Chain
|
||||||
import scientifik.kmath.chains.ConstantChain
|
import scientifik.kmath.chains.ConstantChain
|
||||||
import scientifik.kmath.chains.pipe
|
import scientifik.kmath.chains.map
|
||||||
import scientifik.kmath.chains.zip
|
import scientifik.kmath.chains.zip
|
||||||
import scientifik.kmath.operations.Space
|
import scientifik.kmath.operations.Space
|
||||||
|
|
||||||
@ -26,6 +26,6 @@ class SamplerSpace<T : Any>(val space: Space<T>) : Space<Sampler<T>> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
override fun multiply(a: Sampler<T>, k: Number): Sampler<T> = BasicSampler { generator ->
|
override fun multiply(a: Sampler<T>, k: Number): Sampler<T> = BasicSampler { generator ->
|
||||||
a.sample(generator).pipe { space.run { it * k.toDouble() } }
|
a.sample(generator).map { space.run { it * k.toDouble() } }
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -21,10 +21,16 @@ interface Statistic<T, R> {
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* A statistic tha could be computed separately on different blocks of data and then composed
|
* A statistic tha could be computed separately on different blocks of data and then composed
|
||||||
|
* @param T - source type
|
||||||
|
* @param I - intermediate block type
|
||||||
|
* @param R - result type
|
||||||
*/
|
*/
|
||||||
interface ComposableStatistic<T, I, R> : Statistic<T, R> {
|
interface ComposableStatistic<T, I, R> : Statistic<T, R> {
|
||||||
|
//compute statistic on a single block
|
||||||
suspend fun computeIntermediate(data: Buffer<T>): I
|
suspend fun computeIntermediate(data: Buffer<T>): I
|
||||||
|
//Compose two blocks
|
||||||
suspend fun composeIntermediate(first: I, second: I): I
|
suspend fun composeIntermediate(first: I, second: I): I
|
||||||
|
//Transform block to result
|
||||||
suspend fun toResult(intermediate: I): R
|
suspend fun toResult(intermediate: I): R
|
||||||
|
|
||||||
override suspend fun invoke(data: Buffer<T>): R = toResult(computeIntermediate(data))
|
override suspend fun invoke(data: Buffer<T>): R = toResult(computeIntermediate(data))
|
||||||
@ -32,7 +38,7 @@ interface ComposableStatistic<T, I, R> : Statistic<T, R> {
|
|||||||
|
|
||||||
@FlowPreview
|
@FlowPreview
|
||||||
@ExperimentalCoroutinesApi
|
@ExperimentalCoroutinesApi
|
||||||
fun <T, I, R> ComposableStatistic<T, I, R>.flowIntermediate(
|
private fun <T, I, R> ComposableStatistic<T, I, R>.flowIntermediate(
|
||||||
flow: Flow<Buffer<T>>,
|
flow: Flow<Buffer<T>>,
|
||||||
dispatcher: CoroutineDispatcher = Dispatchers.Default
|
dispatcher: CoroutineDispatcher = Dispatchers.Default
|
||||||
): Flow<I> = flow
|
): Flow<I> = flow
|
||||||
|
@ -0,0 +1,31 @@
|
|||||||
|
package scientifik.kmath.prob
|
||||||
|
|
||||||
|
import scientifik.kmath.chains.Chain
|
||||||
|
import scientifik.kmath.chains.SimpleChain
|
||||||
|
|
||||||
|
class UniformDistribution(val range: ClosedFloatingPointRange<Double>) : UnivariateDistribution<Double> {
|
||||||
|
|
||||||
|
private val length = range.endInclusive - range.start
|
||||||
|
|
||||||
|
override fun probability(arg: Double): Double {
|
||||||
|
return if (arg in range) {
|
||||||
|
return 1.0 / length
|
||||||
|
} else {
|
||||||
|
0.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun sample(generator: RandomGenerator): Chain<Double> {
|
||||||
|
return SimpleChain {
|
||||||
|
range.start + generator.nextDouble() * length
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun cumulative(arg: Double): Double {
|
||||||
|
return when {
|
||||||
|
arg < range.start -> 0.0
|
||||||
|
arg >= range.endInclusive -> 1.0
|
||||||
|
else -> (arg - range.start) / length
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -1,10 +1,10 @@
|
|||||||
pluginManagement {
|
pluginManagement {
|
||||||
|
|
||||||
plugins {
|
plugins {
|
||||||
id("scientifik.mpp") version "0.1.6"
|
id("scientifik.mpp") version "0.2.5"
|
||||||
id("scientifik.jvm") version "0.1.6"
|
id("scientifik.jvm") version "0.2.5"
|
||||||
id("scientifik.atomic") version "0.1.6"
|
id("scientifik.atomic") version "0.2.5"
|
||||||
id("scientifik.publish") version "0.1.6"
|
id("scientifik.publish") version "0.2.5"
|
||||||
}
|
}
|
||||||
|
|
||||||
repositories {
|
repositories {
|
||||||
@ -25,8 +25,6 @@ pluginManagement {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
enableFeaturePreview("GRADLE_METADATA")
|
|
||||||
|
|
||||||
rootProject.name = "kmath"
|
rootProject.name = "kmath"
|
||||||
include(
|
include(
|
||||||
":kmath-memory",
|
":kmath-memory",
|
||||||
|
Loading…
Reference in New Issue
Block a user