Merge pull request #137 from mipt-npm/dev

0.1.4
This commit is contained in:
Alexander Nozik 2020-09-14 22:49:29 +03:00 committed by GitHub
commit 95d33c25d0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
105 changed files with 1469 additions and 1030 deletions

View File

@ -1,6 +1,18 @@
# KMath # KMath
## [Unreleased] ## [Unreleased]
### Added
### Changed
### Deprecated
### Removed
### Fixed
### Security
## [0.1.4]
### Added ### Added
- Functional Expressions API - Functional Expressions API
@ -16,17 +28,23 @@
- Local coding conventions - Local coding conventions
- Geometric Domains API in `kmath-core` - Geometric Domains API in `kmath-core`
- Blocking chains in `kmath-coroutines` - Blocking chains in `kmath-coroutines`
- Full hyperbolic functions support and default implementations within `ExtendedField`
- Norm support for `Complex`
### Changed ### Changed
- `readAsMemory` now has `throws IOException` in JVM signature.
- Several functions taking functional types were made `inline`.
- Several functions taking functional types now have `callsInPlace` contracts.
- BigInteger and BigDecimal algebra: JBigDecimalField has companion object with default math context; minor optimizations - BigInteger and BigDecimal algebra: JBigDecimalField has companion object with default math context; minor optimizations
- `power(T, Int)` extension function has preconditions and supports `Field<T>` - `power(T, Int)` extension function has preconditions and supports `Field<T>`
- Memory objects have more preconditions (overflow checking) - Memory objects have more preconditions (overflow checking)
- `tg` function is renamed to `tan` (https://github.com/mipt-npm/kmath/pull/114) - `tg` function is renamed to `tan` (https://github.com/mipt-npm/kmath/pull/114)
- Gradle version: 6.3 -> 6.5.1 - Gradle version: 6.3 -> 6.6
- Moved probability distributions to commons-rng and to `kmath-prob`. - Moved probability distributions to commons-rng and to `kmath-prob`
### Fixed ### Fixed
- Missing copy method in Memory implementation on JS (https://github.com/mipt-npm/kmath/pull/106) - Missing copy method in Memory implementation on JS (https://github.com/mipt-npm/kmath/pull/106)
- D3.dim value in `kmath-dimensions` - D3.dim value in `kmath-dimensions`
- Multiplication in integer rings in `kmath-core` (https://github.com/mipt-npm/kmath/pull/101) - Multiplication in integer rings in `kmath-core` (https://github.com/mipt-npm/kmath/pull/101)
- Commons RNG compatibility (https://github.com/mipt-npm/kmath/issues/93) - Commons RNG compatibility (https://github.com/mipt-npm/kmath/issues/93)
- Multiplication of BigInt by scalar

View File

@ -1,8 +1,9 @@
plugins { plugins {
id("scientifik.publish") apply false id("scientifik.publish") apply false
id("org.jetbrains.changelog") version "0.4.0"
} }
val kmathVersion by extra("0.1.4-dev-8") val kmathVersion by extra("0.1.4")
val bintrayRepo by extra("scientifik") val bintrayRepo by extra("scientifik")
val githubProject by extra("kmath") val githubProject by extra("kmath")
@ -14,8 +15,18 @@ allprojects {
maven("https://dl.bintray.com/hotkeytlt/maven") maven("https://dl.bintray.com/hotkeytlt/maven")
} }
group = "scientifik" group = "kscience.kmath"
version = kmathVersion version = kmathVersion
afterEvaluate {
extensions.findByType<org.jetbrains.kotlin.gradle.dsl.KotlinMultiplatformExtension>()?.run {
targets.all {
sourceSets.all {
languageSettings.useExperimentalAnnotation("kotlin.contracts.ExperimentalContracts")
}
}
}
}
} }
subprojects { subprojects {

View File

@ -56,9 +56,16 @@ benchmark {
} }
} }
kotlin.sourceSets.all {
with(languageSettings) {
useExperimentalAnnotation("kotlin.contracts.ExperimentalContracts")
useExperimentalAnnotation("kotlin.ExperimentalUnsignedTypes")
}
}
tasks.withType<KotlinCompile> { tasks.withType<KotlinCompile> {
kotlinOptions { kotlinOptions {
jvmTarget = Scientifik.JVM_TARGET.toString() jvmTarget = Scientifik.JVM_TARGET.toString()
freeCompilerArgs = freeCompilerArgs + "-Xopt-in=kotlin.RequiresOptIn"
} }
} }

View File

@ -4,46 +4,38 @@ import org.openjdk.jmh.annotations.Benchmark
import org.openjdk.jmh.annotations.Scope import org.openjdk.jmh.annotations.Scope
import org.openjdk.jmh.annotations.State import org.openjdk.jmh.annotations.State
import scientifik.kmath.operations.RealField import scientifik.kmath.operations.RealField
import scientifik.kmath.operations.invoke
@State(Scope.Benchmark) @State(Scope.Benchmark)
class NDFieldBenchmark { class NDFieldBenchmark {
@Benchmark @Benchmark
fun autoFieldAdd() { fun autoFieldAdd() {
bufferedField.run { bufferedField {
var res: NDBuffer<Double> = one var res: NDBuffer<Double> = one
repeat(n) { repeat(n) { res += one }
res += one
}
} }
} }
@Benchmark @Benchmark
fun autoElementAdd() { fun autoElementAdd() {
var res = genericField.one var res = genericField.one
repeat(n) { repeat(n) { res += 1.0 }
res += 1.0
}
} }
@Benchmark @Benchmark
fun specializedFieldAdd() { fun specializedFieldAdd() {
specializedField.run { specializedField {
var res: NDBuffer<Double> = one var res: NDBuffer<Double> = one
repeat(n) { repeat(n) { res += 1.0 }
res += 1.0
}
} }
} }
@Benchmark @Benchmark
fun boxingFieldAdd() { fun boxingFieldAdd() {
genericField.run { genericField {
var res: NDBuffer<Double> = one var res: NDBuffer<Double> = one
repeat(n) { repeat(n) { res += one }
res += one
}
} }
} }

View File

@ -5,23 +5,22 @@ import org.openjdk.jmh.annotations.Benchmark
import org.openjdk.jmh.annotations.Scope import org.openjdk.jmh.annotations.Scope
import org.openjdk.jmh.annotations.State import org.openjdk.jmh.annotations.State
import scientifik.kmath.operations.RealField import scientifik.kmath.operations.RealField
import scientifik.kmath.operations.invoke
import scientifik.kmath.viktor.ViktorNDField import scientifik.kmath.viktor.ViktorNDField
@State(Scope.Benchmark) @State(Scope.Benchmark)
class ViktorBenchmark { class ViktorBenchmark {
final val dim = 1000 final val dim = 1000
final val n = 100 final val n = 100
// automatically build context most suited for given type. // automatically build context most suited for given type.
final val autoField = NDField.auto(RealField, dim, dim) final val autoField: BufferedNDField<Double, RealField> = NDField.auto(RealField, dim, dim)
final val realField = NDField.real(dim, dim) final val realField: RealNDField = NDField.real(dim, dim)
final val viktorField: ViktorNDField = ViktorNDField(intArrayOf(dim, dim))
final val viktorField = ViktorNDField(intArrayOf(dim, dim))
@Benchmark @Benchmark
fun automaticFieldAddition() { fun automaticFieldAddition() {
autoField.run { autoField {
var res = one var res = one
repeat(n) { res += one } repeat(n) { res += one }
} }
@ -29,7 +28,7 @@ class ViktorBenchmark {
@Benchmark @Benchmark
fun viktorFieldAddition() { fun viktorFieldAddition() {
viktorField.run { viktorField {
var res = one var res = one
repeat(n) { res += one } repeat(n) { res += one }
} }
@ -44,7 +43,7 @@ class ViktorBenchmark {
@Benchmark @Benchmark
fun realdFieldLog() { fun realdFieldLog() {
realField.run { realField {
val fortyTwo = produce { 42.0 } val fortyTwo = produce { 42.0 }
var res = one var res = one
repeat(n) { res = ln(fortyTwo) } repeat(n) { res = ln(fortyTwo) }

View File

@ -1,8 +1,11 @@
package scientifik.kmath.utils package scientifik.kmath.utils
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
import kotlin.system.measureTimeMillis import kotlin.system.measureTimeMillis
internal inline fun measureAndPrint(title: String, block: () -> Unit) { internal inline fun measureAndPrint(title: String, block: () -> Unit) {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
val time = measureTimeMillis(block) val time = measureTimeMillis(block)
println("$title completed in $time millis") println("$title completed in $time millis")
} }

View File

@ -5,6 +5,7 @@ import scientifik.kmath.commons.linear.CMMatrixContext
import scientifik.kmath.commons.linear.inverse import scientifik.kmath.commons.linear.inverse
import scientifik.kmath.commons.linear.toCM import scientifik.kmath.commons.linear.toCM
import scientifik.kmath.operations.RealField import scientifik.kmath.operations.RealField
import scientifik.kmath.operations.invoke
import scientifik.kmath.structures.Matrix import scientifik.kmath.structures.Matrix
import kotlin.contracts.ExperimentalContracts import kotlin.contracts.ExperimentalContracts
import kotlin.random.Random import kotlin.random.Random
@ -21,29 +22,18 @@ fun main() {
val n = 5000 // iterations val n = 5000 // iterations
MatrixContext.real.run { MatrixContext.real {
repeat(50) { val res = inverse(matrix) }
repeat(50) { val inverseTime = measureTimeMillis { repeat(n) { val res = inverse(matrix) } }
val res = inverse(matrix)
}
val inverseTime = measureTimeMillis {
repeat(n) {
val res = inverse(matrix)
}
}
println("[kmath] Inversion of $n matrices $dim x $dim finished in $inverseTime millis") println("[kmath] Inversion of $n matrices $dim x $dim finished in $inverseTime millis")
} }
//commons-math //commons-math
val commonsTime = measureTimeMillis { val commonsTime = measureTimeMillis {
CMMatrixContext.run { CMMatrixContext {
val cm = matrix.toCM() //avoid overhead on conversion val cm = matrix.toCM() //avoid overhead on conversion
repeat(n) { repeat(n) { val res = inverse(cm) }
val res = inverse(cm)
}
} }
} }
@ -53,7 +43,7 @@ fun main() {
//koma-ejml //koma-ejml
val komaTime = measureTimeMillis { val komaTime = measureTimeMillis {
KomaMatrixContext(EJMLMatrixFactory(), RealField).run { (KomaMatrixContext(EJMLMatrixFactory(), RealField)) {
val km = matrix.toKoma() //avoid overhead on conversion val km = matrix.toKoma() //avoid overhead on conversion
repeat(n) { repeat(n) {
val res = inverse(km) val res = inverse(km)

View File

@ -4,6 +4,7 @@ import koma.matrix.ejml.EJMLMatrixFactory
import scientifik.kmath.commons.linear.CMMatrixContext import scientifik.kmath.commons.linear.CMMatrixContext
import scientifik.kmath.commons.linear.toCM import scientifik.kmath.commons.linear.toCM
import scientifik.kmath.operations.RealField import scientifik.kmath.operations.RealField
import scientifik.kmath.operations.invoke
import scientifik.kmath.structures.Matrix import scientifik.kmath.structures.Matrix
import kotlin.random.Random import kotlin.random.Random
import kotlin.system.measureTimeMillis import kotlin.system.measureTimeMillis
@ -18,7 +19,7 @@ fun main() {
// //warmup // //warmup
// matrix1 dot matrix2 // matrix1 dot matrix2
CMMatrixContext.run { CMMatrixContext {
val cmMatrix1 = matrix1.toCM() val cmMatrix1 = matrix1.toCM()
val cmMatrix2 = matrix2.toCM() val cmMatrix2 = matrix2.toCM()
@ -29,8 +30,7 @@ fun main() {
println("CM implementation time: $cmTime") println("CM implementation time: $cmTime")
} }
(KomaMatrixContext(EJMLMatrixFactory(), RealField)) {
KomaMatrixContext(EJMLMatrixFactory(), RealField).run {
val komaMatrix1 = matrix1.toKoma() val komaMatrix1 = matrix1.toKoma()
val komaMatrix2 = matrix2.toKoma() val komaMatrix2 = matrix2.toKoma()

View File

@ -0,0 +1,8 @@
package scientifik.kmath.operations
fun main() {
val res = BigIntField {
number(1) * 2
}
println("bigint:$res")
}

View File

@ -9,13 +9,11 @@ fun main() {
Complex(index[0].toDouble() - index[1].toDouble(), index[0].toDouble() + index[1].toDouble()) Complex(index[0].toDouble() - index[1].toDouble(), index[0].toDouble() + index[1].toDouble())
} }
val compute = (NDField.complex(8)) {
val compute = NDField.complex(8).run {
val a = produce { (it) -> i * it - it.toDouble() } val a = produce { (it) -> i * it - it.toDouble() }
val b = 3 val b = 3
val c = Complex(1.0, 1.0) val c = Complex(1.0, 1.0)
(a pow b) + c (a pow b) + c
} }
} }

View File

@ -13,9 +13,8 @@ fun main() {
val realField = NDField.real(dim, dim) val realField = NDField.real(dim, dim)
val complexField = NDField.complex(dim, dim) val complexField = NDField.complex(dim, dim)
val realTime = measureTimeMillis { val realTime = measureTimeMillis {
realField.run { realField {
var res: NDBuffer<Double> = one var res: NDBuffer<Double> = one
repeat(n) { repeat(n) {
res += 1.0 res += 1.0
@ -26,18 +25,15 @@ fun main() {
println("Real addition completed in $realTime millis") println("Real addition completed in $realTime millis")
val complexTime = measureTimeMillis { val complexTime = measureTimeMillis {
complexField.run { complexField {
var res: NDBuffer<Complex> = one var res: NDBuffer<Complex> = one
repeat(n) { repeat(n) { res += 1.0 }
res += 1.0
}
} }
} }
println("Complex addition completed in $complexTime millis") println("Complex addition completed in $complexTime millis")
} }
fun complexExample() { fun complexExample() {
//Create a context for 2-d structure with complex values //Create a context for 2-d structure with complex values
ComplexField { ComplexField {
@ -46,10 +42,7 @@ fun complexExample() {
val x = one * 2.5 val x = one * 2.5
operator fun Number.plus(other: Complex) = Complex(this.toDouble() + other.re, other.im) operator fun Number.plus(other: Complex) = Complex(this.toDouble() + other.re, other.im)
//a structure generator specific to this context //a structure generator specific to this context
val matrix = produce { (k, l) -> val matrix = produce { (k, l) -> k + l * i }
k + l * i
}
//Perform sum //Perform sum
val sum = matrix + x + 1.0 val sum = matrix + x + 1.0

View File

@ -2,14 +2,18 @@ package scientifik.kmath.structures
import kotlinx.coroutines.GlobalScope import kotlinx.coroutines.GlobalScope
import scientifik.kmath.operations.RealField import scientifik.kmath.operations.RealField
import scientifik.kmath.operations.invoke
import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
import kotlin.system.measureTimeMillis import kotlin.system.measureTimeMillis
internal inline fun measureAndPrint(title: String, block: () -> Unit) { internal inline fun measureAndPrint(title: String, block: () -> Unit) {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
val time = measureTimeMillis(block) val time = measureTimeMillis(block)
println("$title completed in $time millis") println("$title completed in $time millis")
} }
fun main() { fun main() {
val dim = 1000 val dim = 1000
val n = 1000 val n = 1000
@ -22,27 +26,21 @@ fun main() {
val genericField = NDField.boxing(RealField, dim, dim) val genericField = NDField.boxing(RealField, dim, dim)
measureAndPrint("Automatic field addition") { measureAndPrint("Automatic field addition") {
autoField.run { autoField {
var res: NDBuffer<Double> = one var res: NDBuffer<Double> = one
repeat(n) { repeat(n) { res += number(1.0) }
res += number(1.0)
}
} }
} }
measureAndPrint("Element addition") { measureAndPrint("Element addition") {
var res = genericField.one var res = genericField.one
repeat(n) { repeat(n) { res += 1.0 }
res += 1.0
}
} }
measureAndPrint("Specialized addition") { measureAndPrint("Specialized addition") {
specializedField.run { specializedField {
var res: NDBuffer<Double> = one var res: NDBuffer<Double> = one
repeat(n) { repeat(n) { res += 1.0 }
res += 1.0
}
} }
} }
@ -60,12 +58,11 @@ fun main() {
measureAndPrint("Generic addition") { measureAndPrint("Generic addition") {
//genericField.run(action) //genericField.run(action)
genericField.run { genericField {
var res: NDBuffer<Double> = one var res: NDBuffer<Double> = one
repeat(n) { repeat(n) {
res += one // con't avoid using `one` due to resolution ambiguity res += one // couldn't avoid using `one` due to resolution ambiguity }
} }
} }
} }
} }

View File

@ -23,13 +23,10 @@ fun DMatrixContext<Double, RealField>.custom() {
val m1 = produce<D2, D5> { i, j -> (i + j).toDouble() } val m1 = produce<D2, D5> { i, j -> (i + j).toDouble() }
val m2 = produce<D5, D2> { i, j -> (i - j).toDouble() } val m2 = produce<D5, D2> { i, j -> (i - j).toDouble() }
val m3 = produce<D2, D2> { i, j -> (i - j).toDouble() } val m3 = produce<D2, D2> { i, j -> (i - j).toDouble() }
(m1 dot m2) + m3 (m1 dot m2) + m3
} }
fun main() { fun main(): Unit = with(DMatrixContext.real) {
DMatrixContext.real.run { simple()
simple() custom()
custom()
}
} }

View File

@ -1,11 +1,8 @@
plugins { id("scientifik.mpp") } plugins { id("scientifik.mpp") }
kotlin.sourceSets { kotlin.sourceSets {
// all { all { languageSettings.useExperimentalAnnotation("kotlin.contracts.ExperimentalContracts") }
// languageSettings.apply{
// enableLanguageFeature("NewInference")
// }
// }
commonMain { commonMain {
dependencies { dependencies {
api(project(":kmath-core")) api(project(":kmath-core"))

View File

@ -84,9 +84,9 @@ object MstExtendedField : ExtendedField<MST> {
override fun sin(arg: MST): MST = unaryOperation(TrigonometricOperations.SIN_OPERATION, arg) override fun sin(arg: MST): MST = unaryOperation(TrigonometricOperations.SIN_OPERATION, arg)
override fun cos(arg: MST): MST = unaryOperation(TrigonometricOperations.COS_OPERATION, arg) override fun cos(arg: MST): MST = unaryOperation(TrigonometricOperations.COS_OPERATION, arg)
override fun asin(arg: MST): MST = unaryOperation(InverseTrigonometricOperations.ASIN_OPERATION, arg) override fun asin(arg: MST): MST = unaryOperation(TrigonometricOperations.ASIN_OPERATION, arg)
override fun acos(arg: MST): MST = unaryOperation(InverseTrigonometricOperations.ACOS_OPERATION, arg) override fun acos(arg: MST): MST = unaryOperation(TrigonometricOperations.ACOS_OPERATION, arg)
override fun atan(arg: MST): MST = unaryOperation(InverseTrigonometricOperations.ATAN_OPERATION, arg) override fun atan(arg: MST): MST = unaryOperation(TrigonometricOperations.ATAN_OPERATION, arg)
override fun add(a: MST, b: MST): MST = MstField.add(a, b) override fun add(a: MST, b: MST): MST = MstField.add(a, b)
override fun multiply(a: MST, k: Number): MST = MstField.multiply(a, k) override fun multiply(a: MST, k: Number): MST = MstField.multiply(a, k)
override fun multiply(a: MST, b: MST): MST = MstField.multiply(a, b) override fun multiply(a: MST, b: MST): MST = MstField.multiply(a, b)

View File

@ -2,6 +2,9 @@ package scientifik.kmath.ast
import scientifik.kmath.expressions.* import scientifik.kmath.expressions.*
import scientifik.kmath.operations.* import scientifik.kmath.operations.*
import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
/** /**
* The expression evaluates MST on-flight. Should be much faster than functional expression, but slower than * The expression evaluates MST on-flight. Should be much faster than functional expression, but slower than
@ -24,7 +27,7 @@ class MstExpression<T>(val algebra: Algebra<T>, val mst: MST) : Expression<T> {
error("Numeric nodes are not supported by $this") error("Numeric nodes are not supported by $this")
} }
override fun invoke(arguments: Map<String, T>): T = InnerAlgebra(arguments).evaluate(mst) override operator fun invoke(arguments: Map<String, T>): T = InnerAlgebra(arguments).evaluate(mst)
} }
/** /**
@ -38,51 +41,63 @@ inline fun <reified T : Any, A : Algebra<T>, E : Algebra<MST>> A.mst(
/** /**
* Builds [MstExpression] over [Space]. * Builds [MstExpression] over [Space].
*/ */
inline fun <reified T : Any> Space<T>.mstInSpace(block: MstSpace.() -> MST): MstExpression<T> = inline fun <reified T : Any> Space<T>.mstInSpace(block: MstSpace.() -> MST): MstExpression<T> {
MstExpression(this, MstSpace.block()) contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return MstExpression(this, MstSpace.block())
}
/** /**
* Builds [MstExpression] over [Ring]. * Builds [MstExpression] over [Ring].
*/ */
inline fun <reified T : Any> Ring<T>.mstInRing(block: MstRing.() -> MST): MstExpression<T> = inline fun <reified T : Any> Ring<T>.mstInRing(block: MstRing.() -> MST): MstExpression<T> {
MstExpression(this, MstRing.block()) contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return MstExpression(this, MstRing.block())
}
/** /**
* Builds [MstExpression] over [Field]. * Builds [MstExpression] over [Field].
*/ */
inline fun <reified T : Any> Field<T>.mstInField(block: MstField.() -> MST): MstExpression<T> = inline fun <reified T : Any> Field<T>.mstInField(block: MstField.() -> MST): MstExpression<T> {
MstExpression(this, MstField.block()) contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return MstExpression(this, MstField.block())
}
/** /**
* Builds [MstExpression] over [ExtendedField]. * Builds [MstExpression] over [ExtendedField].
*/ */
inline fun <reified T : Any> Field<T>.mstInExtendedField(block: MstExtendedField.() -> MST): MstExpression<T> = inline fun <reified T : Any> Field<T>.mstInExtendedField(block: MstExtendedField.() -> MST): MstExpression<T> {
MstExpression(this, MstExtendedField.block()) contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return MstExpression(this, MstExtendedField.block())
}
/** /**
* Builds [MstExpression] over [FunctionalExpressionSpace]. * Builds [MstExpression] over [FunctionalExpressionSpace].
*/ */
inline fun <reified T : Any, A : Space<T>> FunctionalExpressionSpace<T, A>.mstInSpace( inline fun <reified T : Any, A : Space<T>> FunctionalExpressionSpace<T, A>.mstInSpace(block: MstSpace.() -> MST): MstExpression<T> {
block: MstSpace.() -> MST contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
): MstExpression<T> = algebra.mstInSpace(block) return algebra.mstInSpace(block)
}
/** /**
* Builds [MstExpression] over [FunctionalExpressionRing]. * Builds [MstExpression] over [FunctionalExpressionRing].
*/ */
inline fun <reified T : Any, A : Ring<T>> FunctionalExpressionRing<T, A>.mstInRing( inline fun <reified T : Any, A : Ring<T>> FunctionalExpressionRing<T, A>.mstInRing(block: MstRing.() -> MST): MstExpression<T> {
block: MstRing.() -> MST contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
): MstExpression<T> = algebra.mstInRing(block) return algebra.mstInRing(block)
}
/** /**
* Builds [MstExpression] over [FunctionalExpressionField]. * Builds [MstExpression] over [FunctionalExpressionField].
*/ */
inline fun <reified T : Any, A : Field<T>> FunctionalExpressionField<T, A>.mstInField( inline fun <reified T : Any, A : Field<T>> FunctionalExpressionField<T, A>.mstInField(block: MstField.() -> MST): MstExpression<T> {
block: MstField.() -> MST contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
): MstExpression<T> = algebra.mstInField(block) return algebra.mstInField(block)
}
/** /**
* Builds [MstExpression] over [FunctionalExpressionExtendedField]. * Builds [MstExpression] over [FunctionalExpressionExtendedField].
*/ */
inline fun <reified T : Any, A : ExtendedField<T>> FunctionalExpressionExtendedField<T, A>.mstInExtendedField( inline fun <reified T : Any, A : ExtendedField<T>> FunctionalExpressionExtendedField<T, A>.mstInExtendedField(block: MstExtendedField.() -> MST): MstExpression<T> {
block: MstExtendedField.() -> MST contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
): MstExpression<T> = algebra.mstInExtendedField(block) return algebra.mstInExtendedField(block)
}

View File

@ -7,6 +7,9 @@ import scientifik.kmath.ast.MST
import scientifik.kmath.expressions.Expression import scientifik.kmath.expressions.Expression
import scientifik.kmath.operations.Algebra import scientifik.kmath.operations.Algebra
import java.lang.reflect.Method import java.lang.reflect.Method
import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
import kotlin.reflect.KClass import kotlin.reflect.KClass
private val methodNameAdapters: Map<Pair<String, Int>, String> by lazy { private val methodNameAdapters: Map<Pair<String, Int>, String> by lazy {
@ -26,8 +29,10 @@ internal val KClass<*>.asm: Type
/** /**
* Returns singleton array with this value if the [predicate] is true, returns empty array otherwise. * Returns singleton array with this value if the [predicate] is true, returns empty array otherwise.
*/ */
internal inline fun <reified T> T.wrapToArrayIf(predicate: (T) -> Boolean): Array<T> = internal inline fun <reified T> T.wrapToArrayIf(predicate: (T) -> Boolean): Array<T> {
if (predicate(this)) arrayOf(this) else emptyArray() contract { callsInPlace(predicate, InvocationKind.EXACTLY_ONCE) }
return if (predicate(this)) arrayOf(this) else emptyArray()
}
/** /**
* Creates an [InstructionAdapter] from this [MethodVisitor]. * Creates an [InstructionAdapter] from this [MethodVisitor].
@ -37,8 +42,10 @@ private fun MethodVisitor.instructionAdapter(): InstructionAdapter = Instruction
/** /**
* Creates an [InstructionAdapter] from this [MethodVisitor] and applies [block] to it. * Creates an [InstructionAdapter] from this [MethodVisitor] and applies [block] to it.
*/ */
internal fun MethodVisitor.instructionAdapter(block: InstructionAdapter.() -> Unit): InstructionAdapter = internal inline fun MethodVisitor.instructionAdapter(block: InstructionAdapter.() -> Unit): InstructionAdapter {
instructionAdapter().apply(block) contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return instructionAdapter().apply(block)
}
/** /**
* Constructs a [Label], then applies it to this visitor. * Constructs a [Label], then applies it to this visitor.
@ -64,8 +71,10 @@ internal tailrec fun buildName(mst: MST, collision: Int = 0): String {
} }
@Suppress("FunctionName") @Suppress("FunctionName")
internal inline fun ClassWriter(flags: Int, block: ClassWriter.() -> Unit): ClassWriter = internal inline fun ClassWriter(flags: Int, block: ClassWriter.() -> Unit): ClassWriter {
ClassWriter(flags).apply(block) contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return ClassWriter(flags).apply(block)
}
internal inline fun ClassWriter.visitField( internal inline fun ClassWriter.visitField(
access: Int, access: Int,
@ -74,7 +83,10 @@ internal inline fun ClassWriter.visitField(
signature: String?, signature: String?,
value: Any?, value: Any?,
block: FieldVisitor.() -> Unit block: FieldVisitor.() -> Unit
): FieldVisitor = visitField(access, name, descriptor, signature, value).apply(block) ): FieldVisitor {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return visitField(access, name, descriptor, signature, value).apply(block)
}
private fun <T> AsmBuilder<T>.findSpecific(context: Algebra<T>, name: String, parameterTypes: Array<MstType>): Method? = private fun <T> AsmBuilder<T>.findSpecific(context: Algebra<T>, name: String, parameterTypes: Array<MstType>): Method? =
context.javaClass.methods.find { method -> context.javaClass.methods.find { method ->
@ -158,6 +170,7 @@ internal inline fun <T> AsmBuilder<T>.buildAlgebraOperationCall(
parameterTypes: Array<MstType>, parameterTypes: Array<MstType>,
parameters: AsmBuilder<T>.() -> Unit parameters: AsmBuilder<T>.() -> Unit
) { ) {
contract { callsInPlace(parameters, InvocationKind.EXACTLY_ONCE) }
val arity = parameterTypes.size val arity = parameterTypes.size
loadAlgebra() loadAlgebra()
if (!buildExpectationStack(context, name, parameterTypes)) loadStringConstant(name) if (!buildExpectationStack(context, name, parameterTypes)) loadStringConstant(name)

View File

@ -1,7 +1,4 @@
plugins { plugins { id("scientifik.jvm") }
id("scientifik.jvm")
}
description = "Commons math binding for kmath" description = "Commons math binding for kmath"
dependencies { dependencies {
@ -11,3 +8,5 @@ dependencies {
api(project(":kmath-functions")) api(project(":kmath-functions"))
api("org.apache.commons:commons-math3:3.6.1") api("org.apache.commons:commons-math3:3.6.1")
} }
kotlin.sourceSets.all { languageSettings.useExperimentalAnnotation("kotlin.contracts.ExperimentalContracts") }

View File

@ -5,6 +5,7 @@ import scientifik.kmath.expressions.Expression
import scientifik.kmath.expressions.ExpressionAlgebra import scientifik.kmath.expressions.ExpressionAlgebra
import scientifik.kmath.operations.ExtendedField import scientifik.kmath.operations.ExtendedField
import scientifik.kmath.operations.Field import scientifik.kmath.operations.Field
import scientifik.kmath.operations.invoke
import kotlin.properties.ReadOnlyProperty import kotlin.properties.ReadOnlyProperty
import kotlin.reflect.KProperty import kotlin.reflect.KProperty
@ -15,26 +16,22 @@ class DerivativeStructureField(
val order: Int, val order: Int,
val parameters: Map<String, Double> val parameters: Map<String, Double>
) : ExtendedField<DerivativeStructure> { ) : ExtendedField<DerivativeStructure> {
override val zero: DerivativeStructure by lazy { DerivativeStructure(order, parameters.size) } override val zero: DerivativeStructure by lazy { DerivativeStructure(order, parameters.size) }
override val one: DerivativeStructure by lazy { DerivativeStructure(order, parameters.size, 1.0) } override val one: DerivativeStructure by lazy { DerivativeStructure(order, parameters.size, 1.0) }
private val variables: Map<String, DerivativeStructure> = parameters.mapValues { (key, value) -> private val variables: Map<String, DerivativeStructure> = parameters.mapValues { (key, value) ->
DerivativeStructure(parameters.size, order, parameters.keys.indexOf(key), value) DerivativeStructure(parameters.size, order, parameters.keys.indexOf(key), value)
} }
val variable = object : ReadOnlyProperty<Any?, DerivativeStructure> { val variable: ReadOnlyProperty<Any?, DerivativeStructure> = object : ReadOnlyProperty<Any?, DerivativeStructure> {
override fun getValue(thisRef: Any?, property: KProperty<*>): DerivativeStructure { override fun getValue(thisRef: Any?, property: KProperty<*>): DerivativeStructure =
return variables[property.name] ?: error("A variable with name ${property.name} does not exist") variables[property.name] ?: error("A variable with name ${property.name} does not exist")
}
} }
fun variable(name: String, default: DerivativeStructure? = null): DerivativeStructure = fun variable(name: String, default: DerivativeStructure? = null): DerivativeStructure =
variables[name] ?: default ?: error("A variable with name $name does not exist") variables[name] ?: default ?: error("A variable with name $name does not exist")
fun Number.const(): DerivativeStructure = DerivativeStructure(order, parameters.size, toDouble())
fun Number.const() = DerivativeStructure(order, parameters.size, toDouble())
fun DerivativeStructure.deriv(parName: String, order: Int = 1): Double { fun DerivativeStructure.deriv(parName: String, order: Int = 1): Double {
return deriv(mapOf(parName to order)) return deriv(mapOf(parName to order))
@ -60,10 +57,18 @@ class DerivativeStructureField(
override fun sin(arg: DerivativeStructure): DerivativeStructure = arg.sin() override fun sin(arg: DerivativeStructure): DerivativeStructure = arg.sin()
override fun cos(arg: DerivativeStructure): DerivativeStructure = arg.cos() override fun cos(arg: DerivativeStructure): DerivativeStructure = arg.cos()
override fun tan(arg: DerivativeStructure): DerivativeStructure = arg.tan()
override fun asin(arg: DerivativeStructure): DerivativeStructure = arg.asin() override fun asin(arg: DerivativeStructure): DerivativeStructure = arg.asin()
override fun acos(arg: DerivativeStructure): DerivativeStructure = arg.acos() override fun acos(arg: DerivativeStructure): DerivativeStructure = arg.acos()
override fun atan(arg: DerivativeStructure): DerivativeStructure = arg.atan() override fun atan(arg: DerivativeStructure): DerivativeStructure = arg.atan()
override fun sinh(arg: DerivativeStructure): DerivativeStructure = arg.sinh()
override fun cosh(arg: DerivativeStructure): DerivativeStructure = arg.cosh()
override fun tanh(arg: DerivativeStructure): DerivativeStructure = arg.tanh()
override fun asinh(arg: DerivativeStructure): DerivativeStructure = arg.asinh()
override fun acosh(arg: DerivativeStructure): DerivativeStructure = arg.acosh()
override fun atanh(arg: DerivativeStructure): DerivativeStructure = arg.atanh()
override fun power(arg: DerivativeStructure, pow: Number): DerivativeStructure = when (pow) { override fun power(arg: DerivativeStructure, pow: Number): DerivativeStructure = when (pow) {
is Double -> arg.pow(pow) is Double -> arg.pow(pow)
is Int -> arg.pow(pow) is Int -> arg.pow(pow)
@ -71,23 +76,20 @@ class DerivativeStructureField(
} }
fun power(arg: DerivativeStructure, pow: DerivativeStructure): DerivativeStructure = arg.pow(pow) fun power(arg: DerivativeStructure, pow: DerivativeStructure): DerivativeStructure = arg.pow(pow)
override fun exp(arg: DerivativeStructure): DerivativeStructure = arg.exp() override fun exp(arg: DerivativeStructure): DerivativeStructure = arg.exp()
override fun ln(arg: DerivativeStructure): DerivativeStructure = arg.log() override fun ln(arg: DerivativeStructure): DerivativeStructure = arg.log()
override operator fun DerivativeStructure.plus(b: Number): DerivativeStructure = add(b.toDouble()) override operator fun DerivativeStructure.plus(b: Number): DerivativeStructure = add(b.toDouble())
override operator fun DerivativeStructure.minus(b: Number): DerivativeStructure = subtract(b.toDouble()) override operator fun DerivativeStructure.minus(b: Number): DerivativeStructure = subtract(b.toDouble())
override operator fun Number.plus(b: DerivativeStructure) = b + this override operator fun Number.plus(b: DerivativeStructure): DerivativeStructure = b + this
override operator fun Number.minus(b: DerivativeStructure) = b - this override operator fun Number.minus(b: DerivativeStructure): DerivativeStructure = b - this
} }
/** /**
* A constructs that creates a derivative structure with required order on-demand * A constructs that creates a derivative structure with required order on-demand
*/ */
class DiffExpression(val function: DerivativeStructureField.() -> DerivativeStructure) : Expression<Double> { class DiffExpression(val function: DerivativeStructureField.() -> DerivativeStructure) : Expression<Double> {
override operator fun invoke(arguments: Map<String, Double>): Double = DerivativeStructureField(
override fun invoke(arguments: Map<String, Double>): Double = DerivativeStructureField(
0, 0,
arguments arguments
).run(function).value ).run(function).value
@ -96,45 +98,40 @@ class DiffExpression(val function: DerivativeStructureField.() -> DerivativeStru
* Get the derivative expression with given orders * Get the derivative expression with given orders
* TODO make result [DiffExpression] * TODO make result [DiffExpression]
*/ */
fun derivative(orders: Map<String, Int>): Expression<Double> { fun derivative(orders: Map<String, Int>): Expression<Double> = object : Expression<Double> {
return object : Expression<Double> { override operator fun invoke(arguments: Map<String, Double>): Double =
override fun invoke(arguments: Map<String, Double>): Double = (DerivativeStructureField(orders.values.max() ?: 0, arguments)) { function().deriv(orders) }
DerivativeStructureField(orders.values.max() ?: 0, arguments)
.run {
function().deriv(orders)
}
}
} }
//TODO add gradient and maybe other vector operators //TODO add gradient and maybe other vector operators
} }
fun DiffExpression.derivative(vararg orders: Pair<String, Int>) = derivative(mapOf(*orders)) fun DiffExpression.derivative(vararg orders: Pair<String, Int>): Expression<Double> = derivative(mapOf(*orders))
fun DiffExpression.derivative(name: String) = derivative(name to 1) fun DiffExpression.derivative(name: String): Expression<Double> = derivative(name to 1)
/** /**
* A context for [DiffExpression] (not to be confused with [DerivativeStructure]) * A context for [DiffExpression] (not to be confused with [DerivativeStructure])
*/ */
object DiffExpressionAlgebra : ExpressionAlgebra<Double, DiffExpression>, Field<DiffExpression> { object DiffExpressionAlgebra : ExpressionAlgebra<Double, DiffExpression>, Field<DiffExpression> {
override fun variable(name: String, default: Double?) = override fun variable(name: String, default: Double?): DiffExpression =
DiffExpression { variable(name, default?.const()) } DiffExpression { variable(name, default?.const()) }
override fun const(value: Double): DiffExpression = override fun const(value: Double): DiffExpression =
DiffExpression { value.const() } DiffExpression { value.const() }
override fun add(a: DiffExpression, b: DiffExpression) = override fun add(a: DiffExpression, b: DiffExpression): DiffExpression =
DiffExpression { a.function(this) + b.function(this) } DiffExpression { a.function(this) + b.function(this) }
override val zero = DiffExpression { 0.0.const() } override val zero: DiffExpression = DiffExpression { 0.0.const() }
override fun multiply(a: DiffExpression, k: Number) = override fun multiply(a: DiffExpression, k: Number): DiffExpression =
DiffExpression { a.function(this) * k } DiffExpression { a.function(this) * k }
override val one = DiffExpression { 1.0.const() } override val one: DiffExpression = DiffExpression { 1.0.const() }
override fun multiply(a: DiffExpression, b: DiffExpression) = override fun multiply(a: DiffExpression, b: DiffExpression): DiffExpression =
DiffExpression { a.function(this) * b.function(this) } DiffExpression { a.function(this) * b.function(this) }
override fun divide(a: DiffExpression, b: DiffExpression) = override fun divide(a: DiffExpression, b: DiffExpression): DiffExpression =
DiffExpression { a.function(this) / b.function(this) } DiffExpression { a.function(this) / b.function(this) }
} }

View File

@ -1,8 +1,6 @@
package scientifik.kmath.commons.linear package scientifik.kmath.commons.linear
import org.apache.commons.math3.linear.* import org.apache.commons.math3.linear.*
import org.apache.commons.math3.linear.RealMatrix
import org.apache.commons.math3.linear.RealVector
import scientifik.kmath.linear.* import scientifik.kmath.linear.*
import scientifik.kmath.structures.Matrix import scientifik.kmath.structures.Matrix
import scientifik.kmath.structures.NDStructure import scientifik.kmath.structures.NDStructure
@ -14,12 +12,12 @@ class CMMatrix(val origin: RealMatrix, features: Set<MatrixFeature>? = null) :
override val features: Set<MatrixFeature> = features ?: sequence<MatrixFeature> { override val features: Set<MatrixFeature> = features ?: sequence<MatrixFeature> {
if (origin is DiagonalMatrix) yield(DiagonalFeature) if (origin is DiagonalMatrix) yield(DiagonalFeature)
}.toSet() }.toHashSet()
override fun suggestFeature(vararg features: MatrixFeature) = override fun suggestFeature(vararg features: MatrixFeature): CMMatrix =
CMMatrix(origin, this.features + features) CMMatrix(origin, this.features + features)
override fun get(i: Int, j: Int): Double = origin.getEntry(i, j) override operator fun get(i: Int, j: Int): Double = origin.getEntry(i, j)
override fun equals(other: Any?): Boolean { override fun equals(other: Any?): Boolean {
return NDStructure.equals(this, other as? NDStructure<*> ?: return false) return NDStructure.equals(this, other as? NDStructure<*> ?: return false)
@ -40,24 +38,22 @@ fun Matrix<Double>.toCM(): CMMatrix = if (this is CMMatrix) {
CMMatrix(Array2DRowRealMatrix(array)) CMMatrix(Array2DRowRealMatrix(array))
} }
fun RealMatrix.asMatrix() = CMMatrix(this) fun RealMatrix.asMatrix(): CMMatrix = CMMatrix(this)
class CMVector(val origin: RealVector) : Point<Double> { class CMVector(val origin: RealVector) : Point<Double> {
override val size: Int get() = origin.dimension override val size: Int get() = origin.dimension
override fun get(index: Int): Double = origin.getEntry(index) override operator fun get(index: Int): Double = origin.getEntry(index)
override fun iterator(): Iterator<Double> = origin.toArray().iterator() override operator fun iterator(): Iterator<Double> = origin.toArray().iterator()
} }
fun Point<Double>.toCM(): CMVector = if (this is CMVector) { fun Point<Double>.toCM(): CMVector = if (this is CMVector) this else {
this
} else {
val array = DoubleArray(size) { this[it] } val array = DoubleArray(size) { this[it] }
CMVector(ArrayRealVector(array)) CMVector(ArrayRealVector(array))
} }
fun RealVector.toPoint() = CMVector(this) fun RealVector.toPoint(): CMVector = CMVector(this)
object CMMatrixContext : MatrixContext<Double> { object CMMatrixContext : MatrixContext<Double> {
override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> Double): CMMatrix { override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> Double): CMMatrix {
@ -65,30 +61,31 @@ object CMMatrixContext : MatrixContext<Double> {
return CMMatrix(Array2DRowRealMatrix(array)) return CMMatrix(Array2DRowRealMatrix(array))
} }
override fun Matrix<Double>.dot(other: Matrix<Double>) = override fun Matrix<Double>.dot(other: Matrix<Double>): CMMatrix =
CMMatrix(this.toCM().origin.multiply(other.toCM().origin)) CMMatrix(this.toCM().origin.multiply(other.toCM().origin))
override fun Matrix<Double>.dot(vector: Point<Double>): CMVector = override fun Matrix<Double>.dot(vector: Point<Double>): CMVector =
CMVector(this.toCM().origin.preMultiply(vector.toCM().origin)) CMVector(this.toCM().origin.preMultiply(vector.toCM().origin))
override fun Matrix<Double>.unaryMinus(): CMMatrix = override operator fun Matrix<Double>.unaryMinus(): CMMatrix =
produce(rowNum, colNum) { i, j -> -get(i, j) } produce(rowNum, colNum) { i, j -> -get(i, j) }
override fun add(a: Matrix<Double>, b: Matrix<Double>) = override fun add(a: Matrix<Double>, b: Matrix<Double>): CMMatrix =
CMMatrix(a.toCM().origin.multiply(b.toCM().origin)) CMMatrix(a.toCM().origin.multiply(b.toCM().origin))
override fun Matrix<Double>.minus(b: Matrix<Double>) = override operator fun Matrix<Double>.minus(b: Matrix<Double>): CMMatrix =
CMMatrix(this.toCM().origin.subtract(b.toCM().origin)) CMMatrix(this.toCM().origin.subtract(b.toCM().origin))
override fun multiply(a: Matrix<Double>, k: Number) = override fun multiply(a: Matrix<Double>, k: Number): CMMatrix =
CMMatrix(a.toCM().origin.scalarMultiply(k.toDouble())) CMMatrix(a.toCM().origin.scalarMultiply(k.toDouble()))
override fun Matrix<Double>.times(value: Double): Matrix<Double> = override operator fun Matrix<Double>.times(value: Double): Matrix<Double> =
produce(rowNum, colNum) { i, j -> get(i, j) * value } produce(rowNum, colNum) { i, j -> get(i, j) * value }
} }
operator fun CMMatrix.plus(other: CMMatrix): CMMatrix = operator fun CMMatrix.plus(other: CMMatrix): CMMatrix =
CMMatrix(this.origin.add(other.origin)) CMMatrix(this.origin.add(other.origin))
operator fun CMMatrix.minus(other: CMMatrix): CMMatrix = operator fun CMMatrix.minus(other: CMMatrix): CMMatrix =
CMMatrix(this.origin.subtract(other.origin)) CMMatrix(this.origin.subtract(other.origin))

View File

@ -4,10 +4,9 @@ import scientifik.kmath.prob.RandomGenerator
class CMRandomGeneratorWrapper(val factory: (IntArray) -> RandomGenerator) : class CMRandomGeneratorWrapper(val factory: (IntArray) -> RandomGenerator) :
org.apache.commons.math3.random.RandomGenerator { org.apache.commons.math3.random.RandomGenerator {
private var generator = factory(intArrayOf()) private var generator: RandomGenerator = factory(intArrayOf())
override fun nextBoolean(): Boolean = generator.nextBoolean() override fun nextBoolean(): Boolean = generator.nextBoolean()
override fun nextFloat(): Float = generator.nextDouble().toFloat() override fun nextFloat(): Float = generator.nextDouble().toFloat()
override fun setSeed(seed: Int) { override fun setSeed(seed: Int) {
@ -27,12 +26,8 @@ class CMRandomGeneratorWrapper(val factory: (IntArray) -> RandomGenerator) :
} }
override fun nextInt(): Int = generator.nextInt() override fun nextInt(): Int = generator.nextInt()
override fun nextInt(n: Int): Int = generator.nextInt(n) override fun nextInt(n: Int): Int = generator.nextInt(n)
override fun nextGaussian(): Double = TODO() override fun nextGaussian(): Double = TODO()
override fun nextDouble(): Double = generator.nextDouble() override fun nextDouble(): Double = generator.nextDouble()
override fun nextLong(): Long = generator.nextLong() override fun nextLong(): Long = generator.nextLong()
} }

View File

@ -1,11 +1,15 @@
package scientifik.kmath.commons.expressions package scientifik.kmath.commons.expressions
import scientifik.kmath.expressions.invoke import scientifik.kmath.expressions.invoke
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
inline fun <R> diff(order: Int, vararg parameters: Pair<String, Double>, block: DerivativeStructureField.() -> R) = inline fun <R> diff(order: Int, vararg parameters: Pair<String, Double>, block: DerivativeStructureField.() -> R): R {
DerivativeStructureField(order, mapOf(*parameters)).run(block) contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return DerivativeStructureField(order, mapOf(*parameters)).run(block)
}
class AutoDiffTest { class AutoDiffTest {
@Test @Test

View File

@ -1,7 +1,11 @@
plugins { id("scientifik.mpp") } plugins {
id("scientifik.mpp")
}
kotlin.sourceSets { kotlin.sourceSets {
commonMain { commonMain {
dependencies { api(project(":kmath-memory")) } dependencies {
api(project(":kmath-memory"))
}
} }
} }

View File

@ -4,28 +4,38 @@ import scientifik.kmath.operations.ExtendedField
import scientifik.kmath.operations.Field import scientifik.kmath.operations.Field
import scientifik.kmath.operations.Ring import scientifik.kmath.operations.Ring
import scientifik.kmath.operations.Space import scientifik.kmath.operations.Space
import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
/** /**
* Creates a functional expression with this [Space]. * Creates a functional expression with this [Space].
*/ */
fun <T> Space<T>.spaceExpression(block: FunctionalExpressionSpace<T, Space<T>>.() -> Expression<T>): Expression<T> = inline fun <T> Space<T>.spaceExpression(block: FunctionalExpressionSpace<T, Space<T>>.() -> Expression<T>): Expression<T> {
FunctionalExpressionSpace(this).run(block) contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return FunctionalExpressionSpace(this).block()
}
/** /**
* Creates a functional expression with this [Ring]. * Creates a functional expression with this [Ring].
*/ */
fun <T> Ring<T>.ringExpression(block: FunctionalExpressionRing<T, Ring<T>>.() -> Expression<T>): Expression<T> = inline fun <T> Ring<T>.ringExpression(block: FunctionalExpressionRing<T, Ring<T>>.() -> Expression<T>): Expression<T> {
FunctionalExpressionRing(this).run(block) contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return FunctionalExpressionRing(this).block()
}
/** /**
* Creates a functional expression with this [Field]. * Creates a functional expression with this [Field].
*/ */
fun <T> Field<T>.fieldExpression(block: FunctionalExpressionField<T, Field<T>>.() -> Expression<T>): Expression<T> = inline fun <T> Field<T>.fieldExpression(block: FunctionalExpressionField<T, Field<T>>.() -> Expression<T>): Expression<T> {
FunctionalExpressionField(this).run(block) contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return FunctionalExpressionField(this).block()
}
/** /**
* Creates a functional expression with this [ExtendedField]. * Creates a functional expression with this [ExtendedField].
*/ */
fun <T> ExtendedField<T>.fieldExpression( inline fun <T> ExtendedField<T>.extendedFieldExpression(block: FunctionalExpressionExtendedField<T, ExtendedField<T>>.() -> Expression<T>): Expression<T> {
block: FunctionalExpressionExtendedField<T, ExtendedField<T>>.() -> Expression<T> contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
): Expression<T> = FunctionalExpressionExtendedField(this).run(block) return FunctionalExpressionExtendedField(this).block()
}

View File

@ -22,7 +22,7 @@ interface Expression<T> {
*/ */
fun <T> Algebra<T>.expression(block: Algebra<T>.(arguments: Map<String, T>) -> T): Expression<T> = fun <T> Algebra<T>.expression(block: Algebra<T>.(arguments: Map<String, T>) -> T): Expression<T> =
object : Expression<T> { object : Expression<T> {
override fun invoke(arguments: Map<String, T>): T = block(arguments) override operator fun invoke(arguments: Map<String, T>): T = block(arguments)
} }
/** /**

View File

@ -4,7 +4,7 @@ import scientifik.kmath.operations.*
internal class FunctionalUnaryOperation<T>(val context: Algebra<T>, val name: String, private val expr: Expression<T>) : internal class FunctionalUnaryOperation<T>(val context: Algebra<T>, val name: String, private val expr: Expression<T>) :
Expression<T> { Expression<T> {
override fun invoke(arguments: Map<String, T>): T = context.unaryOperation(name, expr.invoke(arguments)) override operator fun invoke(arguments: Map<String, T>): T = context.unaryOperation(name, expr.invoke(arguments))
} }
internal class FunctionalBinaryOperation<T>( internal class FunctionalBinaryOperation<T>(
@ -13,17 +13,17 @@ internal class FunctionalBinaryOperation<T>(
val first: Expression<T>, val first: Expression<T>,
val second: Expression<T> val second: Expression<T>
) : Expression<T> { ) : Expression<T> {
override fun invoke(arguments: Map<String, T>): T = override operator fun invoke(arguments: Map<String, T>): T =
context.binaryOperation(name, first.invoke(arguments), second.invoke(arguments)) context.binaryOperation(name, first.invoke(arguments), second.invoke(arguments))
} }
internal class FunctionalVariableExpression<T>(val name: String, val default: T? = null) : Expression<T> { internal class FunctionalVariableExpression<T>(val name: String, val default: T? = null) : Expression<T> {
override fun invoke(arguments: Map<String, T>): T = override operator fun invoke(arguments: Map<String, T>): T =
arguments[name] ?: default ?: error("Parameter not found: $name") arguments[name] ?: default ?: error("Parameter not found: $name")
} }
internal class FunctionalConstantExpression<T>(val value: T) : Expression<T> { internal class FunctionalConstantExpression<T>(val value: T) : Expression<T> {
override fun invoke(arguments: Map<String, T>): T = value override operator fun invoke(arguments: Map<String, T>): T = value
} }
internal class FunctionalConstProductExpression<T>( internal class FunctionalConstProductExpression<T>(
@ -31,7 +31,7 @@ internal class FunctionalConstProductExpression<T>(
private val expr: Expression<T>, private val expr: Expression<T>,
val const: Number val const: Number
) : Expression<T> { ) : Expression<T> {
override fun invoke(arguments: Map<String, T>): T = context.multiply(expr.invoke(arguments), const) override operator fun invoke(arguments: Map<String, T>): T = context.multiply(expr.invoke(arguments), const)
} }
/** /**
@ -139,15 +139,9 @@ open class FunctionalExpressionExtendedField<T, A>(algebra: A) :
ExtendedField<Expression<T>> where A : ExtendedField<T>, A : NumericAlgebra<T> { ExtendedField<Expression<T>> where A : ExtendedField<T>, A : NumericAlgebra<T> {
override fun sin(arg: Expression<T>): Expression<T> = unaryOperation(TrigonometricOperations.SIN_OPERATION, arg) override fun sin(arg: Expression<T>): Expression<T> = unaryOperation(TrigonometricOperations.SIN_OPERATION, arg)
override fun cos(arg: Expression<T>): Expression<T> = unaryOperation(TrigonometricOperations.COS_OPERATION, arg) override fun cos(arg: Expression<T>): Expression<T> = unaryOperation(TrigonometricOperations.COS_OPERATION, arg)
override fun asin(arg: Expression<T>): Expression<T> = unaryOperation(TrigonometricOperations.ASIN_OPERATION, arg)
override fun asin(arg: Expression<T>): Expression<T> = override fun acos(arg: Expression<T>): Expression<T> = unaryOperation(TrigonometricOperations.ACOS_OPERATION, arg)
unaryOperation(InverseTrigonometricOperations.ASIN_OPERATION, arg) override fun atan(arg: Expression<T>): Expression<T> = unaryOperation(TrigonometricOperations.ATAN_OPERATION, arg)
override fun acos(arg: Expression<T>): Expression<T> =
unaryOperation(InverseTrigonometricOperations.ACOS_OPERATION, arg)
override fun atan(arg: Expression<T>): Expression<T> =
unaryOperation(InverseTrigonometricOperations.ATAN_OPERATION, arg)
override fun power(arg: Expression<T>, pow: Number): Expression<T> = override fun power(arg: Expression<T>, pow: Number): Expression<T> =
binaryOperation(PowerOperations.POW_OPERATION, arg, number(pow)) binaryOperation(PowerOperations.POW_OPERATION, arg, number(pow))

View File

@ -53,16 +53,12 @@ class BufferMatrix<T : Any>(
override fun suggestFeature(vararg features: MatrixFeature): BufferMatrix<T> = override fun suggestFeature(vararg features: MatrixFeature): BufferMatrix<T> =
BufferMatrix(rowNum, colNum, buffer, this.features + features) BufferMatrix(rowNum, colNum, buffer, this.features + features)
override fun get(index: IntArray): T = get(index[0], index[1]) override operator fun get(index: IntArray): T = get(index[0], index[1])
override fun get(i: Int, j: Int): T = buffer[i * colNum + j] override operator fun get(i: Int, j: Int): T = buffer[i * colNum + j]
override fun elements(): Sequence<Pair<IntArray, T>> = sequence { override fun elements(): Sequence<Pair<IntArray, T>> = sequence {
for (i in 0 until rowNum) { for (i in 0 until rowNum) for (j in 0 until colNum) yield(intArrayOf(i, j) to get(i, j))
for (j in 0 until colNum) {
yield(intArrayOf(i, j) to get(i, j))
}
}
} }
override fun equals(other: Any?): Boolean { override fun equals(other: Any?): Boolean {
@ -95,7 +91,7 @@ class BufferMatrix<T : Any>(
* Optimized dot product for real matrices * Optimized dot product for real matrices
*/ */
infix fun BufferMatrix<Double>.dot(other: BufferMatrix<Double>): BufferMatrix<Double> { infix fun BufferMatrix<Double>.dot(other: BufferMatrix<Double>): BufferMatrix<Double> {
if (this.colNum != other.rowNum) error("Matrix dot operation dimension mismatch: ($rowNum, $colNum) x (${other.rowNum}, ${other.colNum})") require(colNum == other.rowNum) { "Matrix dot operation dimension mismatch: ($rowNum, $colNum) x (${other.rowNum}, ${other.colNum})" }
val array = DoubleArray(this.rowNum * other.colNum) val array = DoubleArray(this.rowNum * other.colNum)

View File

@ -4,6 +4,8 @@ import scientifik.kmath.operations.Ring
import scientifik.kmath.structures.Matrix import scientifik.kmath.structures.Matrix
import scientifik.kmath.structures.Structure2D import scientifik.kmath.structures.Structure2D
import scientifik.kmath.structures.asBuffer import scientifik.kmath.structures.asBuffer
import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.contract
import kotlin.math.sqrt import kotlin.math.sqrt
/** /**
@ -26,15 +28,17 @@ interface FeaturedMatrix<T : Any> : Matrix<T> {
companion object companion object
} }
fun Structure2D.Companion.real(rows: Int, columns: Int, initializer: (Int, Int) -> Double): Matrix<Double> = inline fun Structure2D.Companion.real(rows: Int, columns: Int, initializer: (Int, Int) -> Double): Matrix<Double> {
MatrixContext.real.produce(rows, columns, initializer) contract { callsInPlace(initializer) }
return MatrixContext.real.produce(rows, columns, initializer)
}
/** /**
* Build a square matrix from given elements. * Build a square matrix from given elements.
*/ */
fun <T : Any> Structure2D.Companion.square(vararg elements: T): FeaturedMatrix<T> { fun <T : Any> Structure2D.Companion.square(vararg elements: T): FeaturedMatrix<T> {
val size: Int = sqrt(elements.size.toDouble()).toInt() val size: Int = sqrt(elements.size.toDouble()).toInt()
if (size * size != elements.size) error("The number of elements ${elements.size} is not a full square") require(size * size == elements.size) { "The number of elements ${elements.size} is not a full square" }
val buffer = elements.asBuffer() val buffer = elements.asBuffer()
return BufferMatrix(size, size, buffer) return BufferMatrix(size, size, buffer)
} }

View File

@ -3,6 +3,7 @@ package scientifik.kmath.linear
import scientifik.kmath.operations.Field import scientifik.kmath.operations.Field
import scientifik.kmath.operations.RealField import scientifik.kmath.operations.RealField
import scientifik.kmath.operations.Ring import scientifik.kmath.operations.Ring
import scientifik.kmath.operations.invoke
import scientifik.kmath.structures.BufferAccessor2D import scientifik.kmath.structures.BufferAccessor2D
import scientifik.kmath.structures.Matrix import scientifik.kmath.structures.Matrix
import scientifik.kmath.structures.Structure2D import scientifik.kmath.structures.Structure2D
@ -60,15 +61,13 @@ class LUPDecomposition<T : Any>(
* @return determinant of the matrix * @return determinant of the matrix
*/ */
override val determinant: T by lazy { override val determinant: T by lazy {
with(elementContext) { elementContext { (0 until lu.shape[0]).fold(if (even) one else -one) { value, i -> value * lu[i, i] } }
(0 until lu.shape[0]).fold(if (even) one else -one) { value, i -> value * lu[i, i] }
}
} }
} }
fun <T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.abs(value: T): T = fun <T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.abs(value: T): T =
if (value > elementContext.zero) value else with(elementContext) { -value } if (value > elementContext.zero) value else elementContext { -value }
/** /**
@ -88,43 +87,34 @@ fun <T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.lup(
//TODO just waits for KEEP-176 //TODO just waits for KEEP-176
BufferAccessor2D(type, matrix.rowNum, matrix.colNum).run { BufferAccessor2D(type, matrix.rowNum, matrix.colNum).run {
elementContext.run { elementContext {
val lu = create(matrix) val lu = create(matrix)
// Initialize permutation array and parity // Initialize permutation array and parity
for (row in 0 until m) { for (row in 0 until m) pivot[row] = row
pivot[row] = row
}
var even = true var even = true
// Initialize permutation array and parity // Initialize permutation array and parity
for (row in 0 until m) { for (row in 0 until m) pivot[row] = row
pivot[row] = row
}
// Loop over columns // Loop over columns
for (col in 0 until m) { for (col in 0 until m) {
// upper // upper
for (row in 0 until col) { for (row in 0 until col) {
val luRow = lu.row(row) val luRow = lu.row(row)
var sum = luRow[col] var sum = luRow[col]
for (i in 0 until row) { for (i in 0 until row) sum -= luRow[i] * lu[i, col]
sum -= luRow[i] * lu[i, col]
}
luRow[col] = sum luRow[col] = sum
} }
// lower // lower
var max = col // permutation row var max = col // permutation row
var largest = -one var largest = -one
for (row in col until m) { for (row in col until m) {
val luRow = lu.row(row) val luRow = lu.row(row)
var sum = luRow[col] var sum = luRow[col]
for (i in 0 until col) { for (i in 0 until col) sum -= luRow[i] * lu[i, col]
sum -= luRow[i] * lu[i, col]
}
luRow[col] = sum luRow[col] = sum
// maintain best permutation choice // maintain best permutation choice
@ -135,19 +125,19 @@ fun <T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.lup(
} }
// Singularity check // Singularity check
if (checkSingular(this@lup.abs(lu[max, col]))) { check(!checkSingular(this@lup.abs(lu[max, col]))) { "The matrix is singular" }
error("The matrix is singular")
}
// Pivot if necessary // Pivot if necessary
if (max != col) { if (max != col) {
val luMax = lu.row(max) val luMax = lu.row(max)
val luCol = lu.row(col) val luCol = lu.row(col)
for (i in 0 until m) { for (i in 0 until m) {
val tmp = luMax[i] val tmp = luMax[i]
luMax[i] = luCol[i] luMax[i] = luCol[i]
luCol[i] = tmp luCol[i] = tmp
} }
val temp = pivot[max] val temp = pivot[max]
pivot[max] = pivot[col] pivot[max] = pivot[col]
pivot[col] = temp pivot[col] = temp
@ -156,9 +146,7 @@ fun <T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.lup(
// Divide the lower elements by the "winning" diagonal elt. // Divide the lower elements by the "winning" diagonal elt.
val luDiag = lu[col, col] val luDiag = lu[col, col]
for (row in col + 1 until m) { for (row in col + 1 until m) lu[row, col] /= luDiag
lu[row, col] /= luDiag
}
} }
return LUPDecomposition(this@lup, lu.collect(), pivot, even) return LUPDecomposition(this@lup, lu.collect(), pivot, even)
@ -175,28 +163,23 @@ fun GenericMatrixContext<Double, RealField>.lup(matrix: Matrix<Double>): LUPDeco
lup(Double::class, matrix) { it < 1e-11 } lup(Double::class, matrix) { it < 1e-11 }
fun <T : Any> LUPDecomposition<T>.solve(type: KClass<T>, matrix: Matrix<T>): Matrix<T> { fun <T : Any> LUPDecomposition<T>.solve(type: KClass<T>, matrix: Matrix<T>): Matrix<T> {
require(matrix.rowNum == pivot.size) { "Matrix dimension mismatch. Expected ${pivot.size}, but got ${matrix.colNum}" }
if (matrix.rowNum != pivot.size) {
error("Matrix dimension mismatch. Expected ${pivot.size}, but got ${matrix.colNum}")
}
BufferAccessor2D(type, matrix.rowNum, matrix.colNum).run { BufferAccessor2D(type, matrix.rowNum, matrix.colNum).run {
elementContext.run { elementContext {
// Apply permutations to b // Apply permutations to b
val bp = create { _, _ -> zero } val bp = create { _, _ -> zero }
for (row in pivot.indices) { for (row in pivot.indices) {
val bpRow = bp.row(row) val bpRow = bp.row(row)
val pRow = pivot[row] val pRow = pivot[row]
for (col in 0 until matrix.colNum) { for (col in 0 until matrix.colNum) bpRow[col] = matrix[pRow, col]
bpRow[col] = matrix[pRow, col]
}
} }
// Solve LY = b // Solve LY = b
for (col in pivot.indices) { for (col in pivot.indices) {
val bpCol = bp.row(col) val bpCol = bp.row(col)
for (i in col + 1 until pivot.size) { for (i in col + 1 until pivot.size) {
val bpI = bp.row(i) val bpI = bp.row(i)
val luICol = lu[i, col] val luICol = lu[i, col]
@ -210,17 +193,15 @@ fun <T : Any> LUPDecomposition<T>.solve(type: KClass<T>, matrix: Matrix<T>): Mat
for (col in pivot.size - 1 downTo 0) { for (col in pivot.size - 1 downTo 0) {
val bpCol = bp.row(col) val bpCol = bp.row(col)
val luDiag = lu[col, col] val luDiag = lu[col, col]
for (j in 0 until matrix.colNum) { for (j in 0 until matrix.colNum) bpCol[j] /= luDiag
bpCol[j] /= luDiag
}
for (i in 0 until col) { for (i in 0 until col) {
val bpI = bp.row(i) val bpI = bp.row(i)
val luICol = lu[i, col] val luICol = lu[i, col]
for (j in 0 until matrix.colNum) { for (j in 0 until matrix.colNum) bpI[j] -= bpCol[j] * luICol
bpI[j] -= bpCol[j] * luICol
}
} }
} }
return context.produce(pivot.size, matrix.colNum) { i, j -> bp[i, j] } return context.produce(pivot.size, matrix.colNum) { i, j -> bp[i, j] }
} }
} }

View File

@ -7,7 +7,7 @@ import scientifik.kmath.structures.asBuffer
class MatrixBuilder(val rows: Int, val columns: Int) { class MatrixBuilder(val rows: Int, val columns: Int) {
operator fun <T : Any> invoke(vararg elements: T): FeaturedMatrix<T> { operator fun <T : Any> invoke(vararg elements: T): FeaturedMatrix<T> {
if (rows * columns != elements.size) error("The number of elements ${elements.size} is not equal $rows * $columns") require(rows * columns == elements.size) { "The number of elements ${elements.size} is not equal $rows * $columns" }
val buffer = elements.asBuffer() val buffer = elements.asBuffer()
return BufferMatrix(rows, columns, buffer) return BufferMatrix(rows, columns, buffer)
} }

View File

@ -2,6 +2,7 @@ package scientifik.kmath.linear
import scientifik.kmath.operations.Ring import scientifik.kmath.operations.Ring
import scientifik.kmath.operations.SpaceOperations import scientifik.kmath.operations.SpaceOperations
import scientifik.kmath.operations.invoke
import scientifik.kmath.operations.sum import scientifik.kmath.operations.sum
import scientifik.kmath.structures.Buffer import scientifik.kmath.structures.Buffer
import scientifik.kmath.structures.BufferFactory import scientifik.kmath.structures.BufferFactory
@ -37,8 +38,7 @@ interface MatrixContext<T : Any> : SpaceOperations<Matrix<T>> {
fun <T : Any, R : Ring<T>> buffered( fun <T : Any, R : Ring<T>> buffered(
ring: R, ring: R,
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing bufferFactory: BufferFactory<T> = Buffer.Companion::boxing
): GenericMatrixContext<T, R> = ): GenericMatrixContext<T, R> = BufferMatrixContext(ring, bufferFactory)
BufferMatrixContext(ring, bufferFactory)
/** /**
* Automatic buffered matrix, unboxed if it is possible * Automatic buffered matrix, unboxed if it is possible
@ -61,45 +61,49 @@ interface GenericMatrixContext<T : Any, R : Ring<T>> : MatrixContext<T> {
override infix fun Matrix<T>.dot(other: Matrix<T>): Matrix<T> { override infix fun Matrix<T>.dot(other: Matrix<T>): Matrix<T> {
//TODO add typed error //TODO add typed error
if (this.colNum != other.rowNum) error("Matrix dot operation dimension mismatch: ($rowNum, $colNum) x (${other.rowNum}, ${other.colNum})") require(colNum == other.rowNum) { "Matrix dot operation dimension mismatch: ($rowNum, $colNum) x (${other.rowNum}, ${other.colNum})" }
return produce(rowNum, other.colNum) { i, j -> return produce(rowNum, other.colNum) { i, j ->
val row = rows[i] val row = rows[i]
val column = other.columns[j] val column = other.columns[j]
with(elementContext) { elementContext { sum(row.asSequence().zip(column.asSequence(), ::multiply)) }
sum(row.asSequence().zip(column.asSequence(), ::multiply))
}
} }
} }
override infix fun Matrix<T>.dot(vector: Point<T>): Point<T> { override infix fun Matrix<T>.dot(vector: Point<T>): Point<T> {
//TODO add typed error //TODO add typed error
if (this.colNum != vector.size) error("Matrix dot vector operation dimension mismatch: ($rowNum, $colNum) x (${vector.size})") require(colNum == vector.size) { "Matrix dot vector operation dimension mismatch: ($rowNum, $colNum) x (${vector.size})" }
return point(rowNum) { i -> return point(rowNum) { i ->
val row = rows[i] val row = rows[i]
with(elementContext) { elementContext { sum(row.asSequence().zip(vector.asSequence(), ::multiply)) }
sum(row.asSequence().zip(vector.asSequence(), ::multiply))
}
} }
} }
override operator fun Matrix<T>.unaryMinus(): Matrix<T> = override operator fun Matrix<T>.unaryMinus(): Matrix<T> =
produce(rowNum, colNum) { i, j -> elementContext.run { -get(i, j) } } produce(rowNum, colNum) { i, j -> elementContext { -get(i, j) } }
override fun add(a: Matrix<T>, b: Matrix<T>): Matrix<T> { override fun add(a: Matrix<T>, b: Matrix<T>): Matrix<T> {
if (a.rowNum != b.rowNum || a.colNum != b.colNum) error("Matrix operation dimension mismatch. [${a.rowNum},${a.colNum}] + [${b.rowNum},${b.colNum}]") require(a.rowNum == b.rowNum && a.colNum == b.colNum) {
return produce(a.rowNum, a.colNum) { i, j -> elementContext.run { a[i, j] + b[i, j] } } "Matrix operation dimension mismatch. [${a.rowNum},${a.colNum}] + [${b.rowNum},${b.colNum}]"
}
return produce(a.rowNum, a.colNum) { i, j -> elementContext { a[i, j] + b[i, j] } }
} }
override operator fun Matrix<T>.minus(b: Matrix<T>): Matrix<T> { override operator fun Matrix<T>.minus(b: Matrix<T>): Matrix<T> {
if (rowNum != b.rowNum || colNum != b.colNum) error("Matrix operation dimension mismatch. [$rowNum,$colNum] - [${b.rowNum},${b.colNum}]") require(rowNum == b.rowNum && colNum == b.colNum) {
return produce(rowNum, colNum) { i, j -> elementContext.run { get(i, j) + b[i, j] } } "Matrix operation dimension mismatch. [$rowNum,$colNum] - [${b.rowNum},${b.colNum}]"
}
return produce(rowNum, colNum) { i, j -> elementContext { get(i, j) + b[i, j] } }
} }
override fun multiply(a: Matrix<T>, k: Number): Matrix<T> = override fun multiply(a: Matrix<T>, k: Number): Matrix<T> =
produce(a.rowNum, a.colNum) { i, j -> elementContext.run { a[i, j] * k } } produce(a.rowNum, a.colNum) { i, j -> elementContext { a[i, j] * k } }
operator fun Number.times(matrix: FeaturedMatrix<T>): Matrix<T> = matrix * this operator fun Number.times(matrix: FeaturedMatrix<T>): Matrix<T> = matrix * this
override fun Matrix<T>.times(value: T): Matrix<T> = override operator fun Matrix<T>.times(value: T): Matrix<T> =
produce(rowNum, colNum) { i, j -> elementContext.run { get(i, j) * value } } produce(rowNum, colNum) { i, j -> elementContext { get(i, j) * value } }
} }

View File

@ -2,6 +2,7 @@ package scientifik.kmath.linear
import scientifik.kmath.operations.RealField import scientifik.kmath.operations.RealField
import scientifik.kmath.operations.Space import scientifik.kmath.operations.Space
import scientifik.kmath.operations.invoke
import scientifik.kmath.structures.Buffer import scientifik.kmath.structures.Buffer
import scientifik.kmath.structures.BufferFactory import scientifik.kmath.structures.BufferFactory
@ -10,10 +11,9 @@ import scientifik.kmath.structures.BufferFactory
* Could be used on any point-like structure * Could be used on any point-like structure
*/ */
interface VectorSpace<T : Any, S : Space<T>> : Space<Point<T>> { interface VectorSpace<T : Any, S : Space<T>> : Space<Point<T>> {
val size: Int val size: Int
val space: S val space: S
override val zero: Point<T> get() = produce { space.zero }
fun produce(initializer: (Int) -> T): Point<T> fun produce(initializer: (Int) -> T): Point<T>
@ -22,29 +22,24 @@ interface VectorSpace<T : Any, S : Space<T>> : Space<Point<T>> {
*/ */
//fun produceElement(initializer: (Int) -> T): Vector<T, S> //fun produceElement(initializer: (Int) -> T): Vector<T, S>
override val zero: Point<T> get() = produce { space.zero } override fun add(a: Point<T>, b: Point<T>): Point<T> = produce { space { a[it] + b[it] } }
override fun add(a: Point<T>, b: Point<T>): Point<T> = produce { with(space) { a[it] + b[it] } } override fun multiply(a: Point<T>, k: Number): Point<T> = produce { space { a[it] * k } }
override fun multiply(a: Point<T>, k: Number): Point<T> = produce { with(space) { a[it] * k } }
//TODO add basis //TODO add basis
companion object { companion object {
private val realSpaceCache: MutableMap<Int, BufferVectorSpace<Double, RealField>> = hashMapOf()
private val realSpaceCache = HashMap<Int, BufferVectorSpace<Double, RealField>>()
/** /**
* Non-boxing double vector space * Non-boxing double vector space
*/ */
fun real(size: Int): BufferVectorSpace<Double, RealField> { fun real(size: Int): BufferVectorSpace<Double, RealField> = realSpaceCache.getOrPut(size) {
return realSpaceCache.getOrPut(size) { BufferVectorSpace(
BufferVectorSpace( size,
size, RealField,
RealField, Buffer.Companion::auto
Buffer.Companion::auto )
)
}
} }
/** /**

View File

@ -18,7 +18,7 @@ class VirtualMatrix<T : Any>(
override val shape: IntArray get() = intArrayOf(rowNum, colNum) override val shape: IntArray get() = intArrayOf(rowNum, colNum)
override fun get(i: Int, j: Int): T = generator(i, j) override operator fun get(i: Int, j: Int): T = generator(i, j)
override fun suggestFeature(vararg features: MatrixFeature): VirtualMatrix<T> = override fun suggestFeature(vararg features: MatrixFeature): VirtualMatrix<T> =
VirtualMatrix(rowNum, colNum, this.features + features, generator) VirtualMatrix(rowNum, colNum, this.features + features, generator)

View File

@ -3,8 +3,12 @@ package scientifik.kmath.misc
import scientifik.kmath.linear.Point import scientifik.kmath.linear.Point
import scientifik.kmath.operations.ExtendedField import scientifik.kmath.operations.ExtendedField
import scientifik.kmath.operations.Field import scientifik.kmath.operations.Field
import scientifik.kmath.operations.invoke
import scientifik.kmath.operations.sum import scientifik.kmath.operations.sum
import scientifik.kmath.structures.asBuffer import scientifik.kmath.structures.asBuffer
import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
/* /*
* Implementation of backward-mode automatic differentiation. * Implementation of backward-mode automatic differentiation.
@ -27,15 +31,14 @@ class DerivationResult<T : Any>(
/** /**
* compute divergence * compute divergence
*/ */
fun div(): T = context.run { sum(deriv.values) } fun div(): T = context { sum(deriv.values) }
/** /**
* Compute a gradient for variables in given order * Compute a gradient for variables in given order
*/ */
fun grad(vararg variables: Variable<T>): Point<T> = if (variables.isEmpty()) { fun grad(vararg variables: Variable<T>): Point<T> {
error("Variable order is not provided for gradient construction") check(variables.isNotEmpty()) { "Variable order is not provided for gradient construction" }
} else { return variables.map(::deriv).asBuffer()
variables.map(::deriv).asBuffer()
} }
} }
@ -52,19 +55,27 @@ class DerivationResult<T : Any>(
* assertEquals(9.0, x.d) // dy/dx * assertEquals(9.0, x.d) // dy/dx
* ``` * ```
*/ */
fun <T : Any, F : Field<T>> F.deriv(body: AutoDiffField<T, F>.() -> Variable<T>): DerivationResult<T> = inline fun <T : Any, F : Field<T>> F.deriv(body: AutoDiffField<T, F>.() -> Variable<T>): DerivationResult<T> {
AutoDiffContext(this).run { contract { callsInPlace(body, InvocationKind.EXACTLY_ONCE) }
return (AutoDiffContext(this)) {
val result = body() val result = body()
result.d = context.one// computing derivative w.r.t result result.d = context.one // computing derivative w.r.t result
runBackwardPass() runBackwardPass()
DerivationResult(result.value, derivatives, this@deriv) DerivationResult(result.value, derivatives, this@deriv)
} }
}
abstract class AutoDiffField<T : Any, F : Field<T>> : Field<Variable<T>> { abstract class AutoDiffField<T : Any, F : Field<T>> : Field<Variable<T>> {
abstract val context: F abstract val context: F
/**
* A variable accessing inner state of derivatives.
* Use this function in inner builders to avoid creating additional derivative bindings
*/
abstract var Variable<T>.d: T
/** /**
* Performs update of derivative after the rest of the formula in the back-pass. * Performs update of derivative after the rest of the formula in the back-pass.
* *
@ -78,12 +89,6 @@ abstract class AutoDiffField<T : Any, F : Field<T>> : Field<Variable<T>> {
*/ */
abstract fun <R> derive(value: R, block: F.(R) -> Unit): R abstract fun <R> derive(value: R, block: F.(R) -> Unit): R
/**
* A variable accessing inner state of derivatives.
* Use this function in inner builders to avoid creating additional derivative bindings
*/
abstract var Variable<T>.d: T
abstract fun variable(value: T): Variable<T> abstract fun variable(value: T): Variable<T>
inline fun variable(block: F.() -> T): Variable<T> = variable(context.block()) inline fun variable(block: F.() -> T): Variable<T> = variable(context.block())
@ -98,46 +103,35 @@ abstract class AutoDiffField<T : Any, F : Field<T>> : Field<Variable<T>> {
override operator fun Variable<T>.plus(b: Number): Variable<T> = b.plus(this) override operator fun Variable<T>.plus(b: Number): Variable<T> = b.plus(this)
override operator fun Number.minus(b: Variable<T>): Variable<T> = override operator fun Number.minus(b: Variable<T>): Variable<T> =
derive(variable { this@minus.toDouble() * one - b.value }) { z -> derive(variable { this@minus.toDouble() * one - b.value }) { z -> b.d -= z.d }
b.d -= z.d
}
override operator fun Variable<T>.minus(b: Number): Variable<T> = override operator fun Variable<T>.minus(b: Number): Variable<T> =
derive(variable { this@minus.value - one * b.toDouble() }) { z -> derive(variable { this@minus.value - one * b.toDouble() }) { z -> this@minus.d += z.d }
this@minus.d += z.d
}
} }
/** /**
* Automatic Differentiation context class. * Automatic Differentiation context class.
*/ */
private class AutoDiffContext<T : Any, F : Field<T>>(override val context: F) : AutoDiffField<T, F>() { @PublishedApi
internal class AutoDiffContext<T : Any, F : Field<T>>(override val context: F) : AutoDiffField<T, F>() {
// this stack contains pairs of blocks and values to apply them to // this stack contains pairs of blocks and values to apply them to
private var stack = arrayOfNulls<Any?>(8) private var stack: Array<Any?> = arrayOfNulls<Any?>(8)
private var sp = 0 private var sp: Int = 0
val derivatives: MutableMap<Variable<T>, T> = hashMapOf()
internal val derivatives = HashMap<Variable<T>, T>() override val zero: Variable<T> get() = Variable(context.zero)
override val one: Variable<T> get() = Variable(context.one)
/** /**
* A variable coupled with its derivative. For internal use only * A variable coupled with its derivative. For internal use only
*/ */
private class VariableWithDeriv<T : Any>(x: T, var d: T) : Variable<T>(x) private class VariableWithDeriv<T : Any>(x: T, var d: T) : Variable<T>(x)
override fun variable(value: T): Variable<T> = override fun variable(value: T): Variable<T> =
VariableWithDeriv(value, context.zero) VariableWithDeriv(value, context.zero)
override var Variable<T>.d: T override var Variable<T>.d: T
get() = (this as? VariableWithDeriv)?.d ?: derivatives[this] ?: context.zero get() = (this as? VariableWithDeriv)?.d ?: derivatives[this] ?: context.zero
set(value) { set(value) = if (this is VariableWithDeriv) d = value else derivatives[this] = value
if (this is VariableWithDeriv) {
d = value
} else {
derivatives[this] = value
}
}
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
override fun <R> derive(value: R, block: F.(R) -> Unit): R { override fun <R> derive(value: R, block: F.(R) -> Unit): R {
@ -160,67 +154,49 @@ private class AutoDiffContext<T : Any, F : Field<T>>(override val context: F) :
// Basic math (+, -, *, /) // Basic math (+, -, *, /)
override fun add(a: Variable<T>, b: Variable<T>): Variable<T> = override fun add(a: Variable<T>, b: Variable<T>): Variable<T> = derive(variable { a.value + b.value }) { z ->
derive(variable { a.value + b.value }) { z -> a.d += z.d
a.d += z.d b.d += z.d
b.d += z.d }
}
override fun multiply(a: Variable<T>, b: Variable<T>): Variable<T> = override fun multiply(a: Variable<T>, b: Variable<T>): Variable<T> = derive(variable { a.value * b.value }) { z ->
derive(variable { a.value * b.value }) { z -> a.d += z.d * b.value
a.d += z.d * b.value b.d += z.d * a.value
b.d += z.d * a.value }
}
override fun divide(a: Variable<T>, b: Variable<T>): Variable<T> = override fun divide(a: Variable<T>, b: Variable<T>): Variable<T> = derive(variable { a.value / b.value }) { z ->
derive(variable { a.value / b.value }) { z -> a.d += z.d / b.value
a.d += z.d / b.value b.d -= z.d * a.value / (b.value * b.value)
b.d -= z.d * a.value / (b.value * b.value) }
}
override fun multiply(a: Variable<T>, k: Number): Variable<T> = override fun multiply(a: Variable<T>, k: Number): Variable<T> = derive(variable { k.toDouble() * a.value }) { z ->
derive(variable { k.toDouble() * a.value }) { z -> a.d += z.d * k.toDouble()
a.d += z.d * k.toDouble() }
}
override val zero: Variable<T> get() = Variable(context.zero)
override val one: Variable<T> get() = Variable(context.one)
} }
// Extensions for differentiation of various basic mathematical functions // Extensions for differentiation of various basic mathematical functions
// x ^ 2 // x ^ 2
fun <T : Any, F : Field<T>> AutoDiffField<T, F>.sqr(x: Variable<T>): Variable<T> = fun <T : Any, F : Field<T>> AutoDiffField<T, F>.sqr(x: Variable<T>): Variable<T> =
derive(variable { x.value * x.value }) { z -> derive(variable { x.value * x.value }) { z -> x.d += z.d * 2 * x.value }
x.d += z.d * 2 * x.value
}
// x ^ 1/2 // x ^ 1/2
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sqrt(x: Variable<T>): Variable<T> = fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sqrt(x: Variable<T>): Variable<T> =
derive(variable { sqrt(x.value) }) { z -> derive(variable { sqrt(x.value) }) { z -> x.d += z.d * 0.5 / z.value }
x.d += z.d * 0.5 / z.value
}
// x ^ y (const) // x ^ y (const)
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(x: Variable<T>, y: Double): Variable<T> = fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(x: Variable<T>, y: Double): Variable<T> =
derive(variable { power(x.value, y) }) { z -> derive(variable { power(x.value, y) }) { z -> x.d += z.d * y * power(x.value, y - 1) }
x.d += z.d * y * power(x.value, y - 1)
}
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(x: Variable<T>, y: Int): Variable<T> = pow(x, y.toDouble()) fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(x: Variable<T>, y: Int): Variable<T> = pow(x, y.toDouble())
// exp(x) // exp(x)
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.exp(x: Variable<T>): Variable<T> = fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.exp(x: Variable<T>): Variable<T> =
derive(variable { exp(x.value) }) { z -> derive(variable { exp(x.value) }) { z -> x.d += z.d * z.value }
x.d += z.d * z.value
}
// ln(x) // ln(x)
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.ln(x: Variable<T>): Variable<T> = derive( fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.ln(x: Variable<T>): Variable<T> =
variable { ln(x.value) } derive(variable { ln(x.value) }) { z -> x.d += z.d / x.value }
) { z ->
x.d += z.d / x.value
}
// x ^ y (any) // x ^ y (any)
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(x: Variable<T>, y: Variable<T>): Variable<T> = fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(x: Variable<T>, y: Variable<T>): Variable<T> =
@ -228,12 +204,8 @@ fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(x: Variable<T>, y: V
// sin(x) // sin(x)
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sin(x: Variable<T>): Variable<T> = fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sin(x: Variable<T>): Variable<T> =
derive(variable { sin(x.value) }) { z -> derive(variable { sin(x.value) }) { z -> x.d += z.d * cos(x.value) }
x.d += z.d * cos(x.value)
}
// cos(x) // cos(x)
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.cos(x: Variable<T>): Variable<T> = fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.cos(x: Variable<T>): Variable<T> =
derive(variable { cos(x.value) }) { z -> derive(variable { cos(x.value) }) { z -> x.d -= z.d * sin(x.value) }
x.d -= z.d * sin(x.value)
}

View File

@ -41,6 +41,6 @@ fun ClosedFloatingPointRange<Double>.toSequenceWithPoints(numPoints: Int): Seque
*/ */
@Deprecated("Replace by 'toSequenceWithPoints'") @Deprecated("Replace by 'toSequenceWithPoints'")
fun ClosedFloatingPointRange<Double>.toGrid(numPoints: Int): DoubleArray { fun ClosedFloatingPointRange<Double>.toGrid(numPoints: Int): DoubleArray {
if (numPoints < 2) error("Can't create generic grid with less than two points") require(numPoints >= 2) { "Can't create generic grid with less than two points" }
return DoubleArray(numPoints) { i -> start + (endInclusive - start) / (numPoints - 1) * i } return DoubleArray(numPoints) { i -> start + (endInclusive - start) / (numPoints - 1) * i }
} }

View File

@ -2,6 +2,8 @@ package scientifik.kmath.misc
import scientifik.kmath.operations.Space import scientifik.kmath.operations.Space
import scientifik.kmath.operations.invoke import scientifik.kmath.operations.invoke
import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.contract
import kotlin.jvm.JvmName import kotlin.jvm.JvmName
/** /**
@ -11,67 +13,68 @@ import kotlin.jvm.JvmName
* @param R the type of resulting iterable. * @param R the type of resulting iterable.
* @param initial lazy evaluated. * @param initial lazy evaluated.
*/ */
fun <T, R> Iterator<T>.cumulative(initial: R, operation: (R, T) -> R): Iterator<R> = object : Iterator<R> { inline fun <T, R> Iterator<T>.cumulative(initial: R, crossinline operation: (R, T) -> R): Iterator<R> {
var state: R = initial contract { callsInPlace(operation) }
override fun hasNext(): Boolean = this@cumulative.hasNext()
override fun next(): R { return object : Iterator<R> {
state = operation(state, this@cumulative.next()) var state: R = initial
return state
override fun hasNext(): Boolean = this@cumulative.hasNext()
override fun next(): R {
state = operation(state, this@cumulative.next())
return state
}
} }
} }
fun <T, R> Iterable<T>.cumulative(initial: R, operation: (R, T) -> R): Iterable<R> = object : Iterable<R> { inline fun <T, R> Iterable<T>.cumulative(initial: R, crossinline operation: (R, T) -> R): Iterable<R> =
override fun iterator(): Iterator<R> = this@cumulative.iterator().cumulative(initial, operation) Iterable { this@cumulative.iterator().cumulative(initial, operation) }
}
fun <T, R> Sequence<T>.cumulative(initial: R, operation: (R, T) -> R): Sequence<R> = object : Sequence<R> { inline fun <T, R> Sequence<T>.cumulative(initial: R, crossinline operation: (R, T) -> R): Sequence<R> = Sequence {
override fun iterator(): Iterator<R> = this@cumulative.iterator().cumulative(initial, operation) this@cumulative.iterator().cumulative(initial, operation)
} }
fun <T, R> List<T>.cumulative(initial: R, operation: (R, T) -> R): List<R> = fun <T, R> List<T>.cumulative(initial: R, operation: (R, T) -> R): List<R> =
this.iterator().cumulative(initial, operation).asSequence().toList() iterator().cumulative(initial, operation).asSequence().toList()
//Cumulative sum //Cumulative sum
/** /**
* Cumulative sum with custom space * Cumulative sum with custom space
*/ */
fun <T> Iterable<T>.cumulativeSum(space: Space<T>): Iterable<T> = space { fun <T> Iterable<T>.cumulativeSum(space: Space<T>): Iterable<T> =
cumulative(zero) { element: T, sum: T -> sum + element } space { cumulative(zero) { element: T, sum: T -> sum + element } }
}
@JvmName("cumulativeSumOfDouble") @JvmName("cumulativeSumOfDouble")
fun Iterable<Double>.cumulativeSum(): Iterable<Double> = this.cumulative(0.0) { element, sum -> sum + element } fun Iterable<Double>.cumulativeSum(): Iterable<Double> = cumulative(0.0) { element, sum -> sum + element }
@JvmName("cumulativeSumOfInt") @JvmName("cumulativeSumOfInt")
fun Iterable<Int>.cumulativeSum(): Iterable<Int> = this.cumulative(0) { element, sum -> sum + element } fun Iterable<Int>.cumulativeSum(): Iterable<Int> = cumulative(0) { element, sum -> sum + element }
@JvmName("cumulativeSumOfLong") @JvmName("cumulativeSumOfLong")
fun Iterable<Long>.cumulativeSum(): Iterable<Long> = this.cumulative(0L) { element, sum -> sum + element } fun Iterable<Long>.cumulativeSum(): Iterable<Long> = cumulative(0L) { element, sum -> sum + element }
fun <T> Sequence<T>.cumulativeSum(space: Space<T>): Sequence<T> = with(space) { fun <T> Sequence<T>.cumulativeSum(space: Space<T>): Sequence<T> =
cumulative(zero) { element: T, sum: T -> sum + element } space { cumulative(zero) { element: T, sum: T -> sum + element } }
}
@JvmName("cumulativeSumOfDouble") @JvmName("cumulativeSumOfDouble")
fun Sequence<Double>.cumulativeSum(): Sequence<Double> = this.cumulative(0.0) { element, sum -> sum + element } fun Sequence<Double>.cumulativeSum(): Sequence<Double> = cumulative(0.0) { element, sum -> sum + element }
@JvmName("cumulativeSumOfInt") @JvmName("cumulativeSumOfInt")
fun Sequence<Int>.cumulativeSum(): Sequence<Int> = this.cumulative(0) { element, sum -> sum + element } fun Sequence<Int>.cumulativeSum(): Sequence<Int> = cumulative(0) { element, sum -> sum + element }
@JvmName("cumulativeSumOfLong") @JvmName("cumulativeSumOfLong")
fun Sequence<Long>.cumulativeSum(): Sequence<Long> = this.cumulative(0L) { element, sum -> sum + element } fun Sequence<Long>.cumulativeSum(): Sequence<Long> = cumulative(0L) { element, sum -> sum + element }
fun <T> List<T>.cumulativeSum(space: Space<T>): List<T> = with(space) { fun <T> List<T>.cumulativeSum(space: Space<T>): List<T> =
cumulative(zero) { element: T, sum: T -> sum + element } space { cumulative(zero) { element: T, sum: T -> sum + element } }
}
@JvmName("cumulativeSumOfDouble") @JvmName("cumulativeSumOfDouble")
fun List<Double>.cumulativeSum(): List<Double> = this.cumulative(0.0) { element, sum -> sum + element } fun List<Double>.cumulativeSum(): List<Double> = cumulative(0.0) { element, sum -> sum + element }
@JvmName("cumulativeSumOfInt") @JvmName("cumulativeSumOfInt")
fun List<Int>.cumulativeSum(): List<Int> = this.cumulative(0) { element, sum -> sum + element } fun List<Int>.cumulativeSum(): List<Int> = cumulative(0) { element, sum -> sum + element }
@JvmName("cumulativeSumOfLong") @JvmName("cumulativeSumOfLong")
fun List<Long>.cumulativeSum(): List<Long> = this.cumulative(0L) { element, sum -> sum + element } fun List<Long>.cumulativeSum(): List<Long> = cumulative(0L) { element, sum -> sum + element }

View File

@ -3,12 +3,13 @@ package scientifik.kmath.operations
import scientifik.kmath.operations.BigInt.Companion.BASE import scientifik.kmath.operations.BigInt.Companion.BASE
import scientifik.kmath.operations.BigInt.Companion.BASE_SIZE import scientifik.kmath.operations.BigInt.Companion.BASE_SIZE
import scientifik.kmath.structures.* import scientifik.kmath.structures.*
import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.contract
import kotlin.math.log2 import kotlin.math.log2
import kotlin.math.max import kotlin.math.max
import kotlin.math.min import kotlin.math.min
import kotlin.math.sign import kotlin.math.sign
typealias Magnitude = UIntArray typealias Magnitude = UIntArray
typealias TBase = ULong typealias TBase = ULong
@ -22,8 +23,9 @@ object BigIntField : Field<BigInt> {
override val one: BigInt = BigInt.ONE override val one: BigInt = BigInt.ONE
override fun add(a: BigInt, b: BigInt): BigInt = a.plus(b) override fun add(a: BigInt, b: BigInt): BigInt = a.plus(b)
override fun number(value: Number): BigInt = value.toLong().toBigInt()
override fun multiply(a: BigInt, k: Number): BigInt = a.times(k.toLong()) override fun multiply(a: BigInt, k: Number): BigInt = a.times(number(k))
override fun multiply(a: BigInt, b: BigInt): BigInt = a.times(b) override fun multiply(a: BigInt, b: BigInt): BigInt = a.times(b)
@ -430,8 +432,8 @@ fun ULong.toBigInt(): BigInt = BigInt(
* Create a [BigInt] with this array of magnitudes with protective copy * Create a [BigInt] with this array of magnitudes with protective copy
*/ */
fun UIntArray.toBigInt(sign: Byte): BigInt { fun UIntArray.toBigInt(sign: Byte): BigInt {
if (sign == 0.toByte() && isNotEmpty()) error("") require(sign != 0.toByte() || !isNotEmpty())
return BigInt(sign, this.copyOf()) return BigInt(sign, copyOf())
} }
val hexChToInt: MutableMap<Char, Int> = hashMapOf( val hexChToInt: MutableMap<Char, Int> = hashMapOf(
@ -484,11 +486,15 @@ fun String.parseBigInteger(): BigInt? {
return res * sign return res * sign
} }
inline fun Buffer.Companion.bigInt(size: Int, initializer: (Int) -> BigInt): Buffer<BigInt> = inline fun Buffer.Companion.bigInt(size: Int, initializer: (Int) -> BigInt): Buffer<BigInt> {
boxing(size, initializer) contract { callsInPlace(initializer) }
return boxing(size, initializer)
}
inline fun MutableBuffer.Companion.bigInt(size: Int, initializer: (Int) -> BigInt): MutableBuffer<BigInt> = inline fun MutableBuffer.Companion.bigInt(size: Int, initializer: (Int) -> BigInt): MutableBuffer<BigInt> {
boxing(size, initializer) contract { callsInPlace(initializer) }
return boxing(size, initializer)
}
fun NDAlgebra.Companion.bigInt(vararg shape: Int): BoxingNDRing<BigInt, BigIntField> = fun NDAlgebra.Companion.bigInt(vararg shape: Int): BoxingNDRing<BigInt, BigIntField> =
BoxingNDRing(shape, BigIntField, Buffer.Companion::bigInt) BoxingNDRing(shape, BigIntField, Buffer.Companion::bigInt)
@ -496,5 +502,4 @@ fun NDAlgebra.Companion.bigInt(vararg shape: Int): BoxingNDRing<BigInt, BigIntFi
fun NDElement.Companion.bigInt( fun NDElement.Companion.bigInt(
vararg shape: Int, vararg shape: Int,
initializer: BigIntField.(IntArray) -> BigInt initializer: BigIntField.(IntArray) -> BigInt
): BufferedNDRingElement<BigInt, BigIntField> = ): BufferedNDRingElement<BigInt, BigIntField> = NDAlgebra.bigInt(*shape).produce(initializer)
NDAlgebra.bigInt(*shape).produce(initializer)

View File

@ -6,17 +6,45 @@ import scientifik.kmath.structures.MutableBuffer
import scientifik.memory.MemoryReader import scientifik.memory.MemoryReader
import scientifik.memory.MemorySpec import scientifik.memory.MemorySpec
import scientifik.memory.MemoryWriter import scientifik.memory.MemoryWriter
import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.contract
import kotlin.math.* import kotlin.math.*
/**
* This complex's conjugate.
*/
val Complex.conjugate: Complex
get() = Complex(re, -im)
/**
* This complex's reciprocal.
*/
val Complex.reciprocal: Complex
get() {
val scale = re * re + im * im
return Complex(re / scale, -im / scale)
}
/**
* Absolute value of complex number.
*/
val Complex.r: Double
get() = sqrt(re * re + im * im)
/**
* An angle between vector represented by complex number and X axis.
*/
val Complex.theta: Double
get() = atan(im / re)
private val PI_DIV_2 = Complex(PI / 2, 0) private val PI_DIV_2 = Complex(PI / 2, 0)
/** /**
* A field of [Complex]. * A field of [Complex].
*/ */
object ComplexField : ExtendedField<Complex> { object ComplexField : ExtendedField<Complex>, Norm<Complex, Complex> {
override val zero: Complex = Complex(0.0, 0.0) override val zero: Complex = 0.0.toComplex()
override val one: Complex = 1.0.toComplex()
override val one: Complex = Complex(1.0, 0.0)
/** /**
* The imaginary unit. * The imaginary unit.
@ -30,19 +58,53 @@ object ComplexField : ExtendedField<Complex> {
override fun multiply(a: Complex, b: Complex): Complex = override fun multiply(a: Complex, b: Complex): Complex =
Complex(a.re * b.re - a.im * b.im, a.re * b.im + a.im * b.re) Complex(a.re * b.re - a.im * b.im, a.re * b.im + a.im * b.re)
override fun divide(a: Complex, b: Complex): Complex { override fun divide(a: Complex, b: Complex): Complex = when {
val norm = b.re * b.re + b.im * b.im b.re.isNaN() || b.im.isNaN() -> Complex(Double.NaN, Double.NaN)
return Complex((a.re * b.re + a.im * b.im) / norm, (a.re * b.im - a.im * b.re) / norm)
(if (b.im < 0) -b.im else +b.im) < (if (b.re < 0) -b.re else +b.re) -> {
val wr = b.im / b.re
val wd = b.re + wr * b.im
if (wd.isNaN() || wd == 0.0)
Complex(Double.NaN, Double.NaN)
else
Complex((a.re + a.im * wr) / wd, (a.im - a.re * wr) / wd)
}
b.im == 0.0 -> Complex(Double.NaN, Double.NaN)
else -> {
val wr = b.re / b.im
val wd = b.im + wr * b.re
if (wd.isNaN() || wd == 0.0)
Complex(Double.NaN, Double.NaN)
else
Complex((a.re * wr + a.im) / wd, (a.im * wr - a.re) / wd)
}
} }
override fun sin(arg: Complex): Complex = i * (exp(-i * arg) - exp(i * arg)) / 2 override fun sin(arg: Complex): Complex = i * (exp(-i * arg) - exp(i * arg)) / 2
override fun cos(arg: Complex): Complex = (exp(-i * arg) + exp(i * arg)) / 2 override fun cos(arg: Complex): Complex = (exp(-i * arg) + exp(i * arg)) / 2
override fun asin(arg: Complex): Complex = -i * ln(sqrt(one - arg pow 2) + i * arg)
override fun acos(arg: Complex): Complex = PI_DIV_2 + i * ln(sqrt(one - arg pow 2) + i * arg)
override fun atan(arg: Complex): Complex = i * (ln(one - i * arg) - ln(one + i * arg)) / 2
override fun power(arg: Complex, pow: Number): Complex = override fun tan(arg: Complex): Complex {
arg.r.pow(pow.toDouble()) * (cos(pow.toDouble() * arg.theta) + i * sin(pow.toDouble() * arg.theta)) val e1 = exp(-i * arg)
val e2 = exp(i * arg)
return i * (e1 - e2) / (e1 + e2)
}
override fun asin(arg: Complex): Complex = -i * ln(sqrt(1 - (arg * arg)) + i * arg)
override fun acos(arg: Complex): Complex = PI_DIV_2 + i * ln(sqrt(1 - (arg * arg)) + i * arg)
override fun atan(arg: Complex): Complex {
val iArg = i * arg
return i * (ln(1 - iArg) - ln(1 + iArg)) / 2
}
override fun power(arg: Complex, pow: Number): Complex = if (arg.im == 0.0)
arg.re.pow(pow.toDouble()).toComplex()
else
exp(pow * ln(arg))
override fun exp(arg: Complex): Complex = exp(arg.re) * (cos(arg.im) + i * sin(arg.im)) override fun exp(arg: Complex): Complex = exp(arg.re) * (cos(arg.im) + i * sin(arg.im))
@ -93,6 +155,8 @@ object ComplexField : ExtendedField<Complex> {
*/ */
operator fun Double.times(c: Complex): Complex = Complex(c.re * this, c.im * this) operator fun Double.times(c: Complex): Complex = Complex(c.re * this, c.im * this)
override fun norm(arg: Complex): Complex = sqrt(arg.conjugate * arg)
override fun symbol(value: String): Complex = if (value == "i") i else super.symbol(value) override fun symbol(value: String): Complex = if (value == "i") i else super.symbol(value)
} }
@ -105,12 +169,12 @@ object ComplexField : ExtendedField<Complex> {
data class Complex(val re: Double, val im: Double) : FieldElement<Complex, Complex, ComplexField>, Comparable<Complex> { data class Complex(val re: Double, val im: Double) : FieldElement<Complex, Complex, ComplexField>, Comparable<Complex> {
constructor(re: Number, im: Number) : this(re.toDouble(), im.toDouble()) constructor(re: Number, im: Number) : this(re.toDouble(), im.toDouble())
override val context: ComplexField get() = ComplexField
override fun unwrap(): Complex = this override fun unwrap(): Complex = this
override fun Complex.wrap(): Complex = this override fun Complex.wrap(): Complex = this
override val context: ComplexField get() = ComplexField
override fun compareTo(other: Complex): Int = r.compareTo(other.r) override fun compareTo(other: Complex): Int = r.compareTo(other.r)
companion object : MemorySpec<Complex> { companion object : MemorySpec<Complex> {
@ -126,33 +190,20 @@ data class Complex(val re: Double, val im: Double) : FieldElement<Complex, Compl
} }
} }
/**
* A complex conjugate
*/
val Complex.conjugate: Complex get() = Complex(re, -im)
/**
* Absolute value of complex number
*/
val Complex.r: Double get() = sqrt(re * re + im * im)
/**
* An angle between vector represented by complex number and X axis
*/
val Complex.theta: Double get() = atan(im / re)
/** /**
* Creates a complex number with real part equal to this real. * Creates a complex number with real part equal to this real.
* *
* @receiver the real part. * @receiver the real part.
* @return the new complex number. * @return the new complex number.
*/ */
fun Double.toComplex(): Complex = Complex(this, 0.0) fun Number.toComplex(): Complex = Complex(this, 0.0)
inline fun Buffer.Companion.complex(size: Int, crossinline init: (Int) -> Complex): Buffer<Complex> { inline fun Buffer.Companion.complex(size: Int, crossinline init: (Int) -> Complex): Buffer<Complex> {
contract { callsInPlace(init) }
return MemoryBuffer.create(Complex, size, init) return MemoryBuffer.create(Complex, size, init)
} }
inline fun MutableBuffer.Companion.complex(size: Int, crossinline init: (Int) -> Complex): Buffer<Complex> { inline fun MutableBuffer.Companion.complex(size: Int, crossinline init: (Int) -> Complex): Buffer<Complex> {
contract { callsInPlace(init) }
return MemoryBuffer.create(Complex, size, init) return MemoryBuffer.create(Complex, size, init)
} }

View File

@ -1,5 +1,6 @@
package scientifik.kmath.operations package scientifik.kmath.operations
import scientifik.kmath.operations.RealField.pow
import kotlin.math.abs import kotlin.math.abs
import kotlin.math.pow as kpow import kotlin.math.pow as kpow
@ -7,19 +8,28 @@ import kotlin.math.pow as kpow
* Advanced Number-like semifield that implements basic operations. * Advanced Number-like semifield that implements basic operations.
*/ */
interface ExtendedFieldOperations<T> : interface ExtendedFieldOperations<T> :
InverseTrigonometricOperations<T>, FieldOperations<T>,
TrigonometricOperations<T>,
HyperbolicOperations<T>,
PowerOperations<T>, PowerOperations<T>,
ExponentialOperations<T> { ExponentialOperations<T> {
override fun tan(arg: T): T = sin(arg) / cos(arg) override fun tan(arg: T): T = sin(arg) / cos(arg)
override fun tanh(arg: T): T = sinh(arg) / cosh(arg)
override fun unaryOperation(operation: String, arg: T): T = when (operation) { override fun unaryOperation(operation: String, arg: T): T = when (operation) {
TrigonometricOperations.COS_OPERATION -> cos(arg) TrigonometricOperations.COS_OPERATION -> cos(arg)
TrigonometricOperations.SIN_OPERATION -> sin(arg) TrigonometricOperations.SIN_OPERATION -> sin(arg)
TrigonometricOperations.TAN_OPERATION -> tan(arg) TrigonometricOperations.TAN_OPERATION -> tan(arg)
InverseTrigonometricOperations.ACOS_OPERATION -> acos(arg) TrigonometricOperations.ACOS_OPERATION -> acos(arg)
InverseTrigonometricOperations.ASIN_OPERATION -> asin(arg) TrigonometricOperations.ASIN_OPERATION -> asin(arg)
InverseTrigonometricOperations.ATAN_OPERATION -> atan(arg) TrigonometricOperations.ATAN_OPERATION -> atan(arg)
HyperbolicOperations.COSH_OPERATION -> cosh(arg)
HyperbolicOperations.SINH_OPERATION -> sinh(arg)
HyperbolicOperations.TANH_OPERATION -> tanh(arg)
HyperbolicOperations.ACOSH_OPERATION -> acosh(arg)
HyperbolicOperations.ASINH_OPERATION -> asinh(arg)
HyperbolicOperations.ATANH_OPERATION -> atanh(arg)
PowerOperations.SQRT_OPERATION -> sqrt(arg) PowerOperations.SQRT_OPERATION -> sqrt(arg)
ExponentialOperations.EXP_OPERATION -> exp(arg) ExponentialOperations.EXP_OPERATION -> exp(arg)
ExponentialOperations.LN_OPERATION -> ln(arg) ExponentialOperations.LN_OPERATION -> ln(arg)
@ -32,6 +42,13 @@ interface ExtendedFieldOperations<T> :
* Advanced Number-like field that implements basic operations. * Advanced Number-like field that implements basic operations.
*/ */
interface ExtendedField<T> : ExtendedFieldOperations<T>, Field<T> { interface ExtendedField<T> : ExtendedFieldOperations<T>, Field<T> {
override fun sinh(arg: T): T = (exp(arg) - exp(-arg)) / 2
override fun cosh(arg: T): T = (exp(arg) + exp(-arg)) / 2
override fun tanh(arg: T): T = (exp(arg) - exp(-arg)) / (exp(-arg) + exp(arg))
override fun asinh(arg: T): T = ln(sqrt(arg * arg + one) + arg)
override fun acosh(arg: T): T = ln(arg + sqrt((arg - one) * (arg + one)))
override fun atanh(arg: T): T = (ln(arg + one) - ln(one - arg)) / 2
override fun rightSideNumberOperation(operation: String, left: T, right: Number): T = when (operation) { override fun rightSideNumberOperation(operation: String, left: T, right: Number): T = when (operation) {
PowerOperations.POW_OPERATION -> power(left, right) PowerOperations.POW_OPERATION -> power(left, right)
else -> super.rightSideNumberOperation(operation, left, right) else -> super.rightSideNumberOperation(operation, left, right)
@ -46,12 +63,13 @@ interface ExtendedField<T> : ExtendedFieldOperations<T>, Field<T> {
* TODO inline does not work due to compiler bug. Waiting for fix for KT-27586 * TODO inline does not work due to compiler bug. Waiting for fix for KT-27586
*/ */
inline class Real(val value: Double) : FieldElement<Double, Real, RealField> { inline class Real(val value: Double) : FieldElement<Double, Real, RealField> {
override val context: RealField
get() = RealField
override fun unwrap(): Double = value override fun unwrap(): Double = value
override fun Double.wrap(): Real = Real(value) override fun Double.wrap(): Real = Real(value)
override val context: RealField get() = RealField
companion object companion object
} }
@ -60,12 +78,22 @@ inline class Real(val value: Double) : FieldElement<Double, Real, RealField> {
*/ */
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
object RealField : ExtendedField<Double>, Norm<Double, Double> { object RealField : ExtendedField<Double>, Norm<Double, Double> {
override val zero: Double = 0.0 override val zero: Double
get() = 0.0
override val one: Double
get() = 1.0
override fun binaryOperation(operation: String, left: Double, right: Double): Double = when (operation) {
PowerOperations.POW_OPERATION -> left pow right
else -> super.binaryOperation(operation, left, right)
}
override inline fun add(a: Double, b: Double): Double = a + b override inline fun add(a: Double, b: Double): Double = a + b
override inline fun multiply(a: Double, b: Double): Double = a * b
override inline fun multiply(a: Double, k: Number): Double = a * k.toDouble() override inline fun multiply(a: Double, k: Number): Double = a * k.toDouble()
override val one: Double = 1.0 override inline fun multiply(a: Double, b: Double): Double = a * b
override inline fun divide(a: Double, b: Double): Double = a / b override inline fun divide(a: Double, b: Double): Double = a / b
override inline fun sin(arg: Double): Double = kotlin.math.sin(arg) override inline fun sin(arg: Double): Double = kotlin.math.sin(arg)
@ -75,27 +103,24 @@ object RealField : ExtendedField<Double>, Norm<Double, Double> {
override inline fun asin(arg: Double): Double = kotlin.math.asin(arg) override inline fun asin(arg: Double): Double = kotlin.math.asin(arg)
override inline fun atan(arg: Double): Double = kotlin.math.atan(arg) override inline fun atan(arg: Double): Double = kotlin.math.atan(arg)
override inline fun power(arg: Double, pow: Number): Double = arg.kpow(pow.toDouble()) override inline fun sinh(arg: Double): Double = kotlin.math.sinh(arg)
override inline fun cosh(arg: Double): Double = kotlin.math.cosh(arg)
override inline fun tanh(arg: Double): Double = kotlin.math.tanh(arg)
override inline fun asinh(arg: Double): Double = kotlin.math.asinh(arg)
override inline fun acosh(arg: Double): Double = kotlin.math.acosh(arg)
override inline fun atanh(arg: Double): Double = kotlin.math.atanh(arg)
override inline fun power(arg: Double, pow: Number): Double = arg.kpow(pow.toDouble())
override inline fun exp(arg: Double): Double = kotlin.math.exp(arg) override inline fun exp(arg: Double): Double = kotlin.math.exp(arg)
override inline fun ln(arg: Double): Double = kotlin.math.ln(arg) override inline fun ln(arg: Double): Double = kotlin.math.ln(arg)
override inline fun norm(arg: Double): Double = abs(arg) override inline fun norm(arg: Double): Double = abs(arg)
override inline fun Double.unaryMinus(): Double = -this override inline fun Double.unaryMinus(): Double = -this
override inline fun Double.plus(b: Double): Double = this + b override inline fun Double.plus(b: Double): Double = this + b
override inline fun Double.minus(b: Double): Double = this - b override inline fun Double.minus(b: Double): Double = this - b
override inline fun Double.times(b: Double): Double = this * b override inline fun Double.times(b: Double): Double = this * b
override inline fun Double.div(b: Double): Double = this / b override inline fun Double.div(b: Double): Double = this / b
override fun binaryOperation(operation: String, left: Double, right: Double): Double = when (operation) {
PowerOperations.POW_OPERATION -> left pow right
else -> super.binaryOperation(operation, left, right)
}
} }
/** /**
@ -103,12 +128,22 @@ object RealField : ExtendedField<Double>, Norm<Double, Double> {
*/ */
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
object FloatField : ExtendedField<Float>, Norm<Float, Float> { object FloatField : ExtendedField<Float>, Norm<Float, Float> {
override val zero: Float = 0f override val zero: Float
get() = 0.0f
override val one: Float
get() = 1.0f
override fun binaryOperation(operation: String, left: Float, right: Float): Float = when (operation) {
PowerOperations.POW_OPERATION -> left pow right
else -> super.binaryOperation(operation, left, right)
}
override inline fun add(a: Float, b: Float): Float = a + b override inline fun add(a: Float, b: Float): Float = a + b
override inline fun multiply(a: Float, b: Float): Float = a * b
override inline fun multiply(a: Float, k: Number): Float = a * k.toFloat() override inline fun multiply(a: Float, k: Number): Float = a * k.toFloat()
override val one: Float = 1f override inline fun multiply(a: Float, b: Float): Float = a * b
override inline fun divide(a: Float, b: Float): Float = a / b override inline fun divide(a: Float, b: Float): Float = a / b
override inline fun sin(arg: Float): Float = kotlin.math.sin(arg) override inline fun sin(arg: Float): Float = kotlin.math.sin(arg)
@ -118,108 +153,118 @@ object FloatField : ExtendedField<Float>, Norm<Float, Float> {
override inline fun asin(arg: Float): Float = kotlin.math.asin(arg) override inline fun asin(arg: Float): Float = kotlin.math.asin(arg)
override inline fun atan(arg: Float): Float = kotlin.math.atan(arg) override inline fun atan(arg: Float): Float = kotlin.math.atan(arg)
override inline fun power(arg: Float, pow: Number): Float = arg.pow(pow.toFloat()) override inline fun sinh(arg: Float): Float = kotlin.math.sinh(arg)
override inline fun cosh(arg: Float): Float = kotlin.math.cosh(arg)
override inline fun tanh(arg: Float): Float = kotlin.math.tanh(arg)
override inline fun asinh(arg: Float): Float = kotlin.math.asinh(arg)
override inline fun acosh(arg: Float): Float = kotlin.math.acosh(arg)
override inline fun atanh(arg: Float): Float = kotlin.math.atanh(arg)
override inline fun power(arg: Float, pow: Number): Float = arg.kpow(pow.toFloat())
override inline fun exp(arg: Float): Float = kotlin.math.exp(arg) override inline fun exp(arg: Float): Float = kotlin.math.exp(arg)
override inline fun ln(arg: Float): Float = kotlin.math.ln(arg) override inline fun ln(arg: Float): Float = kotlin.math.ln(arg)
override inline fun norm(arg: Float): Float = abs(arg) override inline fun norm(arg: Float): Float = abs(arg)
override inline fun Float.unaryMinus(): Float = -this override inline fun Float.unaryMinus(): Float = -this
override inline fun Float.plus(b: Float): Float = this + b override inline fun Float.plus(b: Float): Float = this + b
override inline fun Float.minus(b: Float): Float = this - b override inline fun Float.minus(b: Float): Float = this - b
override inline fun Float.times(b: Float): Float = this * b override inline fun Float.times(b: Float): Float = this * b
override inline fun Float.div(b: Float): Float = this / b override inline fun Float.div(b: Float): Float = this / b
} }
/** /**
* A field for [Int] without boxing. Does not produce corresponding field element * A field for [Int] without boxing. Does not produce corresponding ring element.
*/ */
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
object IntRing : Ring<Int>, Norm<Int, Int> { object IntRing : Ring<Int>, Norm<Int, Int> {
override val zero: Int = 0 override val zero: Int
get() = 0
override val one: Int
get() = 1
override inline fun add(a: Int, b: Int): Int = a + b override inline fun add(a: Int, b: Int): Int = a + b
override inline fun multiply(a: Int, b: Int): Int = a * b
override inline fun multiply(a: Int, k: Number): Int = k.toInt() * a override inline fun multiply(a: Int, k: Number): Int = k.toInt() * a
override val one: Int = 1
override inline fun multiply(a: Int, b: Int): Int = a * b
override inline fun norm(arg: Int): Int = abs(arg) override inline fun norm(arg: Int): Int = abs(arg)
override inline fun Int.unaryMinus(): Int = -this override inline fun Int.unaryMinus(): Int = -this
override inline fun Int.plus(b: Int): Int = this + b override inline fun Int.plus(b: Int): Int = this + b
override inline fun Int.minus(b: Int): Int = this - b override inline fun Int.minus(b: Int): Int = this - b
override inline fun Int.times(b: Int): Int = this * b override inline fun Int.times(b: Int): Int = this * b
} }
/** /**
* A field for [Short] without boxing. Does not produce appropriate field element * A field for [Short] without boxing. Does not produce appropriate ring element.
*/ */
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
object ShortRing : Ring<Short>, Norm<Short, Short> { object ShortRing : Ring<Short>, Norm<Short, Short> {
override val zero: Short = 0 override val zero: Short
get() = 0
override val one: Short
get() = 1
override inline fun add(a: Short, b: Short): Short = (a + b).toShort() override inline fun add(a: Short, b: Short): Short = (a + b).toShort()
override inline fun multiply(a: Short, b: Short): Short = (a * b).toShort()
override inline fun multiply(a: Short, k: Number): Short = (a * k.toShort()).toShort() override inline fun multiply(a: Short, k: Number): Short = (a * k.toShort()).toShort()
override val one: Short = 1
override inline fun multiply(a: Short, b: Short): Short = (a * b).toShort()
override fun norm(arg: Short): Short = if (arg > 0) arg else (-arg).toShort() override fun norm(arg: Short): Short = if (arg > 0) arg else (-arg).toShort()
override inline fun Short.unaryMinus(): Short = (-this).toShort() override inline fun Short.unaryMinus(): Short = (-this).toShort()
override inline fun Short.plus(b: Short): Short = (this + b).toShort() override inline fun Short.plus(b: Short): Short = (this + b).toShort()
override inline fun Short.minus(b: Short): Short = (this - b).toShort() override inline fun Short.minus(b: Short): Short = (this - b).toShort()
override inline fun Short.times(b: Short): Short = (this * b).toShort() override inline fun Short.times(b: Short): Short = (this * b).toShort()
} }
/** /**
* A field for [Byte] values * A field for [Byte] without boxing. Does not produce appropriate ring element.
*/ */
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
object ByteRing : Ring<Byte>, Norm<Byte, Byte> { object ByteRing : Ring<Byte>, Norm<Byte, Byte> {
override val zero: Byte = 0 override val zero: Byte
get() = 0
override val one: Byte
get() = 1
override inline fun add(a: Byte, b: Byte): Byte = (a + b).toByte() override inline fun add(a: Byte, b: Byte): Byte = (a + b).toByte()
override inline fun multiply(a: Byte, b: Byte): Byte = (a * b).toByte()
override inline fun multiply(a: Byte, k: Number): Byte = (a * k.toByte()).toByte() override inline fun multiply(a: Byte, k: Number): Byte = (a * k.toByte()).toByte()
override val one: Byte = 1
override inline fun multiply(a: Byte, b: Byte): Byte = (a * b).toByte()
override fun norm(arg: Byte): Byte = if (arg > 0) arg else (-arg).toByte() override fun norm(arg: Byte): Byte = if (arg > 0) arg else (-arg).toByte()
override inline fun Byte.unaryMinus(): Byte = (-this).toByte() override inline fun Byte.unaryMinus(): Byte = (-this).toByte()
override inline fun Byte.plus(b: Byte): Byte = (this + b).toByte() override inline fun Byte.plus(b: Byte): Byte = (this + b).toByte()
override inline fun Byte.minus(b: Byte): Byte = (this - b).toByte() override inline fun Byte.minus(b: Byte): Byte = (this - b).toByte()
override inline fun Byte.times(b: Byte): Byte = (this * b).toByte() override inline fun Byte.times(b: Byte): Byte = (this * b).toByte()
} }
/** /**
* A field for [Long] values * A field for [Double] without boxing. Does not produce appropriate ring element.
*/ */
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
object LongRing : Ring<Long>, Norm<Long, Long> { object LongRing : Ring<Long>, Norm<Long, Long> {
override val zero: Long = 0 override val zero: Long
override inline fun add(a: Long, b: Long): Long = (a + b) get() = 0
override inline fun multiply(a: Long, b: Long): Long = (a * b)
override val one: Long
get() = 1
override inline fun add(a: Long, b: Long): Long = a + b
override inline fun multiply(a: Long, k: Number): Long = a * k.toLong() override inline fun multiply(a: Long, k: Number): Long = a * k.toLong()
override val one: Long = 1
override inline fun multiply(a: Long, b: Long): Long = a * b
override fun norm(arg: Long): Long = abs(arg) override fun norm(arg: Long): Long = abs(arg)
override inline fun Long.unaryMinus(): Long = (-this) override inline fun Long.unaryMinus(): Long = (-this)
override inline fun Long.plus(b: Long): Long = (this + b) override inline fun Long.plus(b: Long): Long = (this + b)
override inline fun Long.minus(b: Long): Long = (this - b) override inline fun Long.minus(b: Long): Long = (this - b)
override inline fun Long.times(b: Long): Long = (this * b) override inline fun Long.times(b: Long): Long = (this * b)
} }

View File

@ -1,12 +1,11 @@
package scientifik.kmath.operations package scientifik.kmath.operations
/** /**
* A container for trigonometric operations for specific type. They are limited to semifields. * A container for trigonometric operations for specific type.
* *
* The operations are not exposed to class directly to avoid method bloat but instead are declared in the field. * @param T the type of element of this structure.
* It also allows to override behavior for optional operations.
*/ */
interface TrigonometricOperations<T> : FieldOperations<T> { interface TrigonometricOperations<T> : Algebra<T> {
/** /**
* Computes the sine of [arg]. * Computes the sine of [arg].
*/ */
@ -22,31 +21,6 @@ interface TrigonometricOperations<T> : FieldOperations<T> {
*/ */
fun tan(arg: T): T fun tan(arg: T): T
companion object {
/**
* The identifier of sine.
*/
const val SIN_OPERATION: String = "sin"
/**
* The identifier of cosine.
*/
const val COS_OPERATION: String = "cos"
/**
* The identifier of tangent.
*/
const val TAN_OPERATION: String = "tan"
}
}
/**
* A container for inverse trigonometric operations for specific type. They are limited to semifields.
*
* The operations are not exposed to class directly to avoid method bloat but instead are declared in the field.
* It also allows to override behavior for optional operations.
*/
interface InverseTrigonometricOperations<T> : TrigonometricOperations<T> {
/** /**
* Computes the inverse sine of [arg]. * Computes the inverse sine of [arg].
*/ */
@ -63,6 +37,21 @@ interface InverseTrigonometricOperations<T> : TrigonometricOperations<T> {
fun atan(arg: T): T fun atan(arg: T): T
companion object { companion object {
/**
* The identifier of sine.
*/
const val SIN_OPERATION: String = "sin"
/**
* The identifier of cosine.
*/
const val COS_OPERATION: String = "cos"
/**
* The identifier of tangent.
*/
const val TAN_OPERATION: String = "tan"
/** /**
* The identifier of inverse sine. * The identifier of inverse sine.
*/ */
@ -98,20 +87,121 @@ fun <T : MathElement<out TrigonometricOperations<T>>> tan(arg: T): T = arg.conte
/** /**
* Computes the inverse sine of [arg]. * Computes the inverse sine of [arg].
*/ */
fun <T : MathElement<out InverseTrigonometricOperations<T>>> asin(arg: T): T = arg.context.asin(arg) fun <T : MathElement<out TrigonometricOperations<T>>> asin(arg: T): T = arg.context.asin(arg)
/** /**
* Computes the inverse cosine of [arg]. * Computes the inverse cosine of [arg].
*/ */
fun <T : MathElement<out InverseTrigonometricOperations<T>>> acos(arg: T): T = arg.context.acos(arg) fun <T : MathElement<out TrigonometricOperations<T>>> acos(arg: T): T = arg.context.acos(arg)
/** /**
* Computes the inverse tangent of [arg]. * Computes the inverse tangent of [arg].
*/ */
fun <T : MathElement<out InverseTrigonometricOperations<T>>> atan(arg: T): T = arg.context.atan(arg) fun <T : MathElement<out TrigonometricOperations<T>>> atan(arg: T): T = arg.context.atan(arg)
/**
* A container for hyperbolic trigonometric operations for specific type.
*
* @param T the type of element of this structure.
*/
interface HyperbolicOperations<T> : Algebra<T> {
/**
* Computes the hyperbolic sine of [arg].
*/
fun sinh(arg: T): T
/**
* Computes the hyperbolic cosine of [arg].
*/
fun cosh(arg: T): T
/**
* Computes the hyperbolic tangent of [arg].
*/
fun tanh(arg: T): T
/**
* Computes the inverse hyperbolic sine of [arg].
*/
fun asinh(arg: T): T
/**
* Computes the inverse hyperbolic cosine of [arg].
*/
fun acosh(arg: T): T
/**
* Computes the inverse hyperbolic tangent of [arg].
*/
fun atanh(arg: T): T
companion object {
/**
* The identifier of hyperbolic sine.
*/
const val SINH_OPERATION: String = "sinh"
/**
* The identifier of hyperbolic cosine.
*/
const val COSH_OPERATION: String = "cosh"
/**
* The identifier of hyperbolic tangent.
*/
const val TANH_OPERATION: String = "tanh"
/**
* The identifier of inverse hyperbolic sine.
*/
const val ASINH_OPERATION: String = "asinh"
/**
* The identifier of inverse hyperbolic cosine.
*/
const val ACOSH_OPERATION: String = "acosh"
/**
* The identifier of inverse hyperbolic tangent.
*/
const val ATANH_OPERATION: String = "atanh"
}
}
/**
* Computes the hyperbolic sine of [arg].
*/
fun <T : MathElement<out HyperbolicOperations<T>>> sinh(arg: T): T = arg.context.sinh(arg)
/**
* Computes the hyperbolic cosine of [arg].
*/
fun <T : MathElement<out HyperbolicOperations<T>>> cosh(arg: T): T = arg.context.cosh(arg)
/**
* Computes the hyperbolic tangent of [arg].
*/
fun <T : MathElement<out HyperbolicOperations<T>>> tanh(arg: T): T = arg.context.tanh(arg)
/**
* Computes the inverse hyperbolic sine of [arg].
*/
fun <T : MathElement<out HyperbolicOperations<T>>> asinh(arg: T): T = arg.context.asinh(arg)
/**
* Computes the inverse hyperbolic cosine of [arg].
*/
fun <T : MathElement<out HyperbolicOperations<T>>> acosh(arg: T): T = arg.context.acosh(arg)
/**
* Computes the inverse hyperbolic tangent of [arg].
*/
fun <T : MathElement<out HyperbolicOperations<T>>> atanh(arg: T): T = arg.context.atanh(arg)
/** /**
* A context extension to include power operations based on exponentiation. * A context extension to include power operations based on exponentiation.
*
* @param T the type of element of this structure.
*/ */
interface PowerOperations<T> : Algebra<T> { interface PowerOperations<T> : Algebra<T> {
/** /**
@ -163,6 +253,8 @@ fun <T : MathElement<out PowerOperations<T>>> sqr(arg: T): T = arg pow 2.0
/** /**
* A container for operations related to `exp` and `ln` functions. * A container for operations related to `exp` and `ln` functions.
*
* @param T the type of element of this structure.
*/ */
interface ExponentialOperations<T> : Algebra<T> { interface ExponentialOperations<T> : Algebra<T> {
/** /**
@ -200,6 +292,9 @@ fun <T : MathElement<out ExponentialOperations<T>>> ln(arg: T): T = arg.context.
/** /**
* A container for norm functional on element. * A container for norm functional on element.
*
* @param T the type of element having norm defined.
* @param R the type of norm.
*/ */
interface Norm<in T : Any, out R> { interface Norm<in T : Any, out R> {
/** /**

View File

@ -8,19 +8,17 @@ class BoxingNDField<T, F : Field<T>>(
override val elementContext: F, override val elementContext: F,
val bufferFactory: BufferFactory<T> val bufferFactory: BufferFactory<T>
) : BufferedNDField<T, F> { ) : BufferedNDField<T, F> {
override val zero: BufferedNDFieldElement<T, F> by lazy { produce { zero } }
override val one: BufferedNDFieldElement<T, F> by lazy { produce { one } }
override val strides: Strides = DefaultStrides(shape) override val strides: Strides = DefaultStrides(shape)
fun buildBuffer(size: Int, initializer: (Int) -> T): Buffer<T> = fun buildBuffer(size: Int, initializer: (Int) -> T): Buffer<T> =
bufferFactory(size, initializer) bufferFactory(size, initializer)
override fun check(vararg elements: NDBuffer<T>) { override fun check(vararg elements: NDBuffer<T>) {
if (!elements.all { it.strides == this.strides }) error("Element strides are not the same as context strides") check(elements.all { it.strides == strides }) { "Element strides are not the same as context strides" }
} }
override val zero: BufferedNDFieldElement<T, F> by lazy { produce { zero } }
override val one: BufferedNDFieldElement<T, F> by lazy { produce { one } }
override fun produce(initializer: F.(IntArray) -> T): BufferedNDFieldElement<T, F> = override fun produce(initializer: F.(IntArray) -> T): BufferedNDFieldElement<T, F> =
BufferedNDFieldElement( BufferedNDFieldElement(
this, this,
@ -28,6 +26,7 @@ class BoxingNDField<T, F : Field<T>>(
override fun map(arg: NDBuffer<T>, transform: F.(T) -> T): BufferedNDFieldElement<T, F> { override fun map(arg: NDBuffer<T>, transform: F.(T) -> T): BufferedNDFieldElement<T, F> {
check(arg) check(arg)
return BufferedNDFieldElement( return BufferedNDFieldElement(
this, this,
buildBuffer(arg.strides.linearSize) { offset -> elementContext.transform(arg.buffer[offset]) }) buildBuffer(arg.strides.linearSize) { offset -> elementContext.transform(arg.buffer[offset]) })

View File

@ -8,19 +8,16 @@ class BoxingNDRing<T, R : Ring<T>>(
override val elementContext: R, override val elementContext: R,
val bufferFactory: BufferFactory<T> val bufferFactory: BufferFactory<T>
) : BufferedNDRing<T, R> { ) : BufferedNDRing<T, R> {
override val strides: Strides = DefaultStrides(shape) override val strides: Strides = DefaultStrides(shape)
fun buildBuffer(size: Int, initializer: (Int) -> T): Buffer<T> =
bufferFactory(size, initializer)
override fun check(vararg elements: NDBuffer<T>) {
if (!elements.all { it.strides == this.strides }) error("Element strides are not the same as context strides")
}
override val zero: BufferedNDRingElement<T, R> by lazy { produce { zero } } override val zero: BufferedNDRingElement<T, R> by lazy { produce { zero } }
override val one: BufferedNDRingElement<T, R> by lazy { produce { one } } override val one: BufferedNDRingElement<T, R> by lazy { produce { one } }
fun buildBuffer(size: Int, initializer: (Int) -> T): Buffer<T> = bufferFactory(size, initializer)
override fun check(vararg elements: NDBuffer<T>) {
require(elements.all { it.strides == strides }) { "Element strides are not the same as context strides" }
}
override fun produce(initializer: R.(IntArray) -> T): BufferedNDRingElement<T, R> = override fun produce(initializer: R.(IntArray) -> T): BufferedNDRingElement<T, R> =
BufferedNDRingElement( BufferedNDRingElement(
this, this,

View File

@ -6,7 +6,6 @@ import kotlin.reflect.KClass
* A context that allows to operate on a [MutableBuffer] as on 2d array * A context that allows to operate on a [MutableBuffer] as on 2d array
*/ */
class BufferAccessor2D<T : Any>(val type: KClass<T>, val rowNum: Int, val colNum: Int) { class BufferAccessor2D<T : Any>(val type: KClass<T>, val rowNum: Int, val colNum: Int) {
operator fun Buffer<T>.get(i: Int, j: Int): T = get(i + colNum * j) operator fun Buffer<T>.get(i: Int, j: Int): T = get(i + colNum * j)
operator fun MutableBuffer<T>.set(i: Int, j: Int, value: T) { operator fun MutableBuffer<T>.set(i: Int, j: Int, value: T) {
@ -26,15 +25,14 @@ class BufferAccessor2D<T : Any>(val type: KClass<T>, val rowNum: Int, val colNum
inner class Row(val buffer: MutableBuffer<T>, val rowIndex: Int) : MutableBuffer<T> { inner class Row(val buffer: MutableBuffer<T>, val rowIndex: Int) : MutableBuffer<T> {
override val size: Int get() = colNum override val size: Int get() = colNum
override fun get(index: Int): T = buffer[rowIndex, index] override operator fun get(index: Int): T = buffer[rowIndex, index]
override fun set(index: Int, value: T) { override operator fun set(index: Int, value: T) {
buffer[rowIndex, index] = value buffer[rowIndex, index] = value
} }
override fun copy(): MutableBuffer<T> = MutableBuffer.auto(type, colNum) { get(it) } override fun copy(): MutableBuffer<T> = MutableBuffer.auto(type, colNum) { get(it) }
override operator fun iterator(): Iterator<T> = (0 until colNum).map(::get).iterator()
override fun iterator(): Iterator<T> = (0 until colNum).map(::get).iterator()
} }

View File

@ -5,9 +5,8 @@ import scientifik.kmath.operations.*
interface BufferedNDAlgebra<T, C> : NDAlgebra<T, C, NDBuffer<T>> { interface BufferedNDAlgebra<T, C> : NDAlgebra<T, C, NDBuffer<T>> {
val strides: Strides val strides: Strides
override fun check(vararg elements: NDBuffer<T>) { override fun check(vararg elements: NDBuffer<T>): Unit =
if (!elements.all { it.strides == this.strides }) error("Strides mismatch") require(elements.all { it.strides == strides }) { ("Strides mismatch") }
}
/** /**
* Convert any [NDStructure] to buffered structure using strides from this context. * Convert any [NDStructure] to buffered structure using strides from this context.

View File

@ -30,7 +30,6 @@ class BufferedNDRingElement<T, R : Ring<T>>(
override val context: BufferedNDRing<T, R>, override val context: BufferedNDRing<T, R>,
override val buffer: Buffer<T> override val buffer: Buffer<T>
) : BufferedNDElement<T, R>(), RingElement<NDBuffer<T>, BufferedNDRingElement<T, R>, BufferedNDRing<T, R>> { ) : BufferedNDElement<T, R>(), RingElement<NDBuffer<T>, BufferedNDRingElement<T, R>, BufferedNDRing<T, R>> {
override fun unwrap(): NDBuffer<T> = this override fun unwrap(): NDBuffer<T> = this
override fun NDBuffer<T>.wrap(): BufferedNDRingElement<T, R> { override fun NDBuffer<T>.wrap(): BufferedNDRingElement<T, R> {
@ -43,7 +42,6 @@ class BufferedNDFieldElement<T, F : Field<T>>(
override val context: BufferedNDField<T, F>, override val context: BufferedNDField<T, F>,
override val buffer: Buffer<T> override val buffer: Buffer<T>
) : BufferedNDElement<T, F>(), FieldElement<NDBuffer<T>, BufferedNDFieldElement<T, F>, BufferedNDField<T, F>> { ) : BufferedNDElement<T, F>(), FieldElement<NDBuffer<T>, BufferedNDFieldElement<T, F>, BufferedNDField<T, F>> {
override fun unwrap(): NDBuffer<T> = this override fun unwrap(): NDBuffer<T> = this
override fun NDBuffer<T>.wrap(): BufferedNDFieldElement<T, F> { override fun NDBuffer<T>.wrap(): BufferedNDFieldElement<T, F> {

View File

@ -2,6 +2,8 @@ package scientifik.kmath.structures
import scientifik.kmath.operations.Complex import scientifik.kmath.operations.Complex
import scientifik.kmath.operations.complex import scientifik.kmath.operations.complex
import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.contract
import kotlin.reflect.KClass import kotlin.reflect.KClass
/** /**
@ -117,15 +119,14 @@ interface MutableBuffer<T> : Buffer<T> {
MutableListBuffer(MutableList(size, initializer)) MutableListBuffer(MutableList(size, initializer))
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
inline fun <T : Any> auto(type: KClass<out T>, size: Int, initializer: (Int) -> T): MutableBuffer<T> { inline fun <T : Any> auto(type: KClass<out T>, size: Int, initializer: (Int) -> T): MutableBuffer<T> =
return when (type) { when (type) {
Double::class -> RealBuffer(DoubleArray(size) { initializer(it) as Double }) as MutableBuffer<T> Double::class -> RealBuffer(DoubleArray(size) { initializer(it) as Double }) as MutableBuffer<T>
Short::class -> ShortBuffer(ShortArray(size) { initializer(it) as Short }) as MutableBuffer<T> Short::class -> ShortBuffer(ShortArray(size) { initializer(it) as Short }) as MutableBuffer<T>
Int::class -> IntBuffer(IntArray(size) { initializer(it) as Int }) as MutableBuffer<T> Int::class -> IntBuffer(IntArray(size) { initializer(it) as Int }) as MutableBuffer<T>
Long::class -> LongBuffer(LongArray(size) { initializer(it) as Long }) as MutableBuffer<T> Long::class -> LongBuffer(LongArray(size) { initializer(it) as Long }) as MutableBuffer<T>
else -> boxing(size, initializer) else -> boxing(size, initializer)
} }
}
/** /**
* Create most appropriate mutable buffer for given type avoiding boxing wherever possible * Create most appropriate mutable buffer for given type avoiding boxing wherever possible
@ -150,9 +151,8 @@ inline class ListBuffer<T>(val list: List<T>) : Buffer<T> {
override val size: Int override val size: Int
get() = list.size get() = list.size
override fun get(index: Int): T = list[index] override operator fun get(index: Int): T = list[index]
override operator fun iterator(): Iterator<T> = list.iterator()
override fun iterator(): Iterator<T> = list.iterator()
} }
/** /**
@ -167,7 +167,10 @@ fun <T> List<T>.asBuffer(): ListBuffer<T> = ListBuffer(this)
* The function [init] is called for each array element sequentially starting from the first one. * The function [init] is called for each array element sequentially starting from the first one.
* It should return the value for an array element given its index. * It should return the value for an array element given its index.
*/ */
inline fun <T> ListBuffer(size: Int, init: (Int) -> T): ListBuffer<T> = List(size, init).asBuffer() inline fun <T> ListBuffer(size: Int, init: (Int) -> T): ListBuffer<T> {
contract { callsInPlace(init) }
return List(size, init).asBuffer()
}
/** /**
* [MutableBuffer] implementation over [MutableList]. * [MutableBuffer] implementation over [MutableList].
@ -176,17 +179,16 @@ inline fun <T> ListBuffer(size: Int, init: (Int) -> T): ListBuffer<T> = List(siz
* @property list The underlying list. * @property list The underlying list.
*/ */
inline class MutableListBuffer<T>(val list: MutableList<T>) : MutableBuffer<T> { inline class MutableListBuffer<T>(val list: MutableList<T>) : MutableBuffer<T> {
override val size: Int override val size: Int
get() = list.size get() = list.size
override fun get(index: Int): T = list[index] override operator fun get(index: Int): T = list[index]
override fun set(index: Int, value: T) { override operator fun set(index: Int, value: T) {
list[index] = value list[index] = value
} }
override fun iterator(): Iterator<T> = list.iterator() override operator fun iterator(): Iterator<T> = list.iterator()
override fun copy(): MutableBuffer<T> = MutableListBuffer(ArrayList(list)) override fun copy(): MutableBuffer<T> = MutableListBuffer(ArrayList(list))
} }
@ -201,14 +203,13 @@ class ArrayBuffer<T>(private val array: Array<T>) : MutableBuffer<T> {
override val size: Int override val size: Int
get() = array.size get() = array.size
override fun get(index: Int): T = array[index] override operator fun get(index: Int): T = array[index]
override fun set(index: Int, value: T) { override operator fun set(index: Int, value: T) {
array[index] = value array[index] = value
} }
override fun iterator(): Iterator<T> = array.iterator() override operator fun iterator(): Iterator<T> = array.iterator()
override fun copy(): MutableBuffer<T> = ArrayBuffer(array.copyOf()) override fun copy(): MutableBuffer<T> = ArrayBuffer(array.copyOf())
} }
@ -226,9 +227,9 @@ fun <T> Array<T>.asBuffer(): ArrayBuffer<T> = ArrayBuffer(this)
inline class ReadOnlyBuffer<T>(val buffer: MutableBuffer<T>) : Buffer<T> { inline class ReadOnlyBuffer<T>(val buffer: MutableBuffer<T>) : Buffer<T> {
override val size: Int get() = buffer.size override val size: Int get() = buffer.size
override fun get(index: Int): T = buffer[index] override operator fun get(index: Int): T = buffer[index]
override fun iterator(): Iterator<T> = buffer.iterator() override operator fun iterator(): Iterator<T> = buffer.iterator()
} }
/** /**
@ -238,12 +239,12 @@ inline class ReadOnlyBuffer<T>(val buffer: MutableBuffer<T>) : Buffer<T> {
* @param T the type of elements provided by the buffer. * @param T the type of elements provided by the buffer.
*/ */
class VirtualBuffer<T>(override val size: Int, private val generator: (Int) -> T) : Buffer<T> { class VirtualBuffer<T>(override val size: Int, private val generator: (Int) -> T) : Buffer<T> {
override fun get(index: Int): T { override operator fun get(index: Int): T {
if (index < 0 || index >= size) throw IndexOutOfBoundsException("Expected index from 0 to ${size - 1}, but found $index") if (index < 0 || index >= size) throw IndexOutOfBoundsException("Expected index from 0 to ${size - 1}, but found $index")
return generator(index) return generator(index)
} }
override fun iterator(): Iterator<T> = (0 until size).asSequence().map(generator).iterator() override operator fun iterator(): Iterator<T> = (0 until size).asSequence().map(generator).iterator()
override fun contentEquals(other: Buffer<*>): Boolean { override fun contentEquals(other: Buffer<*>): Boolean {
return if (other is VirtualBuffer) { return if (other is VirtualBuffer) {

View File

@ -4,6 +4,9 @@ import scientifik.kmath.operations.Complex
import scientifik.kmath.operations.ComplexField import scientifik.kmath.operations.ComplexField
import scientifik.kmath.operations.FieldElement import scientifik.kmath.operations.FieldElement
import scientifik.kmath.operations.complex import scientifik.kmath.operations.complex
import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
typealias ComplexNDElement = BufferedNDFieldElement<Complex, ComplexField> typealias ComplexNDElement = BufferedNDFieldElement<Complex, ComplexField>
@ -15,7 +18,6 @@ class ComplexNDField(override val shape: IntArray) :
ExtendedNDField<Complex, ComplexField, NDBuffer<Complex>> { ExtendedNDField<Complex, ComplexField, NDBuffer<Complex>> {
override val strides: Strides = DefaultStrides(shape) override val strides: Strides = DefaultStrides(shape)
override val elementContext: ComplexField get() = ComplexField override val elementContext: ComplexField get() = ComplexField
override val zero: ComplexNDElement by lazy { produce { zero } } override val zero: ComplexNDElement by lazy { produce { zero } }
override val one: ComplexNDElement by lazy { produce { one } } override val one: ComplexNDElement by lazy { produce { one } }
@ -45,6 +47,7 @@ class ComplexNDField(override val shape: IntArray) :
transform: ComplexField.(index: IntArray, Complex) -> Complex transform: ComplexField.(index: IntArray, Complex) -> Complex
): ComplexNDElement { ): ComplexNDElement {
check(arg) check(arg)
return BufferedNDFieldElement( return BufferedNDFieldElement(
this, this,
buildBuffer(arg.strides.linearSize) { offset -> buildBuffer(arg.strides.linearSize) { offset ->
@ -61,6 +64,7 @@ class ComplexNDField(override val shape: IntArray) :
transform: ComplexField.(Complex, Complex) -> Complex transform: ComplexField.(Complex, Complex) -> Complex
): ComplexNDElement { ): ComplexNDElement {
check(a, b) check(a, b)
return BufferedNDFieldElement( return BufferedNDFieldElement(
this, this,
buildBuffer(strides.linearSize) { offset -> elementContext.transform(a.buffer[offset], b.buffer[offset]) }) buildBuffer(strides.linearSize) { offset -> elementContext.transform(a.buffer[offset], b.buffer[offset]) })
@ -69,23 +73,25 @@ class ComplexNDField(override val shape: IntArray) :
override fun NDBuffer<Complex>.toElement(): FieldElement<NDBuffer<Complex>, *, out BufferedNDField<Complex, ComplexField>> = override fun NDBuffer<Complex>.toElement(): FieldElement<NDBuffer<Complex>, *, out BufferedNDField<Complex, ComplexField>> =
BufferedNDFieldElement(this@ComplexNDField, buffer) BufferedNDFieldElement(this@ComplexNDField, buffer)
override fun power(arg: NDBuffer<Complex>, pow: Number): ComplexNDElement = map(arg) { power(it, pow) } override fun power(arg: NDBuffer<Complex>, pow: Number): ComplexNDElement =
map(arg) { power(it, pow) }
override fun exp(arg: NDBuffer<Complex>): ComplexNDElement = map(arg) { exp(it) } override fun exp(arg: NDBuffer<Complex>): ComplexNDElement = map(arg) { exp(it) }
override fun ln(arg: NDBuffer<Complex>): ComplexNDElement = map(arg) { ln(it) } override fun ln(arg: NDBuffer<Complex>): ComplexNDElement = map(arg) { ln(it) }
override fun sin(arg: NDBuffer<Complex>): ComplexNDElement = map(arg) { sin(it) } override fun sin(arg: NDBuffer<Complex>): ComplexNDElement = map(arg) { sin(it) }
override fun cos(arg: NDBuffer<Complex>): ComplexNDElement = map(arg) { cos(it) } override fun cos(arg: NDBuffer<Complex>): ComplexNDElement = map(arg) { cos(it) }
override fun tan(arg: NDBuffer<Complex>): ComplexNDElement = map(arg) { tan(it) } override fun tan(arg: NDBuffer<Complex>): ComplexNDElement = map(arg) { tan(it) }
override fun asin(arg: NDBuffer<Complex>): ComplexNDElement = map(arg) { asin(it) } override fun asin(arg: NDBuffer<Complex>): ComplexNDElement = map(arg) { asin(it) }
override fun acos(arg: NDBuffer<Complex>): ComplexNDElement = map(arg) { acos(it) } override fun acos(arg: NDBuffer<Complex>): ComplexNDElement = map(arg) { acos(it) }
override fun atan(arg: NDBuffer<Complex>): ComplexNDElement = map(arg) { atan(it) } override fun atan(arg: NDBuffer<Complex>): ComplexNDElement = map(arg) { atan(it) }
override fun sinh(arg: NDBuffer<Complex>): ComplexNDElement = map(arg) { sinh(it) }
override fun cosh(arg: NDBuffer<Complex>): ComplexNDElement = map(arg) { cosh(it) }
override fun tanh(arg: NDBuffer<Complex>): ComplexNDElement = map(arg) { tanh(it) }
override fun asinh(arg: NDBuffer<Complex>): ComplexNDElement = map(arg) { asinh(it) }
override fun acosh(arg: NDBuffer<Complex>): ComplexNDElement = map(arg) { acosh(it) }
override fun atanh(arg: NDBuffer<Complex>): ComplexNDElement = map(arg) { atanh(it) }
} }
@ -107,6 +113,7 @@ inline fun ComplexNDElement.mapIndexed(crossinline transform: ComplexField.(inde
* Map one [ComplexNDElement] using function without indices. * Map one [ComplexNDElement] using function without indices.
*/ */
inline fun ComplexNDElement.map(crossinline transform: ComplexField.(Complex) -> Complex): ComplexNDElement { inline fun ComplexNDElement.map(crossinline transform: ComplexField.(Complex) -> Complex): ComplexNDElement {
contract { callsInPlace(transform) }
val buffer = Buffer.complex(strides.linearSize) { offset -> ComplexField.transform(buffer[offset]) } val buffer = Buffer.complex(strides.linearSize) { offset -> ComplexField.transform(buffer[offset]) }
return BufferedNDFieldElement(context, buffer) return BufferedNDFieldElement(context, buffer)
} }
@ -146,5 +153,6 @@ fun NDElement.Companion.complex(vararg shape: Int, initializer: ComplexField.(In
* Produce a context for n-dimensional operations inside this real field * Produce a context for n-dimensional operations inside this real field
*/ */
inline fun <R> ComplexField.nd(vararg shape: Int, action: ComplexNDField.() -> R): R { inline fun <R> ComplexField.nd(vararg shape: Int, action: ComplexNDField.() -> R): R {
return NDField.complex(*shape).run(action) contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) }
return NDField.complex(*shape).action()
} }

View File

@ -1,5 +1,7 @@
package scientifik.kmath.structures package scientifik.kmath.structures
import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.contract
import kotlin.experimental.and import kotlin.experimental.and
/** /**
@ -57,17 +59,18 @@ class FlaggedRealBuffer(val values: DoubleArray, val flags: ByteArray) : Flagged
override val size: Int get() = values.size override val size: Int get() = values.size
override fun get(index: Int): Double? = if (isValid(index)) values[index] else null override operator fun get(index: Int): Double? = if (isValid(index)) values[index] else null
override fun iterator(): Iterator<Double?> = values.indices.asSequence().map { override operator fun iterator(): Iterator<Double?> = values.indices.asSequence().map {
if (isValid(it)) values[it] else null if (isValid(it)) values[it] else null
}.iterator() }.iterator()
} }
inline fun FlaggedRealBuffer.forEachValid(block: (Double) -> Unit) { inline fun FlaggedRealBuffer.forEachValid(block: (Double) -> Unit) {
for (i in indices) { contract { callsInPlace(block) }
if (isValid(i)) {
block(values[i]) indices
} .asSequence()
} .filter(::isValid)
.forEach { block(values[it]) }
} }

View File

@ -1,5 +1,8 @@
package scientifik.kmath.structures package scientifik.kmath.structures
import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.contract
/** /**
* Specialized [MutableBuffer] implementation over [FloatArray]. * Specialized [MutableBuffer] implementation over [FloatArray].
* *
@ -8,13 +11,13 @@ package scientifik.kmath.structures
inline class FloatBuffer(val array: FloatArray) : MutableBuffer<Float> { inline class FloatBuffer(val array: FloatArray) : MutableBuffer<Float> {
override val size: Int get() = array.size override val size: Int get() = array.size
override fun get(index: Int): Float = array[index] override operator fun get(index: Int): Float = array[index]
override fun set(index: Int, value: Float) { override operator fun set(index: Int, value: Float) {
array[index] = value array[index] = value
} }
override fun iterator(): FloatIterator = array.iterator() override operator fun iterator(): FloatIterator = array.iterator()
override fun copy(): MutableBuffer<Float> = override fun copy(): MutableBuffer<Float> =
FloatBuffer(array.copyOf()) FloatBuffer(array.copyOf())
@ -27,7 +30,10 @@ inline class FloatBuffer(val array: FloatArray) : MutableBuffer<Float> {
* The function [init] is called for each array element sequentially starting from the first one. * The function [init] is called for each array element sequentially starting from the first one.
* It should return the value for an buffer element given its index. * It should return the value for an buffer element given its index.
*/ */
inline fun FloatBuffer(size: Int, init: (Int) -> Float): FloatBuffer = FloatBuffer(FloatArray(size) { init(it) }) inline fun FloatBuffer(size: Int, init: (Int) -> Float): FloatBuffer {
contract { callsInPlace(init) }
return FloatBuffer(FloatArray(size) { init(it) })
}
/** /**
* Returns a new [FloatBuffer] of given elements. * Returns a new [FloatBuffer] of given elements.

View File

@ -1,5 +1,9 @@
package scientifik.kmath.structures package scientifik.kmath.structures
import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
/** /**
* Specialized [MutableBuffer] implementation over [IntArray]. * Specialized [MutableBuffer] implementation over [IntArray].
* *
@ -8,17 +12,16 @@ package scientifik.kmath.structures
inline class IntBuffer(val array: IntArray) : MutableBuffer<Int> { inline class IntBuffer(val array: IntArray) : MutableBuffer<Int> {
override val size: Int get() = array.size override val size: Int get() = array.size
override fun get(index: Int): Int = array[index] override operator fun get(index: Int): Int = array[index]
override fun set(index: Int, value: Int) { override operator fun set(index: Int, value: Int) {
array[index] = value array[index] = value
} }
override fun iterator(): IntIterator = array.iterator() override operator fun iterator(): IntIterator = array.iterator()
override fun copy(): MutableBuffer<Int> = override fun copy(): MutableBuffer<Int> =
IntBuffer(array.copyOf()) IntBuffer(array.copyOf())
} }
/** /**
@ -28,7 +31,10 @@ inline class IntBuffer(val array: IntArray) : MutableBuffer<Int> {
* The function [init] is called for each array element sequentially starting from the first one. * The function [init] is called for each array element sequentially starting from the first one.
* It should return the value for an buffer element given its index. * It should return the value for an buffer element given its index.
*/ */
inline fun IntBuffer(size: Int, init: (Int) -> Int): IntBuffer = IntBuffer(IntArray(size) { init(it) }) inline fun IntBuffer(size: Int, init: (Int) -> Int): IntBuffer {
contract { callsInPlace(init) }
return IntBuffer(IntArray(size) { init(it) })
}
/** /**
* Returns a new [IntBuffer] of given elements. * Returns a new [IntBuffer] of given elements.

View File

@ -1,5 +1,8 @@
package scientifik.kmath.structures package scientifik.kmath.structures
import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.contract
/** /**
* Specialized [MutableBuffer] implementation over [LongArray]. * Specialized [MutableBuffer] implementation over [LongArray].
* *
@ -8,13 +11,13 @@ package scientifik.kmath.structures
inline class LongBuffer(val array: LongArray) : MutableBuffer<Long> { inline class LongBuffer(val array: LongArray) : MutableBuffer<Long> {
override val size: Int get() = array.size override val size: Int get() = array.size
override fun get(index: Int): Long = array[index] override operator fun get(index: Int): Long = array[index]
override fun set(index: Int, value: Long) { override operator fun set(index: Int, value: Long) {
array[index] = value array[index] = value
} }
override fun iterator(): LongIterator = array.iterator() override operator fun iterator(): LongIterator = array.iterator()
override fun copy(): MutableBuffer<Long> = override fun copy(): MutableBuffer<Long> =
LongBuffer(array.copyOf()) LongBuffer(array.copyOf())
@ -28,7 +31,10 @@ inline class LongBuffer(val array: LongArray) : MutableBuffer<Long> {
* The function [init] is called for each array element sequentially starting from the first one. * The function [init] is called for each array element sequentially starting from the first one.
* It should return the value for an buffer element given its index. * It should return the value for an buffer element given its index.
*/ */
inline fun LongBuffer(size: Int, init: (Int) -> Long): LongBuffer = LongBuffer(LongArray(size) { init(it) }) inline fun LongBuffer(size: Int, init: (Int) -> Long): LongBuffer {
contract { callsInPlace(init) }
return LongBuffer(LongArray(size) { init(it) })
}
/** /**
* Returns a new [LongBuffer] of given elements. * Returns a new [LongBuffer] of given elements.

View File

@ -14,10 +14,8 @@ open class MemoryBuffer<T : Any>(protected val memory: Memory, protected val spe
private val reader: MemoryReader = memory.reader() private val reader: MemoryReader = memory.reader()
override fun get(index: Int): T = reader.read(spec, spec.objectSize * index) override operator fun get(index: Int): T = reader.read(spec, spec.objectSize * index)
override operator fun iterator(): Iterator<T> = (0 until size).asSequence().map { get(it) }.iterator()
override fun iterator(): Iterator<T> = (0 until size).asSequence().map { get(it) }.iterator()
companion object { companion object {
fun <T : Any> create(spec: MemorySpec<T>, size: Int): MemoryBuffer<T> = fun <T : Any> create(spec: MemorySpec<T>, size: Int): MemoryBuffer<T> =
@ -48,8 +46,7 @@ class MutableMemoryBuffer<T : Any>(memory: Memory, spec: MemorySpec<T>) : Memory
private val writer: MemoryWriter = memory.writer() private val writer: MemoryWriter = memory.writer()
override fun set(index: Int, value: T): Unit = writer.write(spec, spec.objectSize * index, value) override operator fun set(index: Int, value: T): Unit = writer.write(spec, spec.objectSize * index, value)
override fun copy(): MutableBuffer<T> = MutableMemoryBuffer(memory.copy(), spec) override fun copy(): MutableBuffer<T> = MutableMemoryBuffer(memory.copy(), spec)
companion object { companion object {

View File

@ -26,19 +26,20 @@ interface NDElement<T, C, N : NDStructure<T>> : NDStructure<T> {
fun real(shape: IntArray, initializer: RealField.(IntArray) -> Double = { 0.0 }): RealNDElement = fun real(shape: IntArray, initializer: RealField.(IntArray) -> Double = { 0.0 }): RealNDElement =
NDField.real(*shape).produce(initializer) NDField.real(*shape).produce(initializer)
inline fun real1D(dim: Int, crossinline initializer: (Int) -> Double = { _ -> 0.0 }): RealNDElement =
fun real1D(dim: Int, initializer: (Int) -> Double = { _ -> 0.0 }): RealNDElement =
real(intArrayOf(dim)) { initializer(it[0]) } real(intArrayOf(dim)) { initializer(it[0]) }
inline fun real2D(
dim1: Int,
dim2: Int,
crossinline initializer: (Int, Int) -> Double = { _, _ -> 0.0 }
): RealNDElement = real(intArrayOf(dim1, dim2)) { initializer(it[0], it[1]) }
fun real2D(dim1: Int, dim2: Int, initializer: (Int, Int) -> Double = { _, _ -> 0.0 }): RealNDElement = inline fun real3D(
real(intArrayOf(dim1, dim2)) { initializer(it[0], it[1]) }
fun real3D(
dim1: Int, dim1: Int,
dim2: Int, dim2: Int,
dim3: Int, dim3: Int,
initializer: (Int, Int, Int) -> Double = { _, _, _ -> 0.0 } crossinline initializer: (Int, Int, Int) -> Double = { _, _, _ -> 0.0 }
): RealNDElement = real(intArrayOf(dim1, dim2, dim3)) { initializer(it[0], it[1], it[2]) } ): RealNDElement = real(intArrayOf(dim1, dim2, dim3)) { initializer(it[0], it[1], it[2]) }
@ -72,7 +73,6 @@ fun <T, C, N : NDStructure<T>> NDElement<T, C, N>.mapIndexed(transform: C.(index
fun <T, C, N : NDStructure<T>> NDElement<T, C, N>.map(transform: C.(T) -> T): NDElement<T, C, N> = fun <T, C, N : NDStructure<T>> NDElement<T, C, N>.map(transform: C.(T) -> T): NDElement<T, C, N> =
context.map(unwrap(), transform).wrap() context.map(unwrap(), transform).wrap()
/** /**
* Element by element application of any operation on elements to the whole [NDElement] * Element by element application of any operation on elements to the whole [NDElement]
*/ */
@ -107,7 +107,6 @@ operator fun <T, R : Ring<T>, N : NDStructure<T>> NDElement<T, R, N>.times(arg:
operator fun <T, F : Field<T>, N : NDStructure<T>> NDElement<T, F, N>.div(arg: T): NDElement<T, F, N> = operator fun <T, F : Field<T>, N : NDStructure<T>> NDElement<T, F, N>.div(arg: T): NDElement<T, F, N> =
map { value -> arg / value } map { value -> arg / value }
// /** // /**
// * Reverse sum operation // * Reverse sum operation
// */ // */

View File

@ -1,5 +1,7 @@
package scientifik.kmath.structures package scientifik.kmath.structures
import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.contract
import kotlin.jvm.JvmName import kotlin.jvm.JvmName
import kotlin.reflect.KClass import kotlin.reflect.KClass
@ -139,9 +141,8 @@ interface MutableNDStructure<T> : NDStructure<T> {
} }
inline fun <T> MutableNDStructure<T>.mapInPlace(action: (IntArray, T) -> T) { inline fun <T> MutableNDStructure<T>.mapInPlace(action: (IntArray, T) -> T) {
elements().forEach { (index, oldValue) -> contract { callsInPlace(action) }
this[index] = action(index, oldValue) elements().forEach { (index, oldValue) -> this[index] = action(index, oldValue) }
}
} }
/** /**
@ -200,14 +201,12 @@ class DefaultStrides private constructor(override val shape: IntArray) : Strides
}.toList() }.toList()
} }
override fun offset(index: IntArray): Int { override fun offset(index: IntArray): Int = index.mapIndexed { i, value ->
return index.mapIndexed { i, value -> if (value < 0 || value >= this.shape[i])
if (value < 0 || value >= this.shape[i]) { throw IndexOutOfBoundsException("Index $value out of shape bounds: (0,${this.shape[i]})")
throw RuntimeException("Index $value out of shape bounds: (0,${this.shape[i]})")
} value * strides[i]
value * strides[i] }.sum()
}.sum()
}
override fun index(offset: Int): IntArray { override fun index(offset: Int): IntArray {
val res = IntArray(shape.size) val res = IntArray(shape.size)
@ -259,7 +258,7 @@ abstract class NDBuffer<T> : NDStructure<T> {
*/ */
abstract val strides: Strides abstract val strides: Strides
override fun get(index: IntArray): T = buffer[strides.offset(index)] override operator fun get(index: IntArray): T = buffer[strides.offset(index)]
override val shape: IntArray get() = strides.shape override val shape: IntArray get() = strides.shape
@ -319,13 +318,13 @@ class MutableBufferNDStructure<T>(
} }
} }
override fun set(index: IntArray, value: T): Unit = buffer.set(strides.offset(index), value) override operator fun set(index: IntArray, value: T): Unit = buffer.set(strides.offset(index), value)
} }
inline fun <reified T : Any> NDStructure<T>.combine( inline fun <reified T : Any> NDStructure<T>.combine(
struct: NDStructure<T>, struct: NDStructure<T>,
crossinline block: (T, T) -> T crossinline block: (T, T) -> T
): NDStructure<T> { ): NDStructure<T> {
if (!this.shape.contentEquals(struct.shape)) error("Shape mismatch in structure combination") require(shape.contentEquals(struct.shape)) { "Shape mismatch in structure combination" }
return NDStructure.auto(shape) { block(this[it], struct[it]) } return NDStructure.auto(shape) { block(this[it], struct[it]) }
} }

View File

@ -1,5 +1,8 @@
package scientifik.kmath.structures package scientifik.kmath.structures
import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.contract
/** /**
* Specialized [MutableBuffer] implementation over [DoubleArray]. * Specialized [MutableBuffer] implementation over [DoubleArray].
* *
@ -8,13 +11,13 @@ package scientifik.kmath.structures
inline class RealBuffer(val array: DoubleArray) : MutableBuffer<Double> { inline class RealBuffer(val array: DoubleArray) : MutableBuffer<Double> {
override val size: Int get() = array.size override val size: Int get() = array.size
override fun get(index: Int): Double = array[index] override operator fun get(index: Int): Double = array[index]
override fun set(index: Int, value: Double) { override operator fun set(index: Int, value: Double) {
array[index] = value array[index] = value
} }
override fun iterator(): DoubleIterator = array.iterator() override operator fun iterator(): DoubleIterator = array.iterator()
override fun copy(): MutableBuffer<Double> = override fun copy(): MutableBuffer<Double> =
RealBuffer(array.copyOf()) RealBuffer(array.copyOf())
@ -27,7 +30,10 @@ inline class RealBuffer(val array: DoubleArray) : MutableBuffer<Double> {
* The function [init] is called for each array element sequentially starting from the first one. * The function [init] is called for each array element sequentially starting from the first one.
* It should return the value for an buffer element given its index. * It should return the value for an buffer element given its index.
*/ */
inline fun RealBuffer(size: Int, init: (Int) -> Double): RealBuffer = RealBuffer(DoubleArray(size) { init(it) }) inline fun RealBuffer(size: Int, init: (Int) -> Double): RealBuffer {
contract { callsInPlace(init) }
return RealBuffer(DoubleArray(size) { init(it) })
}
/** /**
* Returns a new [RealBuffer] of given elements. * Returns a new [RealBuffer] of given elements.

View File

@ -10,14 +10,15 @@ import kotlin.math.*
*/ */
object RealBufferFieldOperations : ExtendedFieldOperations<Buffer<Double>> { object RealBufferFieldOperations : ExtendedFieldOperations<Buffer<Double>> {
override fun add(a: Buffer<Double>, b: Buffer<Double>): RealBuffer { override fun add(a: Buffer<Double>, b: Buffer<Double>): RealBuffer {
require(b.size == a.size) { "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " } require(b.size == a.size) {
"The size of the first buffer ${a.size} should be the same as for second one: ${b.size} "
}
return if (a is RealBuffer && b is RealBuffer) { return if (a is RealBuffer && b is RealBuffer) {
val aArray = a.array val aArray = a.array
val bArray = b.array val bArray = b.array
RealBuffer(DoubleArray(a.size) { aArray[it] + bArray[it] }) RealBuffer(DoubleArray(a.size) { aArray[it] + bArray[it] })
} else } else RealBuffer(DoubleArray(a.size) { a[it] + b[it] })
RealBuffer(DoubleArray(a.size) { a[it] + b[it] })
} }
override fun multiply(a: Buffer<Double>, k: Number): RealBuffer { override fun multiply(a: Buffer<Double>, k: Number): RealBuffer {
@ -26,12 +27,13 @@ object RealBufferFieldOperations : ExtendedFieldOperations<Buffer<Double>> {
return if (a is RealBuffer) { return if (a is RealBuffer) {
val aArray = a.array val aArray = a.array
RealBuffer(DoubleArray(a.size) { aArray[it] * kValue }) RealBuffer(DoubleArray(a.size) { aArray[it] * kValue })
} else } else RealBuffer(DoubleArray(a.size) { a[it] * kValue })
RealBuffer(DoubleArray(a.size) { a[it] * kValue })
} }
override fun multiply(a: Buffer<Double>, b: Buffer<Double>): RealBuffer { override fun multiply(a: Buffer<Double>, b: Buffer<Double>): RealBuffer {
require(b.size == a.size) { "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " } require(b.size == a.size) {
"The size of the first buffer ${a.size} should be the same as for second one: ${b.size} "
}
return if (a is RealBuffer && b is RealBuffer) { return if (a is RealBuffer && b is RealBuffer) {
val aArray = a.array val aArray = a.array
@ -42,34 +44,31 @@ object RealBufferFieldOperations : ExtendedFieldOperations<Buffer<Double>> {
} }
override fun divide(a: Buffer<Double>, b: Buffer<Double>): RealBuffer { override fun divide(a: Buffer<Double>, b: Buffer<Double>): RealBuffer {
require(b.size == a.size) { "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " } require(b.size == a.size) {
"The size of the first buffer ${a.size} should be the same as for second one: ${b.size} "
}
return if (a is RealBuffer && b is RealBuffer) { return if (a is RealBuffer && b is RealBuffer) {
val aArray = a.array val aArray = a.array
val bArray = b.array val bArray = b.array
RealBuffer(DoubleArray(a.size) { aArray[it] / bArray[it] }) RealBuffer(DoubleArray(a.size) { aArray[it] / bArray[it] })
} else } else RealBuffer(DoubleArray(a.size) { a[it] / b[it] })
RealBuffer(DoubleArray(a.size) { a[it] / b[it] })
} }
override fun sin(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) { override fun sin(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
val array = arg.array val array = arg.array
RealBuffer(DoubleArray(arg.size) { sin(array[it]) }) RealBuffer(DoubleArray(arg.size) { sin(array[it]) })
} else { } else RealBuffer(DoubleArray(arg.size) { sin(arg[it]) })
RealBuffer(DoubleArray(arg.size) { sin(arg[it]) })
}
override fun cos(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) { override fun cos(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
val array = arg.array val array = arg.array
RealBuffer(DoubleArray(arg.size) { cos(array[it]) }) RealBuffer(DoubleArray(arg.size) { cos(array[it]) })
} else } else RealBuffer(DoubleArray(arg.size) { cos(arg[it]) })
RealBuffer(DoubleArray(arg.size) { cos(arg[it]) })
override fun tan(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) { override fun tan(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
val array = arg.array val array = arg.array
RealBuffer(DoubleArray(arg.size) { tan(array[it]) }) RealBuffer(DoubleArray(arg.size) { tan(array[it]) })
} else } else RealBuffer(DoubleArray(arg.size) { tan(arg[it]) })
RealBuffer(DoubleArray(arg.size) { tan(arg[it]) })
override fun asin(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) { override fun asin(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
val array = arg.array val array = arg.array
@ -90,23 +89,50 @@ object RealBufferFieldOperations : ExtendedFieldOperations<Buffer<Double>> {
} else } else
RealBuffer(DoubleArray(arg.size) { atan(arg[it]) }) RealBuffer(DoubleArray(arg.size) { atan(arg[it]) })
override fun sinh(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
val array = arg.array
RealBuffer(DoubleArray(arg.size) { sinh(array[it]) })
} else RealBuffer(DoubleArray(arg.size) { sinh(arg[it]) })
override fun cosh(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
val array = arg.array
RealBuffer(DoubleArray(arg.size) { cosh(array[it]) })
} else RealBuffer(DoubleArray(arg.size) { cosh(arg[it]) })
override fun tanh(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
val array = arg.array
RealBuffer(DoubleArray(arg.size) { tanh(array[it]) })
} else RealBuffer(DoubleArray(arg.size) { tanh(arg[it]) })
override fun asinh(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
val array = arg.array
RealBuffer(DoubleArray(arg.size) { asinh(array[it]) })
} else RealBuffer(DoubleArray(arg.size) { asinh(arg[it]) })
override fun acosh(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
val array = arg.array
RealBuffer(DoubleArray(arg.size) { acosh(array[it]) })
} else RealBuffer(DoubleArray(arg.size) { acosh(arg[it]) })
override fun atanh(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
val array = arg.array
RealBuffer(DoubleArray(arg.size) { atanh(array[it]) })
} else RealBuffer(DoubleArray(arg.size) { atanh(arg[it]) })
override fun power(arg: Buffer<Double>, pow: Number): RealBuffer = if (arg is RealBuffer) { override fun power(arg: Buffer<Double>, pow: Number): RealBuffer = if (arg is RealBuffer) {
val array = arg.array val array = arg.array
RealBuffer(DoubleArray(arg.size) { array[it].pow(pow.toDouble()) }) RealBuffer(DoubleArray(arg.size) { array[it].pow(pow.toDouble()) })
} else } else RealBuffer(DoubleArray(arg.size) { arg[it].pow(pow.toDouble()) })
RealBuffer(DoubleArray(arg.size) { arg[it].pow(pow.toDouble()) })
override fun exp(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) { override fun exp(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
val array = arg.array val array = arg.array
RealBuffer(DoubleArray(arg.size) { exp(array[it]) }) RealBuffer(DoubleArray(arg.size) { exp(array[it]) })
} else } else RealBuffer(DoubleArray(arg.size) { exp(arg[it]) })
RealBuffer(DoubleArray(arg.size) { exp(arg[it]) })
override fun ln(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) { override fun ln(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
val array = arg.array val array = arg.array
RealBuffer(DoubleArray(arg.size) { ln(array[it]) }) RealBuffer(DoubleArray(arg.size) { ln(array[it]) })
} else } else RealBuffer(DoubleArray(arg.size) { ln(arg[it]) })
RealBuffer(DoubleArray(arg.size) { ln(arg[it]) })
} }
/** /**
@ -168,6 +194,36 @@ class RealBufferField(val size: Int) : ExtendedField<Buffer<Double>> {
return RealBufferFieldOperations.atan(arg) return RealBufferFieldOperations.atan(arg)
} }
override fun sinh(arg: Buffer<Double>): RealBuffer {
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
return RealBufferFieldOperations.sinh(arg)
}
override fun cosh(arg: Buffer<Double>): RealBuffer {
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
return RealBufferFieldOperations.cosh(arg)
}
override fun tanh(arg: Buffer<Double>): RealBuffer {
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
return RealBufferFieldOperations.tanh(arg)
}
override fun asinh(arg: Buffer<Double>): RealBuffer {
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
return RealBufferFieldOperations.asinh(arg)
}
override fun acosh(arg: Buffer<Double>): RealBuffer {
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
return RealBufferFieldOperations.acosh(arg)
}
override fun atanh(arg: Buffer<Double>): RealBuffer {
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
return RealBufferFieldOperations.atanh(arg)
}
override fun power(arg: Buffer<Double>, pow: Number): RealBuffer { override fun power(arg: Buffer<Double>, pow: Number): RealBuffer {
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
return RealBufferFieldOperations.power(arg, pow) return RealBufferFieldOperations.power(arg, pow)

View File

@ -40,6 +40,7 @@ class RealNDField(override val shape: IntArray) :
transform: RealField.(index: IntArray, Double) -> Double transform: RealField.(index: IntArray, Double) -> Double
): RealNDElement { ): RealNDElement {
check(arg) check(arg)
return BufferedNDFieldElement( return BufferedNDFieldElement(
this, this,
buildBuffer(arg.strides.linearSize) { offset -> buildBuffer(arg.strides.linearSize) { offset ->
@ -71,16 +72,18 @@ class RealNDField(override val shape: IntArray) :
override fun ln(arg: NDBuffer<Double>): RealNDElement = map(arg) { ln(it) } override fun ln(arg: NDBuffer<Double>): RealNDElement = map(arg) { ln(it) }
override fun sin(arg: NDBuffer<Double>): RealNDElement = map(arg) { sin(it) } override fun sin(arg: NDBuffer<Double>): RealNDElement = map(arg) { sin(it) }
override fun cos(arg: NDBuffer<Double>): RealNDElement = map(arg) { cos(it) } override fun cos(arg: NDBuffer<Double>): RealNDElement = map(arg) { cos(it) }
override fun tan(arg: NDBuffer<Double>): RealNDElement = map(arg) { tan(it) }
override fun asin(arg: NDBuffer<Double>): RealNDElement = map(arg) { asin(it) }
override fun acos(arg: NDBuffer<Double>): RealNDElement = map(arg) { acos(it) }
override fun atan(arg: NDBuffer<Double>): RealNDElement = map(arg) { atan(it) }
override fun tan(arg: NDBuffer<Double>): NDBuffer<Double> = map(arg) { tan(it) } override fun sinh(arg: NDBuffer<Double>): RealNDElement = map(arg) { sinh(it) }
override fun cosh(arg: NDBuffer<Double>): RealNDElement = map(arg) { cosh(it) }
override fun asin(arg: NDBuffer<Double>): NDBuffer<Double> = map(arg) { asin(it) } override fun tanh(arg: NDBuffer<Double>): RealNDElement = map(arg) { tanh(it) }
override fun asinh(arg: NDBuffer<Double>): RealNDElement = map(arg) { asinh(it) }
override fun acos(arg: NDBuffer<Double>): NDBuffer<Double> = map(arg) { acos(it) } override fun acosh(arg: NDBuffer<Double>): RealNDElement = map(arg) { acosh(it) }
override fun atanh(arg: NDBuffer<Double>): RealNDElement = map(arg) { atanh(it) }
override fun atan(arg: NDBuffer<Double>): NDBuffer<Double> = map(arg) { atan(it) }
} }
@ -130,6 +133,5 @@ operator fun RealNDElement.minus(arg: Double): RealNDElement =
/** /**
* Produce a context for n-dimensional operations inside this real field * Produce a context for n-dimensional operations inside this real field
*/ */
inline fun <R> RealField.nd(vararg shape: Int, action: RealNDField.() -> R): R {
return NDField.real(*shape).run(action) inline fun <R> RealField.nd(vararg shape: Int, action: RealNDField.() -> R): R = NDField.real(*shape).run(action)
}

View File

@ -1,5 +1,8 @@
package scientifik.kmath.structures package scientifik.kmath.structures
import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.contract
/** /**
* Specialized [MutableBuffer] implementation over [ShortArray]. * Specialized [MutableBuffer] implementation over [ShortArray].
* *
@ -8,17 +11,16 @@ package scientifik.kmath.structures
inline class ShortBuffer(val array: ShortArray) : MutableBuffer<Short> { inline class ShortBuffer(val array: ShortArray) : MutableBuffer<Short> {
override val size: Int get() = array.size override val size: Int get() = array.size
override fun get(index: Int): Short = array[index] override operator fun get(index: Int): Short = array[index]
override fun set(index: Int, value: Short) { override operator fun set(index: Int, value: Short) {
array[index] = value array[index] = value
} }
override fun iterator(): ShortIterator = array.iterator() override operator fun iterator(): ShortIterator = array.iterator()
override fun copy(): MutableBuffer<Short> = override fun copy(): MutableBuffer<Short> =
ShortBuffer(array.copyOf()) ShortBuffer(array.copyOf())
} }
/** /**
@ -28,7 +30,10 @@ inline class ShortBuffer(val array: ShortArray) : MutableBuffer<Short> {
* The function [init] is called for each array element sequentially starting from the first one. * The function [init] is called for each array element sequentially starting from the first one.
* It should return the value for an buffer element given its index. * It should return the value for an buffer element given its index.
*/ */
inline fun ShortBuffer(size: Int, init: (Int) -> Short): ShortBuffer = ShortBuffer(ShortArray(size) { init(it) }) inline fun ShortBuffer(size: Int, init: (Int) -> Short): ShortBuffer {
contract { callsInPlace(init) }
return ShortBuffer(ShortArray(size) { init(it) })
}
/** /**
* Returns a new [ShortBuffer] of given elements. * Returns a new [ShortBuffer] of given elements.

View File

@ -6,12 +6,12 @@ package scientifik.kmath.structures
interface Structure1D<T> : NDStructure<T>, Buffer<T> { interface Structure1D<T> : NDStructure<T>, Buffer<T> {
override val dimension: Int get() = 1 override val dimension: Int get() = 1
override fun get(index: IntArray): T { override operator fun get(index: IntArray): T {
if (index.size != 1) error("Index dimension mismatch. Expected 1 but found ${index.size}") require(index.size == 1) { "Index dimension mismatch. Expected 1 but found ${index.size}" }
return get(index[0]) return get(index[0])
} }
override fun iterator(): Iterator<T> = (0 until size).asSequence().map { get(it) }.iterator() override operator fun iterator(): Iterator<T> = (0 until size).asSequence().map { get(it) }.iterator()
} }
/** /**
@ -22,7 +22,7 @@ private inline class Structure1DWrapper<T>(val structure: NDStructure<T>) : Stru
override val shape: IntArray get() = structure.shape override val shape: IntArray get() = structure.shape
override val size: Int get() = structure.shape[0] override val size: Int get() = structure.shape[0]
override fun get(index: Int): T = structure[index] override operator fun get(index: Int): T = structure[index]
override fun elements(): Sequence<Pair<IntArray, T>> = structure.elements() override fun elements(): Sequence<Pair<IntArray, T>> = structure.elements()
} }
@ -39,7 +39,7 @@ private inline class Buffer1DWrapper<T>(val buffer: Buffer<T>) : Structure1D<T>
override fun elements(): Sequence<Pair<IntArray, T>> = override fun elements(): Sequence<Pair<IntArray, T>> =
asSequence().mapIndexed { index, value -> intArrayOf(index) to value } asSequence().mapIndexed { index, value -> intArrayOf(index) to value }
override fun get(index: Int): T = buffer[index] override operator fun get(index: Int): T = buffer[index]
} }
/** /**

View File

@ -9,8 +9,8 @@ interface Structure2D<T> : NDStructure<T> {
operator fun get(i: Int, j: Int): T operator fun get(i: Int, j: Int): T
override fun get(index: IntArray): T { override operator fun get(index: IntArray): T {
if (index.size != 2) error("Index dimension mismatch. Expected 2 but found ${index.size}") require(index.size == 2) { "Index dimension mismatch. Expected 2 but found ${index.size}" }
return get(index[0], index[1]) return get(index[0], index[1])
} }
@ -39,10 +39,10 @@ interface Structure2D<T> : NDStructure<T> {
* A 2D wrapper for nd-structure * A 2D wrapper for nd-structure
*/ */
private inline class Structure2DWrapper<T>(val structure: NDStructure<T>) : Structure2D<T> { private inline class Structure2DWrapper<T>(val structure: NDStructure<T>) : Structure2D<T> {
override fun get(i: Int, j: Int): T = structure[i, j]
override val shape: IntArray get() = structure.shape override val shape: IntArray get() = structure.shape
override operator fun get(i: Int, j: Int): T = structure[i, j]
override fun elements(): Sequence<Pair<IntArray, T>> = structure.elements() override fun elements(): Sequence<Pair<IntArray, T>> = structure.elements()
} }

View File

@ -3,6 +3,7 @@ package scientifik.kmath.expressions
import scientifik.kmath.operations.Complex import scientifik.kmath.operations.Complex
import scientifik.kmath.operations.ComplexField import scientifik.kmath.operations.ComplexField
import scientifik.kmath.operations.RealField import scientifik.kmath.operations.RealField
import scientifik.kmath.operations.invoke
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
@ -10,10 +11,12 @@ class ExpressionFieldTest {
@Test @Test
fun testExpression() { fun testExpression() {
val context = FunctionalExpressionField(RealField) val context = FunctionalExpressionField(RealField)
val expression = with(context) {
val expression = context {
val x = variable("x", 2.0) val x = variable("x", 2.0)
x * x + 2 * x + one x * x + 2 * x + one
} }
assertEquals(expression("x" to 1.0), 4.0) assertEquals(expression("x" to 1.0), 4.0)
assertEquals(expression(), 9.0) assertEquals(expression(), 9.0)
} }
@ -21,10 +24,12 @@ class ExpressionFieldTest {
@Test @Test
fun testComplex() { fun testComplex() {
val context = FunctionalExpressionField(ComplexField) val context = FunctionalExpressionField(ComplexField)
val expression = with(context) {
val expression = context {
val x = variable("x", Complex(2.0, 0.0)) val x = variable("x", Complex(2.0, 0.0))
x * x + 2 * x + one x * x + 2 * x + one
} }
assertEquals(expression("x" to Complex(1.0, 0.0)), Complex(4.0, 0.0)) assertEquals(expression("x" to Complex(1.0, 0.0)), Complex(4.0, 0.0))
assertEquals(expression(), Complex(9.0, 0.0)) assertEquals(expression(), Complex(9.0, 0.0))
} }

View File

@ -7,7 +7,6 @@ import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
class MatrixTest { class MatrixTest {
@Test @Test
fun testTranspose() { fun testTranspose() {
val matrix = MatrixContext.real.one(3, 3) val matrix = MatrixContext.real.one(3, 3)
@ -51,6 +50,7 @@ class MatrixTest {
fun test2DDot() { fun test2DDot() {
val firstMatrix = NDStructure.auto(2, 3) { (i, j) -> (i + j).toDouble() }.as2D() val firstMatrix = NDStructure.auto(2, 3) { (i, j) -> (i + j).toDouble() }.as2D()
val secondMatrix = NDStructure.auto(3, 2) { (i, j) -> (i + j).toDouble() }.as2D() val secondMatrix = NDStructure.auto(3, 2) { (i, j) -> (i + j).toDouble() }.as2D()
MatrixContext.real.run { MatrixContext.real.run {
// val firstMatrix = produce(2, 3) { i, j -> (i + j).toDouble() } // val firstMatrix = produce(2, 3) { i, j -> (i + j).toDouble() }
// val secondMatrix = produce(3, 2) { i, j -> (i + j).toDouble() } // val secondMatrix = produce(3, 2) { i, j -> (i + j).toDouble() }

View File

@ -1,9 +1,13 @@
package scientifik.kmath.operations package scientifik.kmath.operations
import scientifik.kmath.operations.internal.RingVerifier
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
class BigIntAlgebraTest { internal class BigIntAlgebraTest {
@Test
fun verify() = BigIntField { RingVerifier(this, +"42", +"10", +"-12", 10).verify() }
@Test @Test
fun testKBigIntegerRingSum() { fun testKBigIntegerRingSum() {
val res = BigIntField { val res = BigIntField {

View File

@ -0,0 +1,77 @@
package scientifik.kmath.operations
import scientifik.kmath.operations.internal.FieldVerifier
import kotlin.math.PI
import kotlin.math.abs
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertTrue
internal class ComplexFieldTest {
@Test
fun verify() = ComplexField { FieldVerifier(this, 42.0 * i, 66.0 + 28 * i, 2.0 + 0 * i, 5).verify() }
@Test
fun testAddition() {
assertEquals(Complex(42, 42), ComplexField { Complex(16, 16) + Complex(26, 26) })
assertEquals(Complex(42, 16), ComplexField { Complex(16, 16) + 26 })
assertEquals(Complex(42, 16), ComplexField { 26 + Complex(16, 16) })
}
@Test
fun testSubtraction() {
assertEquals(Complex(42, 42), ComplexField { Complex(86, 55) - Complex(44, 13) })
assertEquals(Complex(42, 56), ComplexField { Complex(86, 56) - 44 })
assertEquals(Complex(42, 56), ComplexField { 86 - Complex(44, -56) })
}
@Test
fun testMultiplication() {
assertEquals(Complex(42, 42), ComplexField { Complex(4.2, 0) * Complex(10, 10) })
assertEquals(Complex(42, 21), ComplexField { Complex(4.2, 2.1) * 10 })
assertEquals(Complex(42, 21), ComplexField { 10 * Complex(4.2, 2.1) })
}
@Test
fun testDivision() {
assertEquals(Complex(42, 42), ComplexField { Complex(0, 168) / Complex(2, 2) })
assertEquals(Complex(42, 56), ComplexField { Complex(86, 56) - 44 })
assertEquals(Complex(42, 56), ComplexField { 86 - Complex(44, -56) })
assertEquals(Complex(Double.NaN, Double.NaN), ComplexField { Complex(1, 1) / Complex(Double.NaN, Double.NaN) })
assertEquals(Complex(Double.NaN, Double.NaN), ComplexField { Complex(1, 1) / Complex(0, 0) })
}
@Test
fun testSine() {
assertEquals(ComplexField { i * sinh(one) }, ComplexField { sin(i) })
assertEquals(ComplexField { i * sinh(PI.toComplex()) }, ComplexField { sin(i * PI.toComplex()) })
}
@Test
fun testInverseSine() {
assertEquals(Complex(0, -0.0), ComplexField { asin(zero) })
assertTrue(abs(ComplexField { i * asinh(one) }.r - ComplexField { asin(i) }.r) < 0.000000000000001)
}
@Test
fun testInverseHyperbolicSine() {
assertEquals(
ComplexField { i * PI.toComplex() / 2 },
ComplexField { asinh(i) })
}
@Test
fun testPower() {
assertEquals(ComplexField.zero, ComplexField { zero pow 2 })
assertEquals(ComplexField.zero, ComplexField { zero pow 2 })
assertEquals(
ComplexField { i * 8 }.let { it.im.toInt() to it.re.toInt() },
ComplexField { Complex(2, 2) pow 2 }.let { it.im.toInt() to it.re.toInt() })
}
@Test
fun testNorm() {
assertEquals(2.toComplex(), ComplexField { norm(2 * i) })
}
}

View File

@ -0,0 +1,38 @@
package scientifik.kmath.operations
import kotlin.test.Test
import kotlin.test.assertEquals
internal class ComplexTest {
@Test
fun conjugate() {
assertEquals(
Complex(0, -42), (ComplexField.i * 42).conjugate
)
}
@Test
fun reciprocal() {
assertEquals(Complex(0.5, -0.0), 2.toComplex().reciprocal)
}
@Test
fun r() {
assertEquals(kotlin.math.sqrt(2.0), (ComplexField.i + 1.0.toComplex()).r)
}
@Test
fun theta() {
assertEquals(0.0, 1.toComplex().theta)
}
@Test
fun toComplex() {
assertEquals(Complex(42, 0), 42.toComplex())
assertEquals(Complex(42.0, 0), 42.0.toComplex())
assertEquals(Complex(42f, 0), 42f.toComplex())
assertEquals(Complex(42.0, 0), 42.0.toComplex())
assertEquals(Complex(42.toByte(), 0), 42.toByte().toComplex())
assertEquals(Complex(42.toShort(), 0), 42.toShort().toComplex())
}
}

View File

@ -1,14 +1,16 @@
package scientifik.kmath.operations package scientifik.kmath.operations
import scientifik.kmath.operations.internal.FieldVerifier
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
class RealFieldTest { internal class RealFieldTest {
@Test
fun verify() = FieldVerifier(RealField, 42.0, 66.0, 2.0, 5).verify()
@Test @Test
fun testSqrt() { fun testSqrt() {
val sqrt = RealField { val sqrt = RealField { sqrt(25 * one) }
sqrt(25 * one)
}
assertEquals(5.0, sqrt) assertEquals(5.0, sqrt)
} }
} }

View File

@ -0,0 +1,9 @@
package scientifik.kmath.operations.internal
import scientifik.kmath.operations.Algebra
internal interface AlgebraicVerifier<T, out A> where A : Algebra<T> {
val algebra: A
fun verify()
}

View File

@ -0,0 +1,24 @@
package scientifik.kmath.operations.internal
import scientifik.kmath.operations.Field
import scientifik.kmath.operations.invoke
import kotlin.test.assertEquals
import kotlin.test.assertNotEquals
internal class FieldVerifier<T>(override val algebra: Field<T>, a: T, b: T, c: T, x: Number) :
RingVerifier<T>(algebra, a, b, c, x) {
override fun verify() {
super.verify()
algebra {
assertNotEquals(a / b, b / a, "Division in $algebra is not anti-commutative.")
assertNotEquals((a / b) / c, a / (b / c), "Division in $algebra is associative.")
assertEquals((a + b) / c, (a / c) + (b / c), "Division in $algebra is not right-distributive.")
assertEquals(a, a / one, "$one in $algebra is not neutral division element.")
assertEquals(one, one / a * a, "$algebra does not provide single reciprocal element.")
assertEquals(zero / a, zero, "$zero in $algebra is not left neutral element for division.")
assertEquals(-one, a / (-a), "Division by sign reversal element in $algebra does not give ${-one}.")
}
}
}

View File

@ -0,0 +1,28 @@
package scientifik.kmath.operations.internal
import scientifik.kmath.operations.Ring
import scientifik.kmath.operations.invoke
import kotlin.test.assertEquals
internal open class RingVerifier<T>(override val algebra: Ring<T>, a: T, b: T, c: T, x: Number) :
SpaceVerifier<T>(algebra, a, b, c, x) {
override fun verify() {
super.verify()
algebra {
assertEquals(a * b, a * b, "Multiplication in $algebra is not commutative.")
assertEquals(a * b * c, a * (b * c), "Multiplication in $algebra is not associative.")
assertEquals(c * (a + b), (c * a) + (c * b), "Multiplication in $algebra is not distributive.")
assertEquals(a * one, one * a, "$one in $algebra is not a neutral multiplication element.")
assertEquals(a, one * a, "$one in $algebra is not a neutral multiplication element.")
assertEquals(a, a * one, "$one in $algebra is not a neutral multiplication element.")
assertEquals(a, one * a, "$one in $algebra is not a neutral multiplication element.")
assertEquals(a, a * one * one, "Multiplication by $one in $algebra is not idempotent.")
assertEquals(a, a * one * one * one, "Multiplication by $one in $algebra is not idempotent.")
assertEquals(a, a * one * one * one * one, "Multiplication by $one in $algebra is not idempotent.")
assertEquals(zero, a * zero, "Multiplication by $zero in $algebra doesn't give $zero.")
assertEquals(zero, zero * a, "Multiplication by $zero in $algebra doesn't give $zero.")
assertEquals(a * zero, a * zero, "Multiplication by $zero in $algebra doesn't give $zero.")
}
}
}

View File

@ -0,0 +1,33 @@
package scientifik.kmath.operations.internal
import scientifik.kmath.operations.Space
import scientifik.kmath.operations.invoke
import kotlin.test.assertEquals
import kotlin.test.assertNotEquals
internal open class SpaceVerifier<T>(
override val algebra: Space<T>,
val a: T,
val b: T,
val c: T,
val x: Number
) :
AlgebraicVerifier<T, Space<T>> {
override fun verify() {
algebra {
assertEquals(a + b, b + a, "Addition in $algebra is not commutative.")
assertEquals(a + b + c, a + (b + c), "Addition in $algebra is not associative.")
assertEquals(x * (a + b), x * a + x * b, "Addition in $algebra is not distributive.")
assertEquals((a + b) * x, a * x + b * x, "Addition in $algebra is not distributive.")
assertEquals(a + zero, zero + a, "$zero in $algebra is not a neutral addition element.")
assertEquals(a, a + zero, "$zero in $algebra is not a neutral addition element.")
assertEquals(a, zero + a, "$zero in $algebra is not a neutral addition element.")
assertEquals(a - b, -(b - a), "Subtraction in $algebra is not anti-commutative.")
assertNotEquals(a - b - c, a - (b - c), "Subtraction in $algebra is associative.")
assertEquals(x * (a - b), x * a - x * b, "Subtraction in $algebra is not distributive.")
assertEquals(a, a - zero, "$zero in $algebra is not a neutral addition element.")
assertEquals(a * x, x * a, "Multiplication by scalar in $algebra is not commutative.")
assertEquals(x * (a + b), (x * a) + (x * b), "Multiplication by scalar in $algebra is not distributive.")
}
}
}

View File

@ -1,6 +1,7 @@
package scientifik.kmath.structures package scientifik.kmath.structures
import scientifik.kmath.operations.Norm import scientifik.kmath.operations.Norm
import scientifik.kmath.operations.invoke
import scientifik.kmath.structures.NDElement.Companion.real2D import scientifik.kmath.structures.NDElement.Companion.real2D
import kotlin.math.abs import kotlin.math.abs
import kotlin.math.pow import kotlin.math.pow
@ -56,17 +57,12 @@ class NumberNDFieldTest {
} }
object L2Norm : Norm<NDStructure<out Number>, Double> { object L2Norm : Norm<NDStructure<out Number>, Double> {
override fun norm(arg: NDStructure<out Number>): Double { override fun norm(arg: NDStructure<out Number>): Double =
return kotlin.math.sqrt(arg.elements().sumByDouble { it.second.toDouble() }) kotlin.math.sqrt(arg.elements().sumByDouble { it.second.toDouble() })
}
} }
@Test @Test
fun testInternalContext() { fun testInternalContext() {
NDField.real(*array1.shape).run { (NDField.real(*array1.shape)) { with(L2Norm) { 1 + norm(array1) + exp(array2) } }
with(L2Norm) {
1 + norm(array1) + exp(array2)
}
}
} }
} }

View File

@ -17,10 +17,10 @@ object JBigIntegerField : Field<BigInteger> {
override fun number(value: Number): BigInteger = BigInteger.valueOf(value.toLong()) override fun number(value: Number): BigInteger = BigInteger.valueOf(value.toLong())
override fun divide(a: BigInteger, b: BigInteger): BigInteger = a.div(b) override fun divide(a: BigInteger, b: BigInteger): BigInteger = a.div(b)
override fun add(a: BigInteger, b: BigInteger): BigInteger = a.add(b) override fun add(a: BigInteger, b: BigInteger): BigInteger = a.add(b)
override fun BigInteger.minus(b: BigInteger): BigInteger = this.subtract(b) override operator fun BigInteger.minus(b: BigInteger): BigInteger = subtract(b)
override fun multiply(a: BigInteger, k: Number): BigInteger = a.multiply(k.toInt().toBigInteger()) override fun multiply(a: BigInteger, k: Number): BigInteger = a.multiply(k.toInt().toBigInteger())
override fun multiply(a: BigInteger, b: BigInteger): BigInteger = a.multiply(b) override fun multiply(a: BigInteger, b: BigInteger): BigInteger = a.multiply(b)
override fun BigInteger.unaryMinus(): BigInteger = negate() override operator fun BigInteger.unaryMinus(): BigInteger = negate()
} }
/** /**
@ -38,7 +38,7 @@ abstract class JBigDecimalFieldBase internal constructor(val mathContext: MathCo
get() = BigDecimal.ONE get() = BigDecimal.ONE
override fun add(a: BigDecimal, b: BigDecimal): BigDecimal = a.add(b) override fun add(a: BigDecimal, b: BigDecimal): BigDecimal = a.add(b)
override fun BigDecimal.minus(b: BigDecimal): BigDecimal = subtract(b) override operator fun BigDecimal.minus(b: BigDecimal): BigDecimal = subtract(b)
override fun number(value: Number): BigDecimal = BigDecimal.valueOf(value.toDouble()) override fun number(value: Number): BigDecimal = BigDecimal.valueOf(value.toDouble())
override fun multiply(a: BigDecimal, k: Number): BigDecimal = override fun multiply(a: BigDecimal, k: Number): BigDecimal =
@ -48,8 +48,7 @@ abstract class JBigDecimalFieldBase internal constructor(val mathContext: MathCo
override fun divide(a: BigDecimal, b: BigDecimal): BigDecimal = a.divide(b, mathContext) override fun divide(a: BigDecimal, b: BigDecimal): BigDecimal = a.divide(b, mathContext)
override fun power(arg: BigDecimal, pow: Number): BigDecimal = arg.pow(pow.toInt(), mathContext) override fun power(arg: BigDecimal, pow: Number): BigDecimal = arg.pow(pow.toInt(), mathContext)
override fun sqrt(arg: BigDecimal): BigDecimal = arg.sqrt(mathContext) override fun sqrt(arg: BigDecimal): BigDecimal = arg.sqrt(mathContext)
override fun BigDecimal.unaryMinus(): BigDecimal = negate(mathContext) override operator fun BigDecimal.unaryMinus(): BigDecimal = negate(mathContext)
} }
/** /**

View File

@ -4,20 +4,27 @@ plugins {
} }
kotlin.sourceSets { kotlin.sourceSets {
all {
with(languageSettings) {
useExperimentalAnnotation("kotlin.contracts.ExperimentalContracts")
useExperimentalAnnotation("kotlinx.coroutines.InternalCoroutinesApi")
useExperimentalAnnotation("kotlinx.coroutines.ExperimentalCoroutinesApi")
useExperimentalAnnotation("kotlinx.coroutines.FlowPreview")
}
}
commonMain { commonMain {
dependencies { dependencies {
api(project(":kmath-core")) api(project(":kmath-core"))
api("org.jetbrains.kotlinx:kotlinx-coroutines-core-common:${Scientifik.coroutinesVersion}") api("org.jetbrains.kotlinx:kotlinx-coroutines-core-common:${Scientifik.coroutinesVersion}")
} }
} }
jvmMain { jvmMain {
dependencies { dependencies { api("org.jetbrains.kotlinx:kotlinx-coroutines-core:${Scientifik.coroutinesVersion}") }
api("org.jetbrains.kotlinx:kotlinx-coroutines-core:${Scientifik.coroutinesVersion}")
}
} }
jsMain { jsMain {
dependencies { dependencies { api("org.jetbrains.kotlinx:kotlinx-coroutines-core-js:${Scientifik.coroutinesVersion}") }
api("org.jetbrains.kotlinx:kotlinx-coroutines-core-js:${Scientifik.coroutinesVersion}")
}
} }
} }

View File

@ -16,9 +16,9 @@
package scientifik.kmath.chains package scientifik.kmath.chains
import kotlinx.coroutines.InternalCoroutinesApi
import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.FlowCollector import kotlinx.coroutines.flow.FlowCollector
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.sync.Mutex import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock import kotlinx.coroutines.sync.withLock
@ -37,14 +37,8 @@ interface Chain<out R> : Flow<R> {
*/ */
fun fork(): Chain<R> fun fork(): Chain<R>
@OptIn(InternalCoroutinesApi::class) override suspend fun collect(collector: FlowCollector<R>): Unit =
override suspend fun collect(collector: FlowCollector<R>) { flow { while (true) emit(next()) }.collect(collector)
kotlinx.coroutines.flow.flow {
while (true) {
emit(next())
}
}.collect(collector)
}
companion object companion object
} }
@ -139,9 +133,10 @@ fun <T, R> Chain<T>.map(func: suspend (T) -> R): Chain<R> = object : Chain<R> {
fun <T> Chain<T>.filter(block: (T) -> Boolean): Chain<T> = object : Chain<T> { fun <T> Chain<T>.filter(block: (T) -> Boolean): Chain<T> = object : Chain<T> {
override suspend fun next(): T { override suspend fun next(): T {
var next: T var next: T
do {
next = this@filter.next() do next = this@filter.next()
} while (!block(next)) while (!block(next))
return next return next
} }
@ -159,6 +154,7 @@ fun <T, R> Chain<T>.collect(mapper: suspend (Chain<T>) -> R): Chain<R> = object
fun <T, S, R> Chain<T>.collectWithState(state: S, stateFork: (S) -> S, mapper: suspend S.(Chain<T>) -> R): Chain<R> = fun <T, S, R> Chain<T>.collectWithState(state: S, stateFork: (S) -> S, mapper: suspend S.(Chain<T>) -> R): Chain<R> =
object : Chain<R> { object : Chain<R> {
override suspend fun next(): R = state.mapper(this@collectWithState) override suspend fun next(): R = state.mapper(this@collectWithState)
override fun fork(): Chain<R> = override fun fork(): Chain<R> =
this@collectWithState.fork().collectWithState(stateFork(state), stateFork, mapper) this@collectWithState.fork().collectWithState(stateFork(state), stateFork, mapper)
} }
@ -168,6 +164,5 @@ fun <T, S, R> Chain<T>.collectWithState(state: S, stateFork: (S) -> S, mapper: s
*/ */
fun <T, U, R> Chain<T>.zip(other: Chain<U>, block: suspend (T, U) -> R): Chain<R> = object : Chain<R> { fun <T, U, R> Chain<T>.zip(other: Chain<U>, block: suspend (T, U) -> R): Chain<R> = object : Chain<R> {
override suspend fun next(): R = block(this@zip.next(), other.next()) override suspend fun next(): R = block(this@zip.next(), other.next())
override fun fork(): Chain<R> = this@zip.fork().zip(other.fork(), block) override fun fork(): Chain<R> = this@zip.fork().zip(other.fork(), block)
} }

View File

@ -7,15 +7,15 @@ import kotlinx.coroutines.flow.scan
import kotlinx.coroutines.flow.scanReduce import kotlinx.coroutines.flow.scanReduce
import scientifik.kmath.operations.Space import scientifik.kmath.operations.Space
import scientifik.kmath.operations.SpaceOperations import scientifik.kmath.operations.SpaceOperations
import scientifik.kmath.operations.invoke
@ExperimentalCoroutinesApi @ExperimentalCoroutinesApi
fun <T> Flow<T>.cumulativeSum(space: SpaceOperations<T>): Flow<T> = with(space) { fun <T> Flow<T>.cumulativeSum(space: SpaceOperations<T>): Flow<T> = space {
scanReduce { sum: T, element: T -> sum + element } scanReduce { sum: T, element: T -> sum + element }
} }
@ExperimentalCoroutinesApi @ExperimentalCoroutinesApi
fun <T> Flow<T>.mean(space: Space<T>): Flow<T> = with(space) { fun <T> Flow<T>.mean(space: Space<T>): Flow<T> = space {
class Accumulator(var sum: T, var num: Int) class Accumulator(var sum: T, var num: Int)
scan(Accumulator(zero, 0)) { sum, element -> scan(Accumulator(zero, 0)) { sum, element ->

View File

@ -3,6 +3,7 @@ package scientifik.kmath.coroutines
import kotlinx.coroutines.* import kotlinx.coroutines.*
import kotlinx.coroutines.channels.produce import kotlinx.coroutines.channels.produce
import kotlinx.coroutines.flow.* import kotlinx.coroutines.flow.*
import kotlin.contracts.contract
val Dispatchers.Math: CoroutineDispatcher val Dispatchers.Math: CoroutineDispatcher
get() = Default get() = Default
@ -23,15 +24,11 @@ internal class LazyDeferred<T>(val dispatcher: CoroutineDispatcher, val block: s
} }
class AsyncFlow<T> internal constructor(internal val deferredFlow: Flow<LazyDeferred<T>>) : Flow<T> { class AsyncFlow<T> internal constructor(internal val deferredFlow: Flow<LazyDeferred<T>>) : Flow<T> {
@InternalCoroutinesApi
override suspend fun collect(collector: FlowCollector<T>) { override suspend fun collect(collector: FlowCollector<T>) {
deferredFlow.collect { deferredFlow.collect { collector.emit((it.await())) }
collector.emit((it.await()))
}
} }
} }
@FlowPreview
fun <T, R> Flow<T>.async( fun <T, R> Flow<T>.async(
dispatcher: CoroutineDispatcher = Dispatchers.Default, dispatcher: CoroutineDispatcher = Dispatchers.Default,
block: suspend CoroutineScope.(T) -> R block: suspend CoroutineScope.(T) -> R
@ -42,7 +39,6 @@ fun <T, R> Flow<T>.async(
return AsyncFlow(flow) return AsyncFlow(flow)
} }
@FlowPreview
fun <T, R> AsyncFlow<T>.map(action: (T) -> R): AsyncFlow<R> = fun <T, R> AsyncFlow<T>.map(action: (T) -> R): AsyncFlow<R> =
AsyncFlow(deferredFlow.map { input -> AsyncFlow(deferredFlow.map { input ->
//TODO add function composition //TODO add function composition
@ -52,10 +48,9 @@ fun <T, R> AsyncFlow<T>.map(action: (T) -> R): AsyncFlow<R> =
} }
}) })
@ExperimentalCoroutinesApi
@FlowPreview
suspend fun <T> AsyncFlow<T>.collect(concurrency: Int, collector: FlowCollector<T>) { suspend fun <T> AsyncFlow<T>.collect(concurrency: Int, collector: FlowCollector<T>) {
require(concurrency >= 1) { "Buffer size should be more than 1, but was $concurrency" } require(concurrency >= 1) { "Buffer size should be more than 1, but was $concurrency" }
coroutineScope { coroutineScope {
//Starting up to N deferred coroutines ahead of time //Starting up to N deferred coroutines ahead of time
val channel = produce(capacity = concurrency - 1) { val channel = produce(capacity = concurrency - 1) {
@ -81,21 +76,18 @@ suspend fun <T> AsyncFlow<T>.collect(concurrency: Int, collector: FlowCollector<
} }
} }
@ExperimentalCoroutinesApi suspend inline fun <T> AsyncFlow<T>.collect(concurrency: Int, crossinline action: suspend (value: T) -> Unit) {
@FlowPreview contract { callsInPlace(action) }
suspend fun <T> AsyncFlow<T>.collect(concurrency: Int, action: suspend (value: T) -> Unit) {
collect(concurrency, object : FlowCollector<T> { collect(concurrency, object : FlowCollector<T> {
override suspend fun emit(value: T): Unit = action(value) override suspend fun emit(value: T): Unit = action(value)
}) })
} }
@ExperimentalCoroutinesApi inline fun <T, R> Flow<T>.mapParallel(
@FlowPreview
fun <T, R> Flow<T>.mapParallel(
dispatcher: CoroutineDispatcher = Dispatchers.Default, dispatcher: CoroutineDispatcher = Dispatchers.Default,
transform: suspend (T) -> R crossinline transform: suspend (T) -> R
): Flow<R> { ): Flow<R> {
return flatMapMerge { value -> contract { callsInPlace(transform) }
flow { emit(transform(value)) } return flatMapMerge { value -> flow { emit(transform(value)) } }.flowOn(dispatcher)
}.flowOn(dispatcher)
} }

View File

@ -20,7 +20,7 @@ class RingBuffer<T>(
override var size: Int = size override var size: Int = size
private set private set
override fun get(index: Int): T { override operator fun get(index: Int): T {
require(index >= 0) { "Index must be positive" } require(index >= 0) { "Index must be positive" }
require(index < size) { "Index $index is out of circular buffer size $size" } require(index < size) { "Index $index is out of circular buffer size $size" }
return buffer[startIndex.forward(index)] as T return buffer[startIndex.forward(index)] as T
@ -31,15 +31,13 @@ class RingBuffer<T>(
/** /**
* Iterator could provide wrong results if buffer is changed in initialization (iteration is safe) * Iterator could provide wrong results if buffer is changed in initialization (iteration is safe)
*/ */
override fun iterator(): Iterator<T> = object : AbstractIterator<T>() { override operator fun iterator(): Iterator<T> = object : AbstractIterator<T>() {
private var count = size private var count = size
private var index = startIndex private var index = startIndex
val copy = buffer.copy() val copy = buffer.copy()
override fun computeNext() { override fun computeNext() {
if (count == 0) { if (count == 0) done() else {
done()
} else {
setNext(copy[index] as T) setNext(copy[index] as T)
index = index.forward(1) index = index.forward(1)
count-- count--

View File

@ -1,7 +1,6 @@
package scientifik.kmath.chains package scientifik.kmath.chains
import kotlinx.coroutines.runBlocking import kotlinx.coroutines.runBlocking
import kotlin.sequences.Sequence
/** /**
* Represent a chain as regular iterator (uses blocking calls) * Represent a chain as regular iterator (uses blocking calls)
@ -15,6 +14,4 @@ operator fun <R> Chain<R>.iterator(): Iterator<R> = object : Iterator<R> {
/** /**
* Represent a chain as a sequence * Represent a chain as a sequence
*/ */
fun <R> Chain<R>.asSequence(): Sequence<R> = object : Sequence<R> { fun <R> Chain<R>.asSequence(): Sequence<R> = Sequence { this@asSequence.iterator() }
override fun iterator(): Iterator<R> = this@asSequence.iterator()
}

View File

@ -18,7 +18,7 @@ class LazyNDStructure<T>(
suspend fun await(index: IntArray): T = deferred(index).await() suspend fun await(index: IntArray): T = deferred(index).await()
override fun get(index: IntArray): T = runBlocking { override operator fun get(index: IntArray): T = runBlocking {
deferred(index).await() deferred(index).await()
} }
@ -52,10 +52,12 @@ suspend fun <T> NDStructure<T>.await(index: IntArray): T =
/** /**
* PENDING would benefit from KEEP-176 * PENDING would benefit from KEEP-176
*/ */
fun <T, R> NDStructure<T>.mapAsyncIndexed( inline fun <T, R> NDStructure<T>.mapAsyncIndexed(
scope: CoroutineScope, scope: CoroutineScope,
function: suspend (T, index: IntArray) -> R crossinline function: suspend (T, index: IntArray) -> R
): LazyNDStructure<R> = LazyNDStructure(scope, shape) { index -> function(get(index), index) } ): LazyNDStructure<R> = LazyNDStructure(scope, shape) { index -> function(get(index), index) }
fun <T, R> NDStructure<T>.mapAsync(scope: CoroutineScope, function: suspend (T) -> R): LazyNDStructure<R> = inline fun <T, R> NDStructure<T>.mapAsync(
LazyNDStructure(scope, shape) { index -> function(get(index)) } scope: CoroutineScope,
crossinline function: suspend (T) -> R
): LazyNDStructure<R> = LazyNDStructure(scope, shape) { index -> function(get(index)) }

View File

@ -4,7 +4,9 @@ import scientifik.kmath.linear.GenericMatrixContext
import scientifik.kmath.linear.MatrixContext import scientifik.kmath.linear.MatrixContext
import scientifik.kmath.linear.Point import scientifik.kmath.linear.Point
import scientifik.kmath.linear.transpose import scientifik.kmath.linear.transpose
import scientifik.kmath.operations.RealField
import scientifik.kmath.operations.Ring import scientifik.kmath.operations.Ring
import scientifik.kmath.operations.invoke
import scientifik.kmath.structures.Matrix import scientifik.kmath.structures.Matrix
import scientifik.kmath.structures.Structure2D import scientifik.kmath.structures.Structure2D
@ -42,7 +44,7 @@ inline class DMatrixWrapper<T, R : Dimension, C : Dimension>(
val structure: Structure2D<T> val structure: Structure2D<T>
) : DMatrix<T, R, C> { ) : DMatrix<T, R, C> {
override val shape: IntArray get() = structure.shape override val shape: IntArray get() = structure.shape
override fun get(i: Int, j: Int): T = structure[i, j] override operator fun get(i: Int, j: Int): T = structure[i, j]
} }
/** /**
@ -70,9 +72,9 @@ inline class DPointWrapper<T, D : Dimension>(val point: Point<T>) :
DPoint<T, D> { DPoint<T, D> {
override val size: Int get() = point.size override val size: Int get() = point.size
override fun get(index: Int): T = point[index] override operator fun get(index: Int): T = point[index]
override fun iterator(): Iterator<T> = point.iterator() override operator fun iterator(): Iterator<T> = point.iterator()
} }
@ -82,12 +84,14 @@ inline class DPointWrapper<T, D : Dimension>(val point: Point<T>) :
inline class DMatrixContext<T : Any, Ri : Ring<T>>(val context: GenericMatrixContext<T, Ri>) { inline class DMatrixContext<T : Any, Ri : Ring<T>>(val context: GenericMatrixContext<T, Ri>) {
inline fun <reified R : Dimension, reified C : Dimension> Matrix<T>.coerce(): DMatrix<T, R, C> { inline fun <reified R : Dimension, reified C : Dimension> Matrix<T>.coerce(): DMatrix<T, R, C> {
if (rowNum != Dimension.dim<R>().toInt()) { check(
error("Row number mismatch: expected ${Dimension.dim<R>()} but found $rowNum") rowNum == Dimension.dim<R>().toInt()
} ) { "Row number mismatch: expected ${Dimension.dim<R>()} but found $rowNum" }
if (colNum != Dimension.dim<C>().toInt()) {
error("Column number mismatch: expected ${Dimension.dim<C>()} but found $colNum") check(
} colNum == Dimension.dim<C>().toInt()
) { "Column number mismatch: expected ${Dimension.dim<C>()} but found $colNum" }
return DMatrix.coerceUnsafe(this) return DMatrix.coerceUnsafe(this)
} }
@ -97,11 +101,12 @@ inline class DMatrixContext<T : Any, Ri : Ring<T>>(val context: GenericMatrixCon
inline fun <reified R : Dimension, reified C : Dimension> produce(noinline initializer: (i: Int, j: Int) -> T): DMatrix<T, R, C> { inline fun <reified R : Dimension, reified C : Dimension> produce(noinline initializer: (i: Int, j: Int) -> T): DMatrix<T, R, C> {
val rows = Dimension.dim<R>() val rows = Dimension.dim<R>()
val cols = Dimension.dim<C>() val cols = Dimension.dim<C>()
return context.produce(rows.toInt(), cols.toInt(), initializer).coerce<R,C>() return context.produce(rows.toInt(), cols.toInt(), initializer).coerce<R, C>()
} }
inline fun <reified D : Dimension> point(noinline initializer: (Int) -> T): DPoint<T, D> { inline fun <reified D : Dimension> point(noinline initializer: (Int) -> T): DPoint<T, D> {
val size = Dimension.dim<D>() val size = Dimension.dim<D>()
return DPoint.coerceUnsafe( return DPoint.coerceUnsafe(
context.point( context.point(
size.toInt(), size.toInt(),
@ -112,37 +117,28 @@ inline class DMatrixContext<T : Any, Ri : Ring<T>>(val context: GenericMatrixCon
inline infix fun <reified R1 : Dimension, reified C1 : Dimension, reified C2 : Dimension> DMatrix<T, R1, C1>.dot( inline infix fun <reified R1 : Dimension, reified C1 : Dimension, reified C2 : Dimension> DMatrix<T, R1, C1>.dot(
other: DMatrix<T, C1, C2> other: DMatrix<T, C1, C2>
): DMatrix<T, R1, C2> { ): DMatrix<T, R1, C2> = context { this@dot dot other }.coerce()
return context.run { this@dot dot other }.coerce()
}
inline infix fun <reified R : Dimension, reified C : Dimension> DMatrix<T, R, C>.dot(vector: DPoint<T, C>): DPoint<T, R> { inline infix fun <reified R : Dimension, reified C : Dimension> DMatrix<T, R, C>.dot(vector: DPoint<T, C>): DPoint<T, R> =
return DPoint.coerceUnsafe(context.run { this@dot dot vector }) DPoint.coerceUnsafe(context { this@dot dot vector })
}
inline operator fun <reified R : Dimension, reified C : Dimension> DMatrix<T, R, C>.times(value: T): DMatrix<T, R, C> { inline operator fun <reified R : Dimension, reified C : Dimension> DMatrix<T, R, C>.times(value: T): DMatrix<T, R, C> =
return context.run { this@times.times(value) }.coerce() context { this@times.times(value) }.coerce()
}
inline operator fun <reified R : Dimension, reified C : Dimension> T.times(m: DMatrix<T, R, C>): DMatrix<T, R, C> = inline operator fun <reified R : Dimension, reified C : Dimension> T.times(m: DMatrix<T, R, C>): DMatrix<T, R, C> =
m * this m * this
inline operator fun <reified R : Dimension, reified C : Dimension> DMatrix<T, C, R>.plus(other: DMatrix<T, C, R>): DMatrix<T, C, R> =
context { this@plus + other }.coerce()
inline operator fun <reified R : Dimension, reified C : Dimension> DMatrix<T, C, R>.plus(other: DMatrix<T, C, R>): DMatrix<T, C, R> { inline operator fun <reified R : Dimension, reified C : Dimension> DMatrix<T, C, R>.minus(other: DMatrix<T, C, R>): DMatrix<T, C, R> =
return context.run { this@plus + other }.coerce() context { this@minus + other }.coerce()
}
inline operator fun <reified R : Dimension, reified C : Dimension> DMatrix<T, C, R>.minus(other: DMatrix<T, C, R>): DMatrix<T, C, R> { inline operator fun <reified R : Dimension, reified C : Dimension> DMatrix<T, C, R>.unaryMinus(): DMatrix<T, C, R> =
return context.run { this@minus + other }.coerce() context { this@unaryMinus.unaryMinus() }.coerce()
}
inline operator fun <reified R : Dimension, reified C : Dimension> DMatrix<T, C, R>.unaryMinus(): DMatrix<T, C, R> { inline fun <reified R : Dimension, reified C : Dimension> DMatrix<T, C, R>.transpose(): DMatrix<T, R, C> =
return context.run { this@unaryMinus.unaryMinus() }.coerce() context { (this@transpose as Matrix<T>).transpose() }.coerce()
}
inline fun <reified R : Dimension, reified C : Dimension> DMatrix<T, C, R>.transpose(): DMatrix<T, R, C> {
return context.run { (this@transpose as Matrix<T>).transpose() }.coerce()
}
/** /**
* A square unit matrix * A square unit matrix
@ -156,6 +152,6 @@ inline class DMatrixContext<T : Any, Ri : Ring<T>>(val context: GenericMatrixCon
} }
companion object { companion object {
val real = DMatrixContext(MatrixContext.real) val real: DMatrixContext<Double, RealField> = DMatrixContext(MatrixContext.real)
} }
} }

View File

@ -5,11 +5,10 @@ import scientifik.kmath.dimensions.D3
import scientifik.kmath.dimensions.DMatrixContext import scientifik.kmath.dimensions.DMatrixContext
import kotlin.test.Test import kotlin.test.Test
class DMatrixContextTest { class DMatrixContextTest {
@Test @Test
fun testDimensionSafeMatrix() { fun testDimensionSafeMatrix() {
val res = DMatrixContext.real.run { val res = with(DMatrixContext.real) {
val m = produce<D2, D2> { i, j -> (i + j).toDouble() } val m = produce<D2, D2> { i, j -> (i + j).toDouble() }
//The dimension of `one()` is inferred from type //The dimension of `one()` is inferred from type
@ -19,7 +18,7 @@ class DMatrixContextTest {
@Test @Test
fun testTypeCheck() { fun testTypeCheck() {
val res = DMatrixContext.real.run { val res = with(DMatrixContext.real) {
val m1 = produce<D2, D3> { i, j -> (i + j).toDouble() } val m1 = produce<D2, D3> { i, j -> (i + j).toDouble() }
val m2 = produce<D3, D2> { i, j -> (i + j).toDouble() } val m2 = produce<D3, D2> { i, j -> (i + j).toDouble() }

View File

@ -1,11 +1,6 @@
plugins { plugins { id("scientifik.mpp") }
id("scientifik.mpp")
}
kotlin.sourceSets { kotlin.sourceSets {
commonMain { all { languageSettings.useExperimentalAnnotation("kotlin.contracts.ExperimentalContracts") }
dependencies { commonMain { dependencies { api(project(":kmath-core")) } }
api(project(":kmath-core"))
}
}
} }

View File

@ -14,8 +14,8 @@ import kotlin.math.sqrt
typealias RealPoint = Point<Double> typealias RealPoint = Point<Double>
fun DoubleArray.asVector() = RealVector(this.asBuffer()) fun DoubleArray.asVector(): RealVector = RealVector(this.asBuffer())
fun List<Double>.asVector() = RealVector(this.asBuffer()) fun List<Double>.asVector(): RealVector = RealVector(this.asBuffer())
object VectorL2Norm : Norm<Point<out Number>, Double> { object VectorL2Norm : Norm<Point<out Number>, Double> {
override fun norm(arg: Point<out Number>): Double = sqrt(arg.asIterable().sumByDouble { it.toDouble() }) override fun norm(arg: Point<out Number>): Double = sqrt(arg.asIterable().sumByDouble { it.toDouble() })
@ -32,15 +32,14 @@ inline class RealVector(private val point: Point<Double>) :
override val size: Int get() = point.size override val size: Int get() = point.size
override fun get(index: Int): Double = point[index] override operator fun get(index: Int): Double = point[index]
override fun iterator(): Iterator<Double> = point.iterator() override operator fun iterator(): Iterator<Double> = point.iterator()
companion object { companion object {
private val spaceCache: MutableMap<Int, BufferVectorSpace<Double, RealField>> = hashMapOf()
private val spaceCache = HashMap<Int, BufferVectorSpace<Double, RealField>>() inline operator fun invoke(dim: Int, initializer: (Int) -> Double): RealVector =
inline operator fun invoke(dim: Int, initializer: (Int) -> Double) =
RealVector(RealBuffer(dim, initializer)) RealVector(RealBuffer(dim, initializer))
operator fun invoke(vararg values: Double): RealVector = values.asVector() operator fun invoke(vararg values: Double): RealVector = values.asVector()

View File

@ -3,11 +3,14 @@ package scientifik.kmath.real
import scientifik.kmath.linear.MatrixContext import scientifik.kmath.linear.MatrixContext
import scientifik.kmath.linear.RealMatrixContext.elementContext import scientifik.kmath.linear.RealMatrixContext.elementContext
import scientifik.kmath.linear.VirtualMatrix import scientifik.kmath.linear.VirtualMatrix
import scientifik.kmath.operations.invoke
import scientifik.kmath.operations.sum import scientifik.kmath.operations.sum
import scientifik.kmath.structures.Buffer import scientifik.kmath.structures.Buffer
import scientifik.kmath.structures.Matrix import scientifik.kmath.structures.Matrix
import scientifik.kmath.structures.RealBuffer import scientifik.kmath.structures.RealBuffer
import scientifik.kmath.structures.asIterable import scientifik.kmath.structures.asIterable
import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.contract
import kotlin.math.pow import kotlin.math.pow
/* /*
@ -27,7 +30,7 @@ typealias RealMatrix = Matrix<Double>
fun realMatrix(rowNum: Int, colNum: Int, initializer: (i: Int, j: Int) -> Double): RealMatrix = fun realMatrix(rowNum: Int, colNum: Int, initializer: (i: Int, j: Int) -> Double): RealMatrix =
MatrixContext.real.produce(rowNum, colNum, initializer) MatrixContext.real.produce(rowNum, colNum, initializer)
fun Array<DoubleArray>.toMatrix(): RealMatrix{ fun Array<DoubleArray>.toMatrix(): RealMatrix {
return MatrixContext.real.produce(size, this[0].size) { row, col -> this[row][col] } return MatrixContext.real.produce(size, this[0].size) { row, col -> this[row][col] }
} }
@ -117,13 +120,16 @@ operator fun Matrix<Double>.minus(other: Matrix<Double>): RealMatrix =
* Operations on columns * Operations on columns
*/ */
inline fun Matrix<Double>.appendColumn(crossinline mapper: (Buffer<Double>) -> Double) = inline fun Matrix<Double>.appendColumn(crossinline mapper: (Buffer<Double>) -> Double): Matrix<Double> {
MatrixContext.real.produce(rowNum, colNum + 1) { row, col -> contract { callsInPlace(mapper) }
return MatrixContext.real.produce(rowNum, colNum + 1) { row, col ->
if (col < colNum) if (col < colNum)
this[row, col] this[row, col]
else else
mapper(rows[row]) mapper(rows[row])
} }
}
fun Matrix<Double>.extractColumns(columnRange: IntRange): RealMatrix = fun Matrix<Double>.extractColumns(columnRange: IntRange): RealMatrix =
MatrixContext.real.produce(rowNum, columnRange.count()) { row, col -> MatrixContext.real.produce(rowNum, columnRange.count()) { row, col ->
@ -135,17 +141,15 @@ fun Matrix<Double>.extractColumn(columnIndex: Int): RealMatrix =
fun Matrix<Double>.sumByColumn(): RealBuffer = RealBuffer(colNum) { j -> fun Matrix<Double>.sumByColumn(): RealBuffer = RealBuffer(colNum) { j ->
val column = columns[j] val column = columns[j]
with(elementContext) { elementContext { sum(column.asIterable()) }
sum(column.asIterable())
}
} }
fun Matrix<Double>.minByColumn(): RealBuffer = RealBuffer(colNum) { j -> fun Matrix<Double>.minByColumn(): RealBuffer = RealBuffer(colNum) { j ->
columns[j].asIterable().min() ?: throw Exception("Cannot produce min on empty column") columns[j].asIterable().min() ?: error("Cannot produce min on empty column")
} }
fun Matrix<Double>.maxByColumn(): RealBuffer = RealBuffer(colNum) { j -> fun Matrix<Double>.maxByColumn(): RealBuffer = RealBuffer(colNum) { j ->
columns[j].asIterable().max() ?: throw Exception("Cannot produce min on empty column") columns[j].asIterable().max() ?: error("Cannot produce min on empty column")
} }
fun Matrix<Double>.averageByColumn(): RealBuffer = RealBuffer(colNum) { j -> fun Matrix<Double>.averageByColumn(): RealBuffer = RealBuffer(colNum) { j ->
@ -156,10 +160,7 @@ fun Matrix<Double>.averageByColumn(): RealBuffer = RealBuffer(colNum) { j ->
* Operations processing all elements * Operations processing all elements
*/ */
fun Matrix<Double>.sum() = elements().map { (_, value) -> value }.sum() fun Matrix<Double>.sum(): Double = elements().map { (_, value) -> value }.sum()
fun Matrix<Double>.min(): Double? = elements().map { (_, value) -> value }.min()
fun Matrix<Double>.min() = elements().map { (_, value) -> value }.min() fun Matrix<Double>.max(): Double? = elements().map { (_, value) -> value }.max()
fun Matrix<Double>.average(): Double = elements().map { (_, value) -> value }.average()
fun Matrix<Double>.max() = elements().map { (_, value) -> value }.max()
fun Matrix<Double>.average() = elements().map { (_, value) -> value }.average()

View File

@ -1,5 +1,6 @@
package scientifik.kmath.linear package scientifik.kmath.linear
import scientifik.kmath.operations.invoke
import scientifik.kmath.real.RealVector import scientifik.kmath.real.RealVector
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
@ -24,14 +25,10 @@ class VectorTest {
fun testDot() { fun testDot() {
val vector1 = RealVector(5) { it.toDouble() } val vector1 = RealVector(5) { it.toDouble() }
val vector2 = RealVector(5) { 5 - it.toDouble() } val vector2 = RealVector(5) { 5 - it.toDouble() }
val matrix1 = vector1.asMatrix() val matrix1 = vector1.asMatrix()
val matrix2 = vector2.asMatrix().transpose() val matrix2 = vector2.asMatrix().transpose()
val product = MatrixContext.real.run { matrix1 dot matrix2 } val product = MatrixContext.real { matrix1 dot matrix2 }
assertEquals(5.0, product[1, 0]) assertEquals(5.0, product[1, 0])
assertEquals(6.0, product[2, 2]) assertEquals(6.0, product[2, 2])
} }
} }

View File

@ -1,11 +1,6 @@
plugins { plugins { id("scientifik.mpp") }
id("scientifik.mpp")
}
kotlin.sourceSets { kotlin.sourceSets {
commonMain { all { languageSettings.useExperimentalAnnotation("kotlin.contracts.ExperimentalContracts") }
dependencies { commonMain { dependencies { api(project(":kmath-core")) } }
api(project(":kmath-core"))
}
}
} }

View File

@ -2,6 +2,10 @@ package scientifik.kmath.functions
import scientifik.kmath.operations.Ring import scientifik.kmath.operations.Ring
import scientifik.kmath.operations.Space import scientifik.kmath.operations.Space
import scientifik.kmath.operations.invoke
import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
import kotlin.math.max import kotlin.math.max
import kotlin.math.pow import kotlin.math.pow
@ -13,20 +17,21 @@ inline class Polynomial<T : Any>(val coefficients: List<T>) {
constructor(vararg coefficients: T) : this(coefficients.toList()) constructor(vararg coefficients: T) : this(coefficients.toList())
} }
fun Polynomial<Double>.value() = fun Polynomial<Double>.value(): Double =
coefficients.reduceIndexed { index: Int, acc: Double, d: Double -> acc + d.pow(index) } coefficients.reduceIndexed { index: Int, acc: Double, d: Double -> acc + d.pow(index) }
fun <T : Any, C : Ring<T>> Polynomial<T>.value(ring: C, arg: T): T = ring {
fun <T : Any, C : Ring<T>> Polynomial<T>.value(ring: C, arg: T): T = ring.run { if (coefficients.isEmpty()) return@ring zero
if (coefficients.isEmpty()) return@run zero
var res = coefficients.first() var res = coefficients.first()
var powerArg = arg var powerArg = arg
for (index in 1 until coefficients.size) { for (index in 1 until coefficients.size) {
res += coefficients[index] * powerArg res += coefficients[index] * powerArg
//recalculating power on each step to avoid power costs on long polynomials //recalculating power on each step to avoid power costs on long polynomials
powerArg *= arg powerArg *= arg
} }
return@run res
res
} }
/** /**
@ -34,7 +39,7 @@ fun <T : Any, C : Ring<T>> Polynomial<T>.value(ring: C, arg: T): T = ring.run {
*/ */
fun <T : Any, C : Ring<T>> Polynomial<T>.asMathFunction(): MathFunction<T, out C, T> = object : fun <T : Any, C : Ring<T>> Polynomial<T>.asMathFunction(): MathFunction<T, out C, T> = object :
MathFunction<T, C, T> { MathFunction<T, C, T> {
override fun C.invoke(arg: T): T = value(this, arg) override operator fun C.invoke(arg: T): T = value(this, arg)
} }
/** /**
@ -49,18 +54,16 @@ class PolynomialSpace<T : Any, C : Ring<T>>(val ring: C) : Space<Polynomial<T>>
override fun add(a: Polynomial<T>, b: Polynomial<T>): Polynomial<T> { override fun add(a: Polynomial<T>, b: Polynomial<T>): Polynomial<T> {
val dim = max(a.coefficients.size, b.coefficients.size) val dim = max(a.coefficients.size, b.coefficients.size)
ring.run {
return Polynomial(List(dim) { index -> return ring {
Polynomial(List(dim) { index ->
a.coefficients.getOrElse(index) { zero } + b.coefficients.getOrElse(index) { zero } a.coefficients.getOrElse(index) { zero } + b.coefficients.getOrElse(index) { zero }
}) })
} }
} }
override fun multiply(a: Polynomial<T>, k: Number): Polynomial<T> { override fun multiply(a: Polynomial<T>, k: Number): Polynomial<T> =
ring.run { ring { Polynomial(List(a.coefficients.size) { index -> a.coefficients[index] * k }) }
return Polynomial(List(a.coefficients.size) { index -> a.coefficients[index] * k })
}
}
override val zero: Polynomial<T> = override val zero: Polynomial<T> =
Polynomial(emptyList()) Polynomial(emptyList())
@ -68,6 +71,7 @@ class PolynomialSpace<T : Any, C : Ring<T>>(val ring: C) : Space<Polynomial<T>>
operator fun Polynomial<T>.invoke(arg: T): T = value(ring, arg) operator fun Polynomial<T>.invoke(arg: T): T = value(ring, arg)
} }
fun <T : Any, C : Ring<T>, R> C.polynomial(block: PolynomialSpace<T, C>.() -> R): R { inline fun <T : Any, C : Ring<T>, R> C.polynomial(block: PolynomialSpace<T, C>.() -> R): R {
return PolynomialSpace(this).run(block) contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return PolynomialSpace(this).block()
} }

View File

@ -4,13 +4,13 @@ import scientifik.kmath.functions.OrderedPiecewisePolynomial
import scientifik.kmath.functions.PiecewisePolynomial import scientifik.kmath.functions.PiecewisePolynomial
import scientifik.kmath.functions.Polynomial import scientifik.kmath.functions.Polynomial
import scientifik.kmath.operations.Field import scientifik.kmath.operations.Field
import scientifik.kmath.operations.invoke
/** /**
* Reference JVM implementation: https://github.com/apache/commons-math/blob/master/src/main/java/org/apache/commons/math4/analysis/interpolation/LinearInterpolator.java * Reference JVM implementation: https://github.com/apache/commons-math/blob/master/src/main/java/org/apache/commons/math4/analysis/interpolation/LinearInterpolator.java
*/ */
class LinearInterpolator<T : Comparable<T>>(override val algebra: Field<T>) : PolynomialInterpolator<T> { class LinearInterpolator<T : Comparable<T>>(override val algebra: Field<T>) : PolynomialInterpolator<T> {
override fun interpolatePolynomials(points: XYPointSet<T, T>): PiecewisePolynomial<T> = algebra {
override fun interpolatePolynomials(points: XYPointSet<T, T>): PiecewisePolynomial<T> = algebra.run {
require(points.size > 0) { "Point array should not be empty" } require(points.size > 0) { "Point array should not be empty" }
insureSorted(points) insureSorted(points)

View File

@ -4,6 +4,7 @@ import scientifik.kmath.functions.OrderedPiecewisePolynomial
import scientifik.kmath.functions.PiecewisePolynomial import scientifik.kmath.functions.PiecewisePolynomial
import scientifik.kmath.functions.Polynomial import scientifik.kmath.functions.Polynomial
import scientifik.kmath.operations.Field import scientifik.kmath.operations.Field
import scientifik.kmath.operations.invoke
import scientifik.kmath.structures.MutableBufferFactory import scientifik.kmath.structures.MutableBufferFactory
/** /**
@ -17,7 +18,7 @@ class SplineInterpolator<T : Comparable<T>>(
//TODO possibly optimize zeroed buffers //TODO possibly optimize zeroed buffers
override fun interpolatePolynomials(points: XYPointSet<T, T>): PiecewisePolynomial<T> = algebra.run { override fun interpolatePolynomials(points: XYPointSet<T, T>): PiecewisePolynomial<T> = algebra {
if (points.size < 3) { if (points.size < 3) {
error("Can't use spline interpolator with less than 3 points") error("Can't use spline interpolator with less than 3 points")
} }

View File

@ -14,9 +14,7 @@ interface XYZPointSet<X, Y, Z> : XYPointSet<X, Y> {
} }
internal fun <T : Comparable<T>> insureSorted(points: XYPointSet<T, *>) { internal fun <T : Comparable<T>> insureSorted(points: XYPointSet<T, *>) {
for (i in 0 until points.size - 1) { for (i in 0 until points.size - 1) require(points.x[i + 1] > points.x[i]) { "Input data is not sorted at index $i" }
if (points.x[i + 1] <= points.x[i]) error("Input data is not sorted at index $i")
}
} }
class NDStructureColumn<T>(val structure: Structure2D<T>, val column: Int) : Buffer<T> { class NDStructureColumn<T>(val structure: Structure2D<T>, val column: Int) : Buffer<T> {
@ -26,9 +24,9 @@ class NDStructureColumn<T>(val structure: Structure2D<T>, val column: Int) : Buf
override val size: Int get() = structure.rowNum override val size: Int get() = structure.rowNum
override fun get(index: Int): T = structure[index, column] override operator fun get(index: Int): T = structure[index, column]
override fun iterator(): Iterator<T> = sequence { override operator fun iterator(): Iterator<T> = sequence {
repeat(size) { repeat(size) {
yield(get(it)) yield(get(it))
} }

View File

@ -9,25 +9,21 @@ import kotlin.math.sqrt
interface Vector2D : Point<Double>, Vector, SpaceElement<Vector2D, Vector2D, Euclidean2DSpace> { interface Vector2D : Point<Double>, Vector, SpaceElement<Vector2D, Vector2D, Euclidean2DSpace> {
val x: Double val x: Double
val y: Double val y: Double
override val context: Euclidean2DSpace get() = Euclidean2DSpace
override val size: Int get() = 2 override val size: Int get() = 2
override fun get(index: Int): Double = when (index) { override operator fun get(index: Int): Double = when (index) {
1 -> x 1 -> x
2 -> y 2 -> y
else -> error("Accessing outside of point bounds") else -> error("Accessing outside of point bounds")
} }
override fun iterator(): Iterator<Double> = listOf(x, y).iterator() override operator fun iterator(): Iterator<Double> = listOf(x, y).iterator()
override val context: Euclidean2DSpace get() = Euclidean2DSpace
override fun unwrap(): Vector2D = this override fun unwrap(): Vector2D = this
override fun Vector2D.wrap(): Vector2D = this override fun Vector2D.wrap(): Vector2D = this
} }
val Vector2D.r: Double get() = Euclidean2DSpace.run { sqrt(norm()) } val Vector2D.r: Double get() = Euclidean2DSpace { sqrt(norm()) }
@Suppress("FunctionName") @Suppress("FunctionName")
fun Vector2D(x: Double, y: Double): Vector2D = Vector2DImpl(x, y) fun Vector2D(x: Double, y: Double): Vector2D = Vector2DImpl(x, y)

View File

@ -2,6 +2,7 @@ package scientifik.kmath.geometry
import scientifik.kmath.linear.Point import scientifik.kmath.linear.Point
import scientifik.kmath.operations.SpaceElement import scientifik.kmath.operations.SpaceElement
import scientifik.kmath.operations.invoke
import kotlin.math.sqrt import kotlin.math.sqrt
@ -9,19 +10,17 @@ interface Vector3D : Point<Double>, Vector, SpaceElement<Vector3D, Vector3D, Euc
val x: Double val x: Double
val y: Double val y: Double
val z: Double val z: Double
override val context: Euclidean3DSpace get() = Euclidean3DSpace
override val size: Int get() = 3 override val size: Int get() = 3
override fun get(index: Int): Double = when (index) { override operator fun get(index: Int): Double = when (index) {
1 -> x 1 -> x
2 -> y 2 -> y
3 -> z 3 -> z
else -> error("Accessing outside of point bounds") else -> error("Accessing outside of point bounds")
} }
override fun iterator(): Iterator<Double> = listOf(x, y, z).iterator() override operator fun iterator(): Iterator<Double> = listOf(x, y, z).iterator()
override val context: Euclidean3DSpace get() = Euclidean3DSpace
override fun unwrap(): Vector3D = this override fun unwrap(): Vector3D = this
@ -31,7 +30,7 @@ interface Vector3D : Point<Double>, Vector, SpaceElement<Vector3D, Vector3D, Euc
@Suppress("FunctionName") @Suppress("FunctionName")
fun Vector3D(x: Double, y: Double, z: Double): Vector3D = Vector3DImpl(x, y, z) fun Vector3D(x: Double, y: Double, z: Double): Vector3D = Vector3DImpl(x, y, z)
val Vector3D.r: Double get() = Euclidean3DSpace.run { sqrt(norm()) } val Vector3D.r: Double get() = Euclidean3DSpace { sqrt(norm()) }
private data class Vector3DImpl( private data class Vector3DImpl(
override val x: Double, override val x: Double,

View File

@ -4,6 +4,9 @@ import scientifik.kmath.domains.Domain
import scientifik.kmath.linear.Point import scientifik.kmath.linear.Point
import scientifik.kmath.structures.ArrayBuffer import scientifik.kmath.structures.ArrayBuffer
import scientifik.kmath.structures.RealBuffer import scientifik.kmath.structures.RealBuffer
import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
/** /**
* The bin in the histogram. The histogram is by definition always done in the real space * The bin in the histogram. The histogram is by definition always done in the real space
@ -37,20 +40,20 @@ interface MutableHistogram<T : Any, out B : Bin<T>> : Histogram<T, B> {
*/ */
fun putWithWeight(point: Point<out T>, weight: Double) fun putWithWeight(point: Point<out T>, weight: Double)
fun put(point: Point<out T>) = putWithWeight(point, 1.0) fun put(point: Point<out T>): Unit = putWithWeight(point, 1.0)
} }
fun <T : Any> MutableHistogram<T, *>.put(vararg point: T) = put(ArrayBuffer(point)) fun <T : Any> MutableHistogram<T, *>.put(vararg point: T): Unit = put(ArrayBuffer(point))
fun MutableHistogram<Double, *>.put(vararg point: Number) = fun MutableHistogram<Double, *>.put(vararg point: Number): Unit =
put(RealBuffer(point.map { it.toDouble() }.toDoubleArray())) put(RealBuffer(point.map { it.toDouble() }.toDoubleArray()))
fun MutableHistogram<Double, *>.put(vararg point: Double) = put(RealBuffer(point)) fun MutableHistogram<Double, *>.put(vararg point: Double): Unit = put(RealBuffer(point))
fun <T : Any> MutableHistogram<T, *>.fill(sequence: Iterable<Point<T>>) = sequence.forEach { put(it) } fun <T : Any> MutableHistogram<T, *>.fill(sequence: Iterable<Point<T>>): Unit = sequence.forEach { put(it) }
/** /**
* Pass a sequence builder into histogram * Pass a sequence builder into histogram
*/ */
fun <T : Any> MutableHistogram<T, *>.fill(buider: suspend SequenceScope<Point<T>>.() -> Unit) = fun <T : Any> MutableHistogram<T, *>.fill(block: suspend SequenceScope<Point<T>>.() -> Unit): Unit =
fill(sequence(buider).asIterable()) fill(sequence(block).asIterable())

View File

@ -2,6 +2,7 @@ package scientifik.kmath.histogram
import scientifik.kmath.linear.Point import scientifik.kmath.linear.Point
import scientifik.kmath.operations.SpaceOperations import scientifik.kmath.operations.SpaceOperations
import scientifik.kmath.operations.invoke
import scientifik.kmath.real.asVector import scientifik.kmath.real.asVector
import scientifik.kmath.structures.* import scientifik.kmath.structures.*
import kotlin.math.floor import kotlin.math.floor
@ -9,19 +10,16 @@ import kotlin.math.floor
data class BinDef<T : Comparable<T>>(val space: SpaceOperations<Point<T>>, val center: Point<T>, val sizes: Point<T>) { data class BinDef<T : Comparable<T>>(val space: SpaceOperations<Point<T>>, val center: Point<T>, val sizes: Point<T>) {
fun contains(vector: Point<out T>): Boolean { fun contains(vector: Point<out T>): Boolean {
if (vector.size != center.size) error("Dimension mismatch for input vector. Expected ${center.size}, but found ${vector.size}") require(vector.size == center.size) { "Dimension mismatch for input vector. Expected ${center.size}, but found ${vector.size}" }
val upper = space.run { center + sizes / 2.0 } val upper = space { center + sizes / 2.0 }
val lower = space.run { center - sizes / 2.0 } val lower = space { center - sizes / 2.0 }
return vector.asSequence().mapIndexed { i, value -> return vector.asSequence().mapIndexed { i, value -> value in lower[i]..upper[i] }.all { it }
value in lower[i]..upper[i]
}.all { it }
} }
} }
class MultivariateBin<T : Comparable<T>>(val def: BinDef<T>, override val value: Number) : Bin<T> { class MultivariateBin<T : Comparable<T>>(val def: BinDef<T>, override val value: Number) : Bin<T> {
override operator fun contains(point: Point<T>): Boolean = def.contains(point)
override fun contains(point: Point<T>): Boolean = def.contains(point)
override val dimension: Int override val dimension: Int
get() = def.center.size get() = def.center.size
@ -39,47 +37,34 @@ class RealHistogram(
private val upper: Buffer<Double>, private val upper: Buffer<Double>,
private val binNums: IntArray = IntArray(lower.size) { 20 } private val binNums: IntArray = IntArray(lower.size) { 20 }
) : MutableHistogram<Double, MultivariateBin<Double>> { ) : MutableHistogram<Double, MultivariateBin<Double>> {
private val strides = DefaultStrides(IntArray(binNums.size) { binNums[it] + 2 }) private val strides = DefaultStrides(IntArray(binNums.size) { binNums[it] + 2 })
private val values: NDStructure<LongCounter> = NDStructure.auto(strides) { LongCounter() } private val values: NDStructure<LongCounter> = NDStructure.auto(strides) { LongCounter() }
private val weights: NDStructure<DoubleCounter> = NDStructure.auto(strides) { DoubleCounter() } private val weights: NDStructure<DoubleCounter> = NDStructure.auto(strides) { DoubleCounter() }
override val dimension: Int get() = lower.size override val dimension: Int get() = lower.size
private val binSize = RealBuffer(dimension) { (upper[it] - lower[it]) / binNums[it] } private val binSize = RealBuffer(dimension) { (upper[it] - lower[it]) / binNums[it] }
init { init {
// argument checks // argument checks
if (lower.size != upper.size) error("Dimension mismatch in histogram lower and upper limits.") require(lower.size == upper.size) { "Dimension mismatch in histogram lower and upper limits." }
if (lower.size != binNums.size) error("Dimension mismatch in bin count.") require(lower.size == binNums.size) { "Dimension mismatch in bin count." }
if ((0 until dimension).any { upper[it] - lower[it] < 0 }) error("Range for one of axis is not strictly positive") require(!(0 until dimension).any { upper[it] - lower[it] < 0 }) { "Range for one of axis is not strictly positive" }
} }
/** /**
* Get internal [NDStructure] bin index for given axis * Get internal [NDStructure] bin index for given axis
*/ */
private fun getIndex(axis: Int, value: Double): Int { private fun getIndex(axis: Int, value: Double): Int = when {
return when { value >= upper[axis] -> binNums[axis] + 1 // overflow
value >= upper[axis] -> binNums[axis] + 1 // overflow value < lower[axis] -> 0 // underflow
value < lower[axis] -> 0 // underflow else -> floor((value - lower[axis]) / binSize[axis]).toInt() + 1
else -> floor((value - lower[axis]) / binSize[axis]).toInt() + 1
}
} }
private fun getIndex(point: Buffer<out Double>): IntArray = IntArray(dimension) { getIndex(it, point[it]) } private fun getIndex(point: Buffer<out Double>): IntArray = IntArray(dimension) { getIndex(it, point[it]) }
private fun getValue(index: IntArray): Long { private fun getValue(index: IntArray): Long = values[index].sum()
return values[index].sum()
}
fun getValue(point: Buffer<out Double>): Long { fun getValue(point: Buffer<out Double>): Long = getValue(getIndex(point))
return getValue(getIndex(point))
}
private fun getDef(index: IntArray): BinDef<Double> { private fun getDef(index: IntArray): BinDef<Double> {
val center = index.mapIndexed { axis, i -> val center = index.mapIndexed { axis, i ->
@ -89,14 +74,13 @@ class RealHistogram(
else -> lower[axis] + (i.toDouble() - 0.5) * binSize[axis] else -> lower[axis] + (i.toDouble() - 0.5) * binSize[axis]
} }
}.asBuffer() }.asBuffer()
return BinDef(RealBufferFieldOperations, center, binSize) return BinDef(RealBufferFieldOperations, center, binSize)
} }
fun getDef(point: Buffer<out Double>): BinDef<Double> { fun getDef(point: Buffer<out Double>): BinDef<Double> = getDef(getIndex(point))
return getDef(getIndex(point))
}
override fun get(point: Buffer<out Double>): MultivariateBin<Double>? { override operator fun get(point: Buffer<out Double>): MultivariateBin<Double>? {
val index = getIndex(point) val index = getIndex(point)
return MultivariateBin(getDef(index), getValue(index)) return MultivariateBin(getDef(index), getValue(index))
} }
@ -112,26 +96,21 @@ class RealHistogram(
weights[index].add(weight) weights[index].add(weight)
} }
override fun iterator(): Iterator<MultivariateBin<Double>> = weights.elements().map { (index, value) -> override operator fun iterator(): Iterator<MultivariateBin<Double>> = weights.elements().map { (index, value) ->
MultivariateBin(getDef(index), value.sum()) MultivariateBin(getDef(index), value.sum())
}.iterator() }.iterator()
/** /**
* Convert this histogram into NDStructure containing bin values but not bin descriptions * Convert this histogram into NDStructure containing bin values but not bin descriptions
*/ */
fun values(): NDStructure<Number> { fun values(): NDStructure<Number> = NDStructure.auto(values.shape) { values[it].sum() }
return NDStructure.auto(values.shape) { values[it].sum() }
}
/** /**
* Sum of weights * Sum of weights
*/ */
fun weights():NDStructure<Double>{ fun weights(): NDStructure<Double> = NDStructure.auto(weights.shape) { weights[it].sum() }
return NDStructure.auto(weights.shape) { weights[it].sum() }
}
companion object { companion object {
/** /**
* Use it like * Use it like
* ``` * ```
@ -141,12 +120,10 @@ class RealHistogram(
*) *)
*``` *```
*/ */
fun fromRanges(vararg ranges: ClosedFloatingPointRange<Double>): RealHistogram { fun fromRanges(vararg ranges: ClosedFloatingPointRange<Double>): RealHistogram = RealHistogram(
return RealHistogram( ranges.map { it.start }.asVector(),
ranges.map { it.start }.asVector(), ranges.map { it.endInclusive }.asVector()
ranges.map { it.endInclusive }.asVector() )
)
}
/** /**
* Use it like * Use it like
@ -157,13 +134,10 @@ class RealHistogram(
*) *)
*``` *```
*/ */
fun fromRanges(vararg ranges: Pair<ClosedFloatingPointRange<Double>, Int>): RealHistogram { fun fromRanges(vararg ranges: Pair<ClosedFloatingPointRange<Double>, Int>): RealHistogram = RealHistogram(
return RealHistogram( ListBuffer(ranges.map { it.first.start }),
ListBuffer(ranges.map { it.first.start }), ListBuffer(ranges.map { it.first.endInclusive }),
ListBuffer(ranges.map { it.first.endInclusive }), ranges.map { it.second }.toIntArray()
ranges.map { it.second }.toIntArray() )
)
}
} }
} }

View File

@ -46,11 +46,11 @@ class UnivariateHistogram private constructor(private val factory: (Double) -> U
synchronized(this) { bins.put(it.position, it) } synchronized(this) { bins.put(it.position, it) }
} }
override fun get(point: Buffer<out Double>): UnivariateBin? = get(point[0]) override operator fun get(point: Buffer<out Double>): UnivariateBin? = get(point[0])
override val dimension: Int get() = 1 override val dimension: Int get() = 1
override fun iterator(): Iterator<UnivariateBin> = bins.values.iterator() override operator fun iterator(): Iterator<UnivariateBin> = bins.values.iterator()
/** /**
* Thread safe put operation * Thread safe put operation
@ -65,15 +65,14 @@ class UnivariateHistogram private constructor(private val factory: (Double) -> U
} }
companion object { companion object {
fun uniform(binSize: Double, start: Double = 0.0): UnivariateHistogram { fun uniform(binSize: Double, start: Double = 0.0): UnivariateHistogram = UnivariateHistogram { value ->
return UnivariateHistogram { value -> val center = start + binSize * floor((value - start) / binSize + 0.5)
val center = start + binSize * floor((value - start) / binSize + 0.5) UnivariateBin(center, binSize)
UnivariateBin(center, binSize)
}
} }
fun custom(borders: DoubleArray): UnivariateHistogram { fun custom(borders: DoubleArray): UnivariateHistogram {
val sorted = borders.sortedArray() val sorted = borders.sortedArray()
return UnivariateHistogram { value -> return UnivariateHistogram { value ->
when { when {
value < sorted.first() -> UnivariateBin( value < sorted.first() -> UnivariateBin(

View File

@ -3,16 +3,16 @@ package scientifik.kmath.linear
import koma.extensions.fill import koma.extensions.fill
import koma.matrix.MatrixFactory import koma.matrix.MatrixFactory
import scientifik.kmath.operations.Space import scientifik.kmath.operations.Space
import scientifik.kmath.operations.invoke
import scientifik.kmath.structures.Matrix import scientifik.kmath.structures.Matrix
import scientifik.kmath.structures.NDStructure import scientifik.kmath.structures.NDStructure
class KomaMatrixContext<T : Any>( class KomaMatrixContext<T : Any>(
private val factory: MatrixFactory<koma.matrix.Matrix<T>>, private val factory: MatrixFactory<koma.matrix.Matrix<T>>,
private val space: Space<T> private val space: Space<T>
) : ) : MatrixContext<T> {
MatrixContext<T> {
override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T) = override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): KomaMatrix<T> =
KomaMatrix(factory.zeros(rows, columns).fill(initializer)) KomaMatrix(factory.zeros(rows, columns).fill(initializer))
fun Matrix<T>.toKoma(): KomaMatrix<T> = if (this is KomaMatrix) { fun Matrix<T>.toKoma(): KomaMatrix<T> = if (this is KomaMatrix) {
@ -28,31 +28,28 @@ class KomaMatrixContext<T : Any>(
} }
override fun Matrix<T>.dot(other: Matrix<T>) = override fun Matrix<T>.dot(other: Matrix<T>): KomaMatrix<T> =
KomaMatrix(this.toKoma().origin * other.toKoma().origin) KomaMatrix(toKoma().origin * other.toKoma().origin)
override fun Matrix<T>.dot(vector: Point<T>) = override fun Matrix<T>.dot(vector: Point<T>): KomaVector<T> =
KomaVector(this.toKoma().origin * vector.toKoma().origin) KomaVector(toKoma().origin * vector.toKoma().origin)
override fun Matrix<T>.unaryMinus() = override operator fun Matrix<T>.unaryMinus(): KomaMatrix<T> =
KomaMatrix(this.toKoma().origin.unaryMinus()) KomaMatrix(toKoma().origin.unaryMinus())
override fun add(a: Matrix<T>, b: Matrix<T>) = override fun add(a: Matrix<T>, b: Matrix<T>): KomaMatrix<T> =
KomaMatrix(a.toKoma().origin + b.toKoma().origin) KomaMatrix(a.toKoma().origin + b.toKoma().origin)
override fun Matrix<T>.minus(b: Matrix<T>) = override operator fun Matrix<T>.minus(b: Matrix<T>): KomaMatrix<T> =
KomaMatrix(this.toKoma().origin - b.toKoma().origin) KomaMatrix(toKoma().origin - b.toKoma().origin)
override fun multiply(a: Matrix<T>, k: Number): Matrix<T> = override fun multiply(a: Matrix<T>, k: Number): Matrix<T> =
produce(a.rowNum, a.colNum) { i, j -> space.run { a[i, j] * k } } produce(a.rowNum, a.colNum) { i, j -> space { a[i, j] * k } }
override fun Matrix<T>.times(value: T) = override operator fun Matrix<T>.times(value: T): KomaMatrix<T> =
KomaMatrix(this.toKoma().origin * value) KomaMatrix(toKoma().origin * value)
companion object {
}
companion object
} }
fun <T : Any> KomaMatrixContext<T>.solve(a: Matrix<T>, b: Matrix<T>) = fun <T : Any> KomaMatrixContext<T>.solve(a: Matrix<T>, b: Matrix<T>) =
@ -70,10 +67,11 @@ class KomaMatrix<T : Any>(val origin: koma.matrix.Matrix<T>, features: Set<Matri
override val shape: IntArray get() = intArrayOf(origin.numRows(), origin.numCols()) override val shape: IntArray get() = intArrayOf(origin.numRows(), origin.numCols())
override val features: Set<MatrixFeature> = features ?: setOf( override val features: Set<MatrixFeature> = features ?: hashSetOf(
object : DeterminantFeature<T> { object : DeterminantFeature<T> {
override val determinant: T get() = origin.det() override val determinant: T get() = origin.det()
}, },
object : LUPDecompositionFeature<T> { object : LUPDecompositionFeature<T> {
private val lup by lazy { origin.LU() } private val lup by lazy { origin.LU() }
override val l: FeaturedMatrix<T> get() = KomaMatrix(lup.second) override val l: FeaturedMatrix<T> get() = KomaMatrix(lup.second)
@ -85,7 +83,7 @@ class KomaMatrix<T : Any>(val origin: koma.matrix.Matrix<T>, features: Set<Matri
override fun suggestFeature(vararg features: MatrixFeature): FeaturedMatrix<T> = override fun suggestFeature(vararg features: MatrixFeature): FeaturedMatrix<T> =
KomaMatrix(this.origin, this.features + features) KomaMatrix(this.origin, this.features + features)
override fun get(i: Int, j: Int): T = origin.getGeneric(i, j) override operator fun get(i: Int, j: Int): T = origin.getGeneric(i, j)
override fun equals(other: Any?): Boolean { override fun equals(other: Any?): Boolean {
return NDStructure.equals(this, other as? NDStructure<*> ?: return false) return NDStructure.equals(this, other as? NDStructure<*> ?: return false)
@ -101,14 +99,12 @@ class KomaMatrix<T : Any>(val origin: koma.matrix.Matrix<T>, features: Set<Matri
} }
class KomaVector<T : Any> internal constructor(val origin: koma.matrix.Matrix<T>) : Point<T> { class KomaVector<T : Any> internal constructor(val origin: koma.matrix.Matrix<T>) : Point<T> {
init {
if (origin.numCols() != 1) error("Only single column matrices are allowed")
}
override val size: Int get() = origin.numRows() override val size: Int get() = origin.numRows()
override fun get(index: Int): T = origin.getGeneric(index) init {
require(origin.numCols() == 1) { error("Only single column matrices are allowed") }
}
override fun iterator(): Iterator<T> = origin.toIterable().iterator() override operator fun get(index: Int): T = origin.getGeneric(index)
override operator fun iterator(): Iterator<T> = origin.toIterable().iterator()
} }

View File

@ -1,5 +1,8 @@
package scientifik.memory package scientifik.memory
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
/** /**
* Represents a display of certain memory structure. * Represents a display of certain memory structure.
*/ */
@ -80,8 +83,12 @@ interface MemoryReader {
/** /**
* Uses the memory for read then releases the reader. * Uses the memory for read then releases the reader.
*/ */
inline fun Memory.read(block: MemoryReader.() -> Unit) { inline fun <R> Memory.read(block: MemoryReader.() -> R): R {
reader().apply(block).release() contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
val reader = reader()
val result = reader.block()
reader.release()
return result
} }
/** /**
@ -133,6 +140,7 @@ interface MemoryWriter {
* Uses the memory for write then releases the writer. * Uses the memory for write then releases the writer.
*/ */
inline fun Memory.write(block: MemoryWriter.() -> Unit) { inline fun Memory.write(block: MemoryWriter.() -> Unit) {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
writer().apply(block).release() writer().apply(block).release()
} }

View File

@ -38,11 +38,7 @@ fun <T : Any> MemoryWriter.write(spec: MemorySpec<T>, offset: Int, value: T): Un
* Reads array of [size] objects mapped by [spec] at certain [offset]. * Reads array of [size] objects mapped by [spec] at certain [offset].
*/ */
inline fun <reified T : Any> MemoryReader.readArray(spec: MemorySpec<T>, offset: Int, size: Int): Array<T> = inline fun <reified T : Any> MemoryReader.readArray(spec: MemorySpec<T>, offset: Int, size: Int): Array<T> =
Array(size) { i -> Array(size) { i -> with(spec) { read(offset + i * objectSize) } }
spec.run {
read(offset + i * objectSize)
}
}
/** /**
* Writes [array] of objects mapped by [spec] at certain [offset]. * Writes [array] of objects mapped by [spec] at certain [offset].

View File

@ -1,12 +1,17 @@
package scientifik.memory package scientifik.memory
import java.io.IOException
import java.nio.ByteBuffer import java.nio.ByteBuffer
import java.nio.channels.FileChannel import java.nio.channels.FileChannel
import java.nio.file.Files import java.nio.file.Files
import java.nio.file.Path import java.nio.file.Path
import java.nio.file.StandardOpenOption import java.nio.file.StandardOpenOption
import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
private class ByteBufferMemory( @PublishedApi
internal class ByteBufferMemory(
val buffer: ByteBuffer, val buffer: ByteBuffer,
val startOffset: Int = 0, val startOffset: Int = 0,
override val size: Int = buffer.limit() override val size: Int = buffer.limit()
@ -112,7 +117,11 @@ fun ByteBuffer.asMemory(startOffset: Int = 0, size: Int = limit()): Memory =
/** /**
* Uses direct memory-mapped buffer from file to read something and close it afterwards. * Uses direct memory-mapped buffer from file to read something and close it afterwards.
*/ */
fun <R> Path.readAsMemory(position: Long = 0, size: Long = Files.size(this), block: Memory.() -> R): R = @Throws(IOException::class)
FileChannel.open(this, StandardOpenOption.READ).use { inline fun <R> Path.readAsMemory(position: Long = 0, size: Long = Files.size(this), block: Memory.() -> R): R {
ByteBufferMemory(it.map(FileChannel.MapMode.READ_ONLY, position, size)).block() contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
}
return FileChannel
.open(this, StandardOpenOption.READ)
.use { ByteBufferMemory(it.map(FileChannel.MapMode.READ_ONLY, position, size)).block() }
}

View File

@ -5,6 +5,7 @@ import scientifik.kmath.chains.ConstantChain
import scientifik.kmath.chains.map import scientifik.kmath.chains.map
import scientifik.kmath.chains.zip import scientifik.kmath.chains.zip
import scientifik.kmath.operations.Space import scientifik.kmath.operations.Space
import scientifik.kmath.operations.invoke
class BasicSampler<T : Any>(val chainBuilder: (RandomGenerator) -> Chain<T>) : Sampler<T> { class BasicSampler<T : Any>(val chainBuilder: (RandomGenerator) -> Chain<T>) : Sampler<T> {
override fun sample(generator: RandomGenerator): Chain<T> = chainBuilder(generator) override fun sample(generator: RandomGenerator): Chain<T> = chainBuilder(generator)
@ -22,10 +23,10 @@ class SamplerSpace<T : Any>(val space: Space<T>) : Space<Sampler<T>> {
override val zero: Sampler<T> = ConstantSampler(space.zero) override val zero: Sampler<T> = ConstantSampler(space.zero)
override fun add(a: Sampler<T>, b: Sampler<T>): Sampler<T> = BasicSampler { generator -> override fun add(a: Sampler<T>, b: Sampler<T>): Sampler<T> = BasicSampler { generator ->
a.sample(generator).zip(b.sample(generator)) { aValue, bValue -> space.run { aValue + bValue } } a.sample(generator).zip(b.sample(generator)) { aValue, bValue -> space { aValue + bValue } }
} }
override fun multiply(a: Sampler<T>, k: Number): Sampler<T> = BasicSampler { generator -> override fun multiply(a: Sampler<T>, k: Number): Sampler<T> = BasicSampler { generator ->
a.sample(generator).map { space.run { it * k.toDouble() } } a.sample(generator).map { space { it * k.toDouble() } }
} }
} }

Some files were not shown because too many files have changed in this diff Show More