forked from kscience/kmath
Provide contracts for many functions, make inline several functions, replace <algebra>.run { and with(<algebra>> { with <algebra> {, add newlines at EOFs, specify operator modifier explicitly at many places, reformat files, replace if (...) error guards with require and check
This commit is contained in:
parent
3d3791b6cb
commit
1d18832aa6
@ -20,11 +20,14 @@
|
|||||||
- Norm support for `Complex`
|
- Norm support for `Complex`
|
||||||
|
|
||||||
### Changed
|
### Changed
|
||||||
|
- `readAsMemory` now has `throws IOException` in JVM signature.
|
||||||
|
- Several functions taking functional types were made `inline`.
|
||||||
|
- Several functions taking functional types now have `callsInPlace` contracts.
|
||||||
- BigInteger and BigDecimal algebra: JBigDecimalField has companion object with default math context; minor optimizations
|
- BigInteger and BigDecimal algebra: JBigDecimalField has companion object with default math context; minor optimizations
|
||||||
- `power(T, Int)` extension function has preconditions and supports `Field<T>`
|
- `power(T, Int)` extension function has preconditions and supports `Field<T>`
|
||||||
- Memory objects have more preconditions (overflow checking)
|
- Memory objects have more preconditions (overflow checking)
|
||||||
- `tg` function is renamed to `tan` (https://github.com/mipt-npm/kmath/pull/114)
|
- `tg` function is renamed to `tan` (https://github.com/mipt-npm/kmath/pull/114)
|
||||||
- Gradle version: 6.3 -> 6.5.1
|
- Gradle version: 6.3 -> 6.6
|
||||||
- Moved probability distributions to commons-rng and to `kmath-prob`
|
- Moved probability distributions to commons-rng and to `kmath-prob`
|
||||||
|
|
||||||
### Fixed
|
### Fixed
|
||||||
|
@ -60,5 +60,6 @@ benchmark {
|
|||||||
tasks.withType<KotlinCompile> {
|
tasks.withType<KotlinCompile> {
|
||||||
kotlinOptions {
|
kotlinOptions {
|
||||||
jvmTarget = Scientifik.JVM_TARGET.toString()
|
jvmTarget = Scientifik.JVM_TARGET.toString()
|
||||||
|
freeCompilerArgs = freeCompilerArgs + "-Xopt-in=kotlin.RequiresOptIn"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -4,46 +4,38 @@ import org.openjdk.jmh.annotations.Benchmark
|
|||||||
import org.openjdk.jmh.annotations.Scope
|
import org.openjdk.jmh.annotations.Scope
|
||||||
import org.openjdk.jmh.annotations.State
|
import org.openjdk.jmh.annotations.State
|
||||||
import scientifik.kmath.operations.RealField
|
import scientifik.kmath.operations.RealField
|
||||||
|
import scientifik.kmath.operations.invoke
|
||||||
|
|
||||||
@State(Scope.Benchmark)
|
@State(Scope.Benchmark)
|
||||||
class NDFieldBenchmark {
|
class NDFieldBenchmark {
|
||||||
|
|
||||||
@Benchmark
|
@Benchmark
|
||||||
fun autoFieldAdd() {
|
fun autoFieldAdd() {
|
||||||
bufferedField.run {
|
bufferedField {
|
||||||
var res: NDBuffer<Double> = one
|
var res: NDBuffer<Double> = one
|
||||||
repeat(n) {
|
repeat(n) { res += one }
|
||||||
res += one
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Benchmark
|
@Benchmark
|
||||||
fun autoElementAdd() {
|
fun autoElementAdd() {
|
||||||
var res = genericField.one
|
var res = genericField.one
|
||||||
repeat(n) {
|
repeat(n) { res += 1.0 }
|
||||||
res += 1.0
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Benchmark
|
@Benchmark
|
||||||
fun specializedFieldAdd() {
|
fun specializedFieldAdd() {
|
||||||
specializedField.run {
|
specializedField {
|
||||||
var res: NDBuffer<Double> = one
|
var res: NDBuffer<Double> = one
|
||||||
repeat(n) {
|
repeat(n) { res += 1.0 }
|
||||||
res += 1.0
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Benchmark
|
@Benchmark
|
||||||
fun boxingFieldAdd() {
|
fun boxingFieldAdd() {
|
||||||
genericField.run {
|
genericField {
|
||||||
var res: NDBuffer<Double> = one
|
var res: NDBuffer<Double> = one
|
||||||
repeat(n) {
|
repeat(n) { res += one }
|
||||||
res += one
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -5,23 +5,22 @@ import org.openjdk.jmh.annotations.Benchmark
|
|||||||
import org.openjdk.jmh.annotations.Scope
|
import org.openjdk.jmh.annotations.Scope
|
||||||
import org.openjdk.jmh.annotations.State
|
import org.openjdk.jmh.annotations.State
|
||||||
import scientifik.kmath.operations.RealField
|
import scientifik.kmath.operations.RealField
|
||||||
|
import scientifik.kmath.operations.invoke
|
||||||
import scientifik.kmath.viktor.ViktorNDField
|
import scientifik.kmath.viktor.ViktorNDField
|
||||||
|
|
||||||
|
|
||||||
@State(Scope.Benchmark)
|
@State(Scope.Benchmark)
|
||||||
class ViktorBenchmark {
|
class ViktorBenchmark {
|
||||||
final val dim = 1000
|
final val dim = 1000
|
||||||
final val n = 100
|
final val n = 100
|
||||||
|
|
||||||
// automatically build context most suited for given type.
|
// automatically build context most suited for given type.
|
||||||
final val autoField = NDField.auto(RealField, dim, dim)
|
final val autoField: BufferedNDField<Double, RealField> = NDField.auto(RealField, dim, dim)
|
||||||
final val realField = NDField.real(dim, dim)
|
final val realField: RealNDField = NDField.real(dim, dim)
|
||||||
|
final val viktorField: ViktorNDField = ViktorNDField(intArrayOf(dim, dim))
|
||||||
final val viktorField = ViktorNDField(intArrayOf(dim, dim))
|
|
||||||
|
|
||||||
@Benchmark
|
@Benchmark
|
||||||
fun automaticFieldAddition() {
|
fun automaticFieldAddition() {
|
||||||
autoField.run {
|
autoField {
|
||||||
var res = one
|
var res = one
|
||||||
repeat(n) { res += one }
|
repeat(n) { res += one }
|
||||||
}
|
}
|
||||||
@ -29,7 +28,7 @@ class ViktorBenchmark {
|
|||||||
|
|
||||||
@Benchmark
|
@Benchmark
|
||||||
fun viktorFieldAddition() {
|
fun viktorFieldAddition() {
|
||||||
viktorField.run {
|
viktorField {
|
||||||
var res = one
|
var res = one
|
||||||
repeat(n) { res += one }
|
repeat(n) { res += one }
|
||||||
}
|
}
|
||||||
@ -44,7 +43,7 @@ class ViktorBenchmark {
|
|||||||
|
|
||||||
@Benchmark
|
@Benchmark
|
||||||
fun realdFieldLog() {
|
fun realdFieldLog() {
|
||||||
realField.run {
|
realField {
|
||||||
val fortyTwo = produce { 42.0 }
|
val fortyTwo = produce { 42.0 }
|
||||||
var res = one
|
var res = one
|
||||||
repeat(n) { res = ln(fortyTwo) }
|
repeat(n) { res = ln(fortyTwo) }
|
||||||
|
@ -1,8 +1,13 @@
|
|||||||
package scientifik.kmath.utils
|
package scientifik.kmath.utils
|
||||||
|
|
||||||
|
import kotlin.contracts.ExperimentalContracts
|
||||||
|
import kotlin.contracts.InvocationKind
|
||||||
|
import kotlin.contracts.contract
|
||||||
import kotlin.system.measureTimeMillis
|
import kotlin.system.measureTimeMillis
|
||||||
|
|
||||||
|
@OptIn(ExperimentalContracts::class)
|
||||||
internal inline fun measureAndPrint(title: String, block: () -> Unit) {
|
internal inline fun measureAndPrint(title: String, block: () -> Unit) {
|
||||||
|
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||||
val time = measureTimeMillis(block)
|
val time = measureTimeMillis(block)
|
||||||
println("$title completed in $time millis")
|
println("$title completed in $time millis")
|
||||||
}
|
}
|
||||||
|
@ -5,6 +5,7 @@ import scientifik.kmath.commons.linear.CMMatrixContext
|
|||||||
import scientifik.kmath.commons.linear.inverse
|
import scientifik.kmath.commons.linear.inverse
|
||||||
import scientifik.kmath.commons.linear.toCM
|
import scientifik.kmath.commons.linear.toCM
|
||||||
import scientifik.kmath.operations.RealField
|
import scientifik.kmath.operations.RealField
|
||||||
|
import scientifik.kmath.operations.invoke
|
||||||
import scientifik.kmath.structures.Matrix
|
import scientifik.kmath.structures.Matrix
|
||||||
import kotlin.contracts.ExperimentalContracts
|
import kotlin.contracts.ExperimentalContracts
|
||||||
import kotlin.random.Random
|
import kotlin.random.Random
|
||||||
@ -21,29 +22,18 @@ fun main() {
|
|||||||
|
|
||||||
val n = 5000 // iterations
|
val n = 5000 // iterations
|
||||||
|
|
||||||
MatrixContext.real.run {
|
MatrixContext.real {
|
||||||
|
repeat(50) { val res = inverse(matrix) }
|
||||||
repeat(50) {
|
val inverseTime = measureTimeMillis { repeat(n) { val res = inverse(matrix) } }
|
||||||
val res = inverse(matrix)
|
|
||||||
}
|
|
||||||
|
|
||||||
val inverseTime = measureTimeMillis {
|
|
||||||
repeat(n) {
|
|
||||||
val res = inverse(matrix)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
println("[kmath] Inversion of $n matrices $dim x $dim finished in $inverseTime millis")
|
println("[kmath] Inversion of $n matrices $dim x $dim finished in $inverseTime millis")
|
||||||
}
|
}
|
||||||
|
|
||||||
//commons-math
|
//commons-math
|
||||||
|
|
||||||
val commonsTime = measureTimeMillis {
|
val commonsTime = measureTimeMillis {
|
||||||
CMMatrixContext.run {
|
CMMatrixContext {
|
||||||
val cm = matrix.toCM() //avoid overhead on conversion
|
val cm = matrix.toCM() //avoid overhead on conversion
|
||||||
repeat(n) {
|
repeat(n) { val res = inverse(cm) }
|
||||||
val res = inverse(cm)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -53,7 +43,7 @@ fun main() {
|
|||||||
//koma-ejml
|
//koma-ejml
|
||||||
|
|
||||||
val komaTime = measureTimeMillis {
|
val komaTime = measureTimeMillis {
|
||||||
KomaMatrixContext(EJMLMatrixFactory(), RealField).run {
|
(KomaMatrixContext(EJMLMatrixFactory(), RealField)) {
|
||||||
val km = matrix.toKoma() //avoid overhead on conversion
|
val km = matrix.toKoma() //avoid overhead on conversion
|
||||||
repeat(n) {
|
repeat(n) {
|
||||||
val res = inverse(km)
|
val res = inverse(km)
|
||||||
|
@ -4,6 +4,7 @@ import koma.matrix.ejml.EJMLMatrixFactory
|
|||||||
import scientifik.kmath.commons.linear.CMMatrixContext
|
import scientifik.kmath.commons.linear.CMMatrixContext
|
||||||
import scientifik.kmath.commons.linear.toCM
|
import scientifik.kmath.commons.linear.toCM
|
||||||
import scientifik.kmath.operations.RealField
|
import scientifik.kmath.operations.RealField
|
||||||
|
import scientifik.kmath.operations.invoke
|
||||||
import scientifik.kmath.structures.Matrix
|
import scientifik.kmath.structures.Matrix
|
||||||
import kotlin.random.Random
|
import kotlin.random.Random
|
||||||
import kotlin.system.measureTimeMillis
|
import kotlin.system.measureTimeMillis
|
||||||
@ -18,7 +19,7 @@ fun main() {
|
|||||||
// //warmup
|
// //warmup
|
||||||
// matrix1 dot matrix2
|
// matrix1 dot matrix2
|
||||||
|
|
||||||
CMMatrixContext.run {
|
CMMatrixContext {
|
||||||
val cmMatrix1 = matrix1.toCM()
|
val cmMatrix1 = matrix1.toCM()
|
||||||
val cmMatrix2 = matrix2.toCM()
|
val cmMatrix2 = matrix2.toCM()
|
||||||
|
|
||||||
@ -29,8 +30,7 @@ fun main() {
|
|||||||
println("CM implementation time: $cmTime")
|
println("CM implementation time: $cmTime")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
(KomaMatrixContext(EJMLMatrixFactory(), RealField)) {
|
||||||
KomaMatrixContext(EJMLMatrixFactory(), RealField).run {
|
|
||||||
val komaMatrix1 = matrix1.toKoma()
|
val komaMatrix1 = matrix1.toKoma()
|
||||||
val komaMatrix2 = matrix2.toKoma()
|
val komaMatrix2 = matrix2.toKoma()
|
||||||
|
|
||||||
|
@ -9,13 +9,11 @@ fun main() {
|
|||||||
Complex(index[0].toDouble() - index[1].toDouble(), index[0].toDouble() + index[1].toDouble())
|
Complex(index[0].toDouble() - index[1].toDouble(), index[0].toDouble() + index[1].toDouble())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
val compute = (NDField.complex(8)) {
|
||||||
val compute = NDField.complex(8).run {
|
|
||||||
val a = produce { (it) -> i * it - it.toDouble() }
|
val a = produce { (it) -> i * it - it.toDouble() }
|
||||||
val b = 3
|
val b = 3
|
||||||
val c = Complex(1.0, 1.0)
|
val c = Complex(1.0, 1.0)
|
||||||
|
|
||||||
(a pow b) + c
|
(a pow b) + c
|
||||||
}
|
}
|
||||||
|
}
|
||||||
}
|
|
||||||
|
@ -13,9 +13,8 @@ fun main() {
|
|||||||
val realField = NDField.real(dim, dim)
|
val realField = NDField.real(dim, dim)
|
||||||
val complexField = NDField.complex(dim, dim)
|
val complexField = NDField.complex(dim, dim)
|
||||||
|
|
||||||
|
|
||||||
val realTime = measureTimeMillis {
|
val realTime = measureTimeMillis {
|
||||||
realField.run {
|
realField {
|
||||||
var res: NDBuffer<Double> = one
|
var res: NDBuffer<Double> = one
|
||||||
repeat(n) {
|
repeat(n) {
|
||||||
res += 1.0
|
res += 1.0
|
||||||
@ -26,18 +25,15 @@ fun main() {
|
|||||||
println("Real addition completed in $realTime millis")
|
println("Real addition completed in $realTime millis")
|
||||||
|
|
||||||
val complexTime = measureTimeMillis {
|
val complexTime = measureTimeMillis {
|
||||||
complexField.run {
|
complexField {
|
||||||
var res: NDBuffer<Complex> = one
|
var res: NDBuffer<Complex> = one
|
||||||
repeat(n) {
|
repeat(n) { res += 1.0 }
|
||||||
res += 1.0
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
println("Complex addition completed in $complexTime millis")
|
println("Complex addition completed in $complexTime millis")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
fun complexExample() {
|
fun complexExample() {
|
||||||
//Create a context for 2-d structure with complex values
|
//Create a context for 2-d structure with complex values
|
||||||
ComplexField {
|
ComplexField {
|
||||||
@ -46,10 +42,7 @@ fun complexExample() {
|
|||||||
val x = one * 2.5
|
val x = one * 2.5
|
||||||
operator fun Number.plus(other: Complex) = Complex(this.toDouble() + other.re, other.im)
|
operator fun Number.plus(other: Complex) = Complex(this.toDouble() + other.re, other.im)
|
||||||
//a structure generator specific to this context
|
//a structure generator specific to this context
|
||||||
val matrix = produce { (k, l) ->
|
val matrix = produce { (k, l) -> k + l * i }
|
||||||
k + l * i
|
|
||||||
}
|
|
||||||
|
|
||||||
//Perform sum
|
//Perform sum
|
||||||
val sum = matrix + x + 1.0
|
val sum = matrix + x + 1.0
|
||||||
|
|
||||||
|
@ -2,14 +2,19 @@ package scientifik.kmath.structures
|
|||||||
|
|
||||||
import kotlinx.coroutines.GlobalScope
|
import kotlinx.coroutines.GlobalScope
|
||||||
import scientifik.kmath.operations.RealField
|
import scientifik.kmath.operations.RealField
|
||||||
|
import scientifik.kmath.operations.invoke
|
||||||
|
import kotlin.contracts.ExperimentalContracts
|
||||||
|
import kotlin.contracts.InvocationKind
|
||||||
|
import kotlin.contracts.contract
|
||||||
import kotlin.system.measureTimeMillis
|
import kotlin.system.measureTimeMillis
|
||||||
|
|
||||||
|
@OptIn(ExperimentalContracts::class)
|
||||||
internal inline fun measureAndPrint(title: String, block: () -> Unit) {
|
internal inline fun measureAndPrint(title: String, block: () -> Unit) {
|
||||||
|
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||||
val time = measureTimeMillis(block)
|
val time = measureTimeMillis(block)
|
||||||
println("$title completed in $time millis")
|
println("$title completed in $time millis")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
fun main() {
|
fun main() {
|
||||||
val dim = 1000
|
val dim = 1000
|
||||||
val n = 1000
|
val n = 1000
|
||||||
@ -22,27 +27,21 @@ fun main() {
|
|||||||
val genericField = NDField.boxing(RealField, dim, dim)
|
val genericField = NDField.boxing(RealField, dim, dim)
|
||||||
|
|
||||||
measureAndPrint("Automatic field addition") {
|
measureAndPrint("Automatic field addition") {
|
||||||
autoField.run {
|
autoField {
|
||||||
var res: NDBuffer<Double> = one
|
var res: NDBuffer<Double> = one
|
||||||
repeat(n) {
|
repeat(n) { res += number(1.0) }
|
||||||
res += number(1.0)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
measureAndPrint("Element addition") {
|
measureAndPrint("Element addition") {
|
||||||
var res = genericField.one
|
var res = genericField.one
|
||||||
repeat(n) {
|
repeat(n) { res += 1.0 }
|
||||||
res += 1.0
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
measureAndPrint("Specialized addition") {
|
measureAndPrint("Specialized addition") {
|
||||||
specializedField.run {
|
specializedField {
|
||||||
var res: NDBuffer<Double> = one
|
var res: NDBuffer<Double> = one
|
||||||
repeat(n) {
|
repeat(n) { res += 1.0 }
|
||||||
res += 1.0
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -60,12 +59,11 @@ fun main() {
|
|||||||
|
|
||||||
measureAndPrint("Generic addition") {
|
measureAndPrint("Generic addition") {
|
||||||
//genericField.run(action)
|
//genericField.run(action)
|
||||||
genericField.run {
|
genericField {
|
||||||
var res: NDBuffer<Double> = one
|
var res: NDBuffer<Double> = one
|
||||||
repeat(n) {
|
repeat(n) {
|
||||||
res += one // con't avoid using `one` due to resolution ambiguity
|
res += one // couldn't avoid using `one` due to resolution ambiguity }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
}
|
|
||||||
|
@ -16,6 +16,7 @@ fun DMatrixContext<Double, RealField>.simple() {
|
|||||||
|
|
||||||
|
|
||||||
object D5 : Dimension {
|
object D5 : Dimension {
|
||||||
|
@OptIn(ExperimentalUnsignedTypes::class)
|
||||||
override val dim: UInt = 5u
|
override val dim: UInt = 5u
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -23,13 +24,10 @@ fun DMatrixContext<Double, RealField>.custom() {
|
|||||||
val m1 = produce<D2, D5> { i, j -> (i + j).toDouble() }
|
val m1 = produce<D2, D5> { i, j -> (i + j).toDouble() }
|
||||||
val m2 = produce<D5, D2> { i, j -> (i - j).toDouble() }
|
val m2 = produce<D5, D2> { i, j -> (i - j).toDouble() }
|
||||||
val m3 = produce<D2, D2> { i, j -> (i - j).toDouble() }
|
val m3 = produce<D2, D2> { i, j -> (i - j).toDouble() }
|
||||||
|
|
||||||
(m1 dot m2) + m3
|
(m1 dot m2) + m3
|
||||||
}
|
}
|
||||||
|
|
||||||
fun main() {
|
fun main(): Unit = with(DMatrixContext.real) {
|
||||||
DMatrixContext.real.run {
|
simple()
|
||||||
simple()
|
custom()
|
||||||
custom()
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
|
@ -2,6 +2,9 @@ package scientifik.kmath.ast
|
|||||||
|
|
||||||
import scientifik.kmath.expressions.*
|
import scientifik.kmath.expressions.*
|
||||||
import scientifik.kmath.operations.*
|
import scientifik.kmath.operations.*
|
||||||
|
import kotlin.contracts.ExperimentalContracts
|
||||||
|
import kotlin.contracts.InvocationKind
|
||||||
|
import kotlin.contracts.contract
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The expression evaluates MST on-flight. Should be much faster than functional expression, but slower than
|
* The expression evaluates MST on-flight. Should be much faster than functional expression, but slower than
|
||||||
@ -24,7 +27,7 @@ class MstExpression<T>(val algebra: Algebra<T>, val mst: MST) : Expression<T> {
|
|||||||
error("Numeric nodes are not supported by $this")
|
error("Numeric nodes are not supported by $this")
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun invoke(arguments: Map<String, T>): T = InnerAlgebra(arguments).evaluate(mst)
|
override operator fun invoke(arguments: Map<String, T>): T = InnerAlgebra(arguments).evaluate(mst)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -38,51 +41,72 @@ inline fun <reified T : Any, A : Algebra<T>, E : Algebra<MST>> A.mst(
|
|||||||
/**
|
/**
|
||||||
* Builds [MstExpression] over [Space].
|
* Builds [MstExpression] over [Space].
|
||||||
*/
|
*/
|
||||||
inline fun <reified T : Any> Space<T>.mstInSpace(block: MstSpace.() -> MST): MstExpression<T> =
|
@OptIn(ExperimentalContracts::class)
|
||||||
MstExpression(this, MstSpace.block())
|
inline fun <reified T : Any> Space<T>.mstInSpace(block: MstSpace.() -> MST): MstExpression<T> {
|
||||||
|
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||||
|
return MstExpression(this, MstSpace.block())
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Builds [MstExpression] over [Ring].
|
* Builds [MstExpression] over [Ring].
|
||||||
*/
|
*/
|
||||||
inline fun <reified T : Any> Ring<T>.mstInRing(block: MstRing.() -> MST): MstExpression<T> =
|
@OptIn(ExperimentalContracts::class)
|
||||||
MstExpression(this, MstRing.block())
|
inline fun <reified T : Any> Ring<T>.mstInRing(block: MstRing.() -> MST): MstExpression<T> {
|
||||||
|
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||||
|
return MstExpression(this, MstRing.block())
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Builds [MstExpression] over [Field].
|
* Builds [MstExpression] over [Field].
|
||||||
*/
|
*/
|
||||||
inline fun <reified T : Any> Field<T>.mstInField(block: MstField.() -> MST): MstExpression<T> =
|
@OptIn(ExperimentalContracts::class)
|
||||||
MstExpression(this, MstField.block())
|
inline fun <reified T : Any> Field<T>.mstInField(block: MstField.() -> MST): MstExpression<T> {
|
||||||
|
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||||
|
return MstExpression(this, MstField.block())
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Builds [MstExpression] over [ExtendedField].
|
* Builds [MstExpression] over [ExtendedField].
|
||||||
*/
|
*/
|
||||||
inline fun <reified T : Any> Field<T>.mstInExtendedField(block: MstExtendedField.() -> MST): MstExpression<T> =
|
@OptIn(ExperimentalContracts::class)
|
||||||
MstExpression(this, MstExtendedField.block())
|
inline fun <reified T : Any> Field<T>.mstInExtendedField(block: MstExtendedField.() -> MST): MstExpression<T> {
|
||||||
|
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||||
|
return MstExpression(this, MstExtendedField.block())
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Builds [MstExpression] over [FunctionalExpressionSpace].
|
* Builds [MstExpression] over [FunctionalExpressionSpace].
|
||||||
*/
|
*/
|
||||||
inline fun <reified T : Any, A : Space<T>> FunctionalExpressionSpace<T, A>.mstInSpace(
|
@OptIn(ExperimentalContracts::class)
|
||||||
block: MstSpace.() -> MST
|
inline fun <reified T : Any, A : Space<T>> FunctionalExpressionSpace<T, A>.mstInSpace(block: MstSpace.() -> MST): MstExpression<T> {
|
||||||
): MstExpression<T> = algebra.mstInSpace(block)
|
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||||
|
return algebra.mstInSpace(block)
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Builds [MstExpression] over [FunctionalExpressionRing].
|
* Builds [MstExpression] over [FunctionalExpressionRing].
|
||||||
*/
|
*/
|
||||||
inline fun <reified T : Any, A : Ring<T>> FunctionalExpressionRing<T, A>.mstInRing(
|
|
||||||
block: MstRing.() -> MST
|
@OptIn(ExperimentalContracts::class)
|
||||||
): MstExpression<T> = algebra.mstInRing(block)
|
inline fun <reified T : Any, A : Ring<T>> FunctionalExpressionRing<T, A>.mstInRing(block: MstRing.() -> MST): MstExpression<T> {
|
||||||
|
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||||
|
return algebra.mstInRing(block)
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Builds [MstExpression] over [FunctionalExpressionField].
|
* Builds [MstExpression] over [FunctionalExpressionField].
|
||||||
*/
|
*/
|
||||||
inline fun <reified T : Any, A : Field<T>> FunctionalExpressionField<T, A>.mstInField(
|
@OptIn(ExperimentalContracts::class)
|
||||||
block: MstField.() -> MST
|
inline fun <reified T : Any, A : Field<T>> FunctionalExpressionField<T, A>.mstInField(block: MstField.() -> MST): MstExpression<T> {
|
||||||
): MstExpression<T> = algebra.mstInField(block)
|
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||||
|
return algebra.mstInField(block)
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Builds [MstExpression] over [FunctionalExpressionExtendedField].
|
* Builds [MstExpression] over [FunctionalExpressionExtendedField].
|
||||||
*/
|
*/
|
||||||
inline fun <reified T : Any, A : ExtendedField<T>> FunctionalExpressionExtendedField<T, A>.mstInExtendedField(
|
@OptIn(ExperimentalContracts::class)
|
||||||
block: MstExtendedField.() -> MST
|
inline fun <reified T : Any, A : ExtendedField<T>> FunctionalExpressionExtendedField<T, A>.mstInExtendedField(block: MstExtendedField.() -> MST): MstExpression<T> {
|
||||||
): MstExpression<T> = algebra.mstInExtendedField(block)
|
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||||
|
return algebra.mstInExtendedField(block)
|
||||||
|
}
|
||||||
|
@ -7,6 +7,9 @@ import scientifik.kmath.ast.MST
|
|||||||
import scientifik.kmath.expressions.Expression
|
import scientifik.kmath.expressions.Expression
|
||||||
import scientifik.kmath.operations.Algebra
|
import scientifik.kmath.operations.Algebra
|
||||||
import java.lang.reflect.Method
|
import java.lang.reflect.Method
|
||||||
|
import kotlin.contracts.ExperimentalContracts
|
||||||
|
import kotlin.contracts.InvocationKind
|
||||||
|
import kotlin.contracts.contract
|
||||||
import kotlin.reflect.KClass
|
import kotlin.reflect.KClass
|
||||||
|
|
||||||
private val methodNameAdapters: Map<Pair<String, Int>, String> by lazy {
|
private val methodNameAdapters: Map<Pair<String, Int>, String> by lazy {
|
||||||
@ -26,8 +29,11 @@ internal val KClass<*>.asm: Type
|
|||||||
/**
|
/**
|
||||||
* Returns singleton array with this value if the [predicate] is true, returns empty array otherwise.
|
* Returns singleton array with this value if the [predicate] is true, returns empty array otherwise.
|
||||||
*/
|
*/
|
||||||
internal inline fun <reified T> T.wrapToArrayIf(predicate: (T) -> Boolean): Array<T> =
|
@OptIn(ExperimentalContracts::class)
|
||||||
if (predicate(this)) arrayOf(this) else emptyArray()
|
internal inline fun <reified T> T.wrapToArrayIf(predicate: (T) -> Boolean): Array<T> {
|
||||||
|
contract { callsInPlace(predicate, InvocationKind.EXACTLY_ONCE) }
|
||||||
|
return if (predicate(this)) arrayOf(this) else emptyArray()
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates an [InstructionAdapter] from this [MethodVisitor].
|
* Creates an [InstructionAdapter] from this [MethodVisitor].
|
||||||
@ -37,8 +43,11 @@ private fun MethodVisitor.instructionAdapter(): InstructionAdapter = Instruction
|
|||||||
/**
|
/**
|
||||||
* Creates an [InstructionAdapter] from this [MethodVisitor] and applies [block] to it.
|
* Creates an [InstructionAdapter] from this [MethodVisitor] and applies [block] to it.
|
||||||
*/
|
*/
|
||||||
internal fun MethodVisitor.instructionAdapter(block: InstructionAdapter.() -> Unit): InstructionAdapter =
|
@OptIn(ExperimentalContracts::class)
|
||||||
instructionAdapter().apply(block)
|
internal inline fun MethodVisitor.instructionAdapter(block: InstructionAdapter.() -> Unit): InstructionAdapter {
|
||||||
|
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||||
|
return instructionAdapter().apply(block)
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Constructs a [Label], then applies it to this visitor.
|
* Constructs a [Label], then applies it to this visitor.
|
||||||
@ -63,10 +72,14 @@ internal tailrec fun buildName(mst: MST, collision: Int = 0): String {
|
|||||||
return buildName(mst, collision + 1)
|
return buildName(mst, collision + 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@OptIn(ExperimentalContracts::class)
|
||||||
@Suppress("FunctionName")
|
@Suppress("FunctionName")
|
||||||
internal inline fun ClassWriter(flags: Int, block: ClassWriter.() -> Unit): ClassWriter =
|
internal inline fun ClassWriter(flags: Int, block: ClassWriter.() -> Unit): ClassWriter {
|
||||||
ClassWriter(flags).apply(block)
|
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||||
|
return ClassWriter(flags).apply(block)
|
||||||
|
}
|
||||||
|
|
||||||
|
@OptIn(ExperimentalContracts::class)
|
||||||
internal inline fun ClassWriter.visitField(
|
internal inline fun ClassWriter.visitField(
|
||||||
access: Int,
|
access: Int,
|
||||||
name: String,
|
name: String,
|
||||||
@ -74,7 +87,10 @@ internal inline fun ClassWriter.visitField(
|
|||||||
signature: String?,
|
signature: String?,
|
||||||
value: Any?,
|
value: Any?,
|
||||||
block: FieldVisitor.() -> Unit
|
block: FieldVisitor.() -> Unit
|
||||||
): FieldVisitor = visitField(access, name, descriptor, signature, value).apply(block)
|
): FieldVisitor {
|
||||||
|
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||||
|
return visitField(access, name, descriptor, signature, value).apply(block)
|
||||||
|
}
|
||||||
|
|
||||||
private fun <T> AsmBuilder<T>.findSpecific(context: Algebra<T>, name: String, parameterTypes: Array<MstType>): Method? =
|
private fun <T> AsmBuilder<T>.findSpecific(context: Algebra<T>, name: String, parameterTypes: Array<MstType>): Method? =
|
||||||
context.javaClass.methods.find { method ->
|
context.javaClass.methods.find { method ->
|
||||||
@ -151,6 +167,7 @@ private fun <T> AsmBuilder<T>.tryInvokeSpecific(
|
|||||||
/**
|
/**
|
||||||
* Builds specialized algebra call with option to fallback to generic algebra operation accepting String.
|
* Builds specialized algebra call with option to fallback to generic algebra operation accepting String.
|
||||||
*/
|
*/
|
||||||
|
@OptIn(ExperimentalContracts::class)
|
||||||
internal inline fun <T> AsmBuilder<T>.buildAlgebraOperationCall(
|
internal inline fun <T> AsmBuilder<T>.buildAlgebraOperationCall(
|
||||||
context: Algebra<T>,
|
context: Algebra<T>,
|
||||||
name: String,
|
name: String,
|
||||||
@ -158,6 +175,7 @@ internal inline fun <T> AsmBuilder<T>.buildAlgebraOperationCall(
|
|||||||
parameterTypes: Array<MstType>,
|
parameterTypes: Array<MstType>,
|
||||||
parameters: AsmBuilder<T>.() -> Unit
|
parameters: AsmBuilder<T>.() -> Unit
|
||||||
) {
|
) {
|
||||||
|
contract { callsInPlace(parameters, InvocationKind.EXACTLY_ONCE) }
|
||||||
val arity = parameterTypes.size
|
val arity = parameterTypes.size
|
||||||
loadAlgebra()
|
loadAlgebra()
|
||||||
if (!buildExpectationStack(context, name, parameterTypes)) loadStringConstant(name)
|
if (!buildExpectationStack(context, name, parameterTypes)) loadStringConstant(name)
|
||||||
|
@ -5,6 +5,7 @@ import scientifik.kmath.expressions.Expression
|
|||||||
import scientifik.kmath.expressions.ExpressionAlgebra
|
import scientifik.kmath.expressions.ExpressionAlgebra
|
||||||
import scientifik.kmath.operations.ExtendedField
|
import scientifik.kmath.operations.ExtendedField
|
||||||
import scientifik.kmath.operations.Field
|
import scientifik.kmath.operations.Field
|
||||||
|
import scientifik.kmath.operations.invoke
|
||||||
import kotlin.properties.ReadOnlyProperty
|
import kotlin.properties.ReadOnlyProperty
|
||||||
import kotlin.reflect.KProperty
|
import kotlin.reflect.KProperty
|
||||||
|
|
||||||
@ -15,7 +16,6 @@ class DerivativeStructureField(
|
|||||||
val order: Int,
|
val order: Int,
|
||||||
val parameters: Map<String, Double>
|
val parameters: Map<String, Double>
|
||||||
) : ExtendedField<DerivativeStructure> {
|
) : ExtendedField<DerivativeStructure> {
|
||||||
|
|
||||||
override val zero: DerivativeStructure by lazy { DerivativeStructure(order, parameters.size) }
|
override val zero: DerivativeStructure by lazy { DerivativeStructure(order, parameters.size) }
|
||||||
override val one: DerivativeStructure by lazy { DerivativeStructure(order, parameters.size, 1.0) }
|
override val one: DerivativeStructure by lazy { DerivativeStructure(order, parameters.size, 1.0) }
|
||||||
|
|
||||||
@ -23,17 +23,15 @@ class DerivativeStructureField(
|
|||||||
DerivativeStructure(parameters.size, order, parameters.keys.indexOf(key), value)
|
DerivativeStructure(parameters.size, order, parameters.keys.indexOf(key), value)
|
||||||
}
|
}
|
||||||
|
|
||||||
val variable = object : ReadOnlyProperty<Any?, DerivativeStructure> {
|
val variable: ReadOnlyProperty<Any?, DerivativeStructure> = object : ReadOnlyProperty<Any?, DerivativeStructure> {
|
||||||
override fun getValue(thisRef: Any?, property: KProperty<*>): DerivativeStructure {
|
override fun getValue(thisRef: Any?, property: KProperty<*>): DerivativeStructure =
|
||||||
return variables[property.name] ?: error("A variable with name ${property.name} does not exist")
|
variables[property.name] ?: error("A variable with name ${property.name} does not exist")
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fun variable(name: String, default: DerivativeStructure? = null): DerivativeStructure =
|
fun variable(name: String, default: DerivativeStructure? = null): DerivativeStructure =
|
||||||
variables[name] ?: default ?: error("A variable with name $name does not exist")
|
variables[name] ?: default ?: error("A variable with name $name does not exist")
|
||||||
|
|
||||||
|
fun Number.const(): DerivativeStructure = DerivativeStructure(order, parameters.size, toDouble())
|
||||||
fun Number.const() = DerivativeStructure(order, parameters.size, toDouble())
|
|
||||||
|
|
||||||
fun DerivativeStructure.deriv(parName: String, order: Int = 1): Double {
|
fun DerivativeStructure.deriv(parName: String, order: Int = 1): Double {
|
||||||
return deriv(mapOf(parName to order))
|
return deriv(mapOf(parName to order))
|
||||||
@ -83,16 +81,15 @@ class DerivativeStructureField(
|
|||||||
|
|
||||||
override operator fun DerivativeStructure.plus(b: Number): DerivativeStructure = add(b.toDouble())
|
override operator fun DerivativeStructure.plus(b: Number): DerivativeStructure = add(b.toDouble())
|
||||||
override operator fun DerivativeStructure.minus(b: Number): DerivativeStructure = subtract(b.toDouble())
|
override operator fun DerivativeStructure.minus(b: Number): DerivativeStructure = subtract(b.toDouble())
|
||||||
override operator fun Number.plus(b: DerivativeStructure) = b + this
|
override operator fun Number.plus(b: DerivativeStructure): DerivativeStructure = b + this
|
||||||
override operator fun Number.minus(b: DerivativeStructure) = b - this
|
override operator fun Number.minus(b: DerivativeStructure): DerivativeStructure = b - this
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A constructs that creates a derivative structure with required order on-demand
|
* A constructs that creates a derivative structure with required order on-demand
|
||||||
*/
|
*/
|
||||||
class DiffExpression(val function: DerivativeStructureField.() -> DerivativeStructure) : Expression<Double> {
|
class DiffExpression(val function: DerivativeStructureField.() -> DerivativeStructure) : Expression<Double> {
|
||||||
|
override operator fun invoke(arguments: Map<String, Double>): Double = DerivativeStructureField(
|
||||||
override fun invoke(arguments: Map<String, Double>): Double = DerivativeStructureField(
|
|
||||||
0,
|
0,
|
||||||
arguments
|
arguments
|
||||||
).run(function).value
|
).run(function).value
|
||||||
@ -101,45 +98,40 @@ class DiffExpression(val function: DerivativeStructureField.() -> DerivativeStru
|
|||||||
* Get the derivative expression with given orders
|
* Get the derivative expression with given orders
|
||||||
* TODO make result [DiffExpression]
|
* TODO make result [DiffExpression]
|
||||||
*/
|
*/
|
||||||
fun derivative(orders: Map<String, Int>): Expression<Double> {
|
fun derivative(orders: Map<String, Int>): Expression<Double> = object : Expression<Double> {
|
||||||
return object : Expression<Double> {
|
override operator fun invoke(arguments: Map<String, Double>): Double =
|
||||||
override fun invoke(arguments: Map<String, Double>): Double =
|
(DerivativeStructureField(orders.values.max() ?: 0, arguments)) { function().deriv(orders) }
|
||||||
DerivativeStructureField(orders.values.max() ?: 0, arguments)
|
|
||||||
.run {
|
|
||||||
function().deriv(orders)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//TODO add gradient and maybe other vector operators
|
//TODO add gradient and maybe other vector operators
|
||||||
}
|
}
|
||||||
|
|
||||||
fun DiffExpression.derivative(vararg orders: Pair<String, Int>) = derivative(mapOf(*orders))
|
fun DiffExpression.derivative(vararg orders: Pair<String, Int>): Expression<Double> = derivative(mapOf(*orders))
|
||||||
fun DiffExpression.derivative(name: String) = derivative(name to 1)
|
fun DiffExpression.derivative(name: String): Expression<Double> = derivative(name to 1)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A context for [DiffExpression] (not to be confused with [DerivativeStructure])
|
* A context for [DiffExpression] (not to be confused with [DerivativeStructure])
|
||||||
*/
|
*/
|
||||||
object DiffExpressionAlgebra : ExpressionAlgebra<Double, DiffExpression>, Field<DiffExpression> {
|
object DiffExpressionAlgebra : ExpressionAlgebra<Double, DiffExpression>, Field<DiffExpression> {
|
||||||
override fun variable(name: String, default: Double?) =
|
override fun variable(name: String, default: Double?): DiffExpression =
|
||||||
DiffExpression { variable(name, default?.const()) }
|
DiffExpression { variable(name, default?.const()) }
|
||||||
|
|
||||||
override fun const(value: Double): DiffExpression =
|
override fun const(value: Double): DiffExpression =
|
||||||
DiffExpression { value.const() }
|
DiffExpression { value.const() }
|
||||||
|
|
||||||
override fun add(a: DiffExpression, b: DiffExpression) =
|
override fun add(a: DiffExpression, b: DiffExpression): DiffExpression =
|
||||||
DiffExpression { a.function(this) + b.function(this) }
|
DiffExpression { a.function(this) + b.function(this) }
|
||||||
|
|
||||||
override val zero = DiffExpression { 0.0.const() }
|
override val zero: DiffExpression = DiffExpression { 0.0.const() }
|
||||||
|
|
||||||
override fun multiply(a: DiffExpression, k: Number) =
|
override fun multiply(a: DiffExpression, k: Number): DiffExpression =
|
||||||
DiffExpression { a.function(this) * k }
|
DiffExpression { a.function(this) * k }
|
||||||
|
|
||||||
override val one = DiffExpression { 1.0.const() }
|
override val one: DiffExpression = DiffExpression { 1.0.const() }
|
||||||
|
|
||||||
override fun multiply(a: DiffExpression, b: DiffExpression) =
|
override fun multiply(a: DiffExpression, b: DiffExpression): DiffExpression =
|
||||||
DiffExpression { a.function(this) * b.function(this) }
|
DiffExpression { a.function(this) * b.function(this) }
|
||||||
|
|
||||||
override fun divide(a: DiffExpression, b: DiffExpression) =
|
override fun divide(a: DiffExpression, b: DiffExpression): DiffExpression =
|
||||||
DiffExpression { a.function(this) / b.function(this) }
|
DiffExpression { a.function(this) / b.function(this) }
|
||||||
}
|
}
|
||||||
|
@ -1,8 +1,6 @@
|
|||||||
package scientifik.kmath.commons.linear
|
package scientifik.kmath.commons.linear
|
||||||
|
|
||||||
import org.apache.commons.math3.linear.*
|
import org.apache.commons.math3.linear.*
|
||||||
import org.apache.commons.math3.linear.RealMatrix
|
|
||||||
import org.apache.commons.math3.linear.RealVector
|
|
||||||
import scientifik.kmath.linear.*
|
import scientifik.kmath.linear.*
|
||||||
import scientifik.kmath.structures.Matrix
|
import scientifik.kmath.structures.Matrix
|
||||||
import scientifik.kmath.structures.NDStructure
|
import scientifik.kmath.structures.NDStructure
|
||||||
@ -14,12 +12,12 @@ class CMMatrix(val origin: RealMatrix, features: Set<MatrixFeature>? = null) :
|
|||||||
|
|
||||||
override val features: Set<MatrixFeature> = features ?: sequence<MatrixFeature> {
|
override val features: Set<MatrixFeature> = features ?: sequence<MatrixFeature> {
|
||||||
if (origin is DiagonalMatrix) yield(DiagonalFeature)
|
if (origin is DiagonalMatrix) yield(DiagonalFeature)
|
||||||
}.toSet()
|
}.toHashSet()
|
||||||
|
|
||||||
override fun suggestFeature(vararg features: MatrixFeature) =
|
override fun suggestFeature(vararg features: MatrixFeature): CMMatrix =
|
||||||
CMMatrix(origin, this.features + features)
|
CMMatrix(origin, this.features + features)
|
||||||
|
|
||||||
override fun get(i: Int, j: Int): Double = origin.getEntry(i, j)
|
override operator fun get(i: Int, j: Int): Double = origin.getEntry(i, j)
|
||||||
|
|
||||||
override fun equals(other: Any?): Boolean {
|
override fun equals(other: Any?): Boolean {
|
||||||
return NDStructure.equals(this, other as? NDStructure<*> ?: return false)
|
return NDStructure.equals(this, other as? NDStructure<*> ?: return false)
|
||||||
@ -40,24 +38,22 @@ fun Matrix<Double>.toCM(): CMMatrix = if (this is CMMatrix) {
|
|||||||
CMMatrix(Array2DRowRealMatrix(array))
|
CMMatrix(Array2DRowRealMatrix(array))
|
||||||
}
|
}
|
||||||
|
|
||||||
fun RealMatrix.asMatrix() = CMMatrix(this)
|
fun RealMatrix.asMatrix(): CMMatrix = CMMatrix(this)
|
||||||
|
|
||||||
class CMVector(val origin: RealVector) : Point<Double> {
|
class CMVector(val origin: RealVector) : Point<Double> {
|
||||||
override val size: Int get() = origin.dimension
|
override val size: Int get() = origin.dimension
|
||||||
|
|
||||||
override fun get(index: Int): Double = origin.getEntry(index)
|
override operator fun get(index: Int): Double = origin.getEntry(index)
|
||||||
|
|
||||||
override fun iterator(): Iterator<Double> = origin.toArray().iterator()
|
override operator fun iterator(): Iterator<Double> = origin.toArray().iterator()
|
||||||
}
|
}
|
||||||
|
|
||||||
fun Point<Double>.toCM(): CMVector = if (this is CMVector) {
|
fun Point<Double>.toCM(): CMVector = if (this is CMVector) this else {
|
||||||
this
|
|
||||||
} else {
|
|
||||||
val array = DoubleArray(size) { this[it] }
|
val array = DoubleArray(size) { this[it] }
|
||||||
CMVector(ArrayRealVector(array))
|
CMVector(ArrayRealVector(array))
|
||||||
}
|
}
|
||||||
|
|
||||||
fun RealVector.toPoint() = CMVector(this)
|
fun RealVector.toPoint(): CMVector = CMVector(this)
|
||||||
|
|
||||||
object CMMatrixContext : MatrixContext<Double> {
|
object CMMatrixContext : MatrixContext<Double> {
|
||||||
override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> Double): CMMatrix {
|
override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> Double): CMMatrix {
|
||||||
@ -65,32 +61,33 @@ object CMMatrixContext : MatrixContext<Double> {
|
|||||||
return CMMatrix(Array2DRowRealMatrix(array))
|
return CMMatrix(Array2DRowRealMatrix(array))
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun Matrix<Double>.dot(other: Matrix<Double>) =
|
override fun Matrix<Double>.dot(other: Matrix<Double>): CMMatrix =
|
||||||
CMMatrix(this.toCM().origin.multiply(other.toCM().origin))
|
CMMatrix(this.toCM().origin.multiply(other.toCM().origin))
|
||||||
|
|
||||||
override fun Matrix<Double>.dot(vector: Point<Double>): CMVector =
|
override fun Matrix<Double>.dot(vector: Point<Double>): CMVector =
|
||||||
CMVector(this.toCM().origin.preMultiply(vector.toCM().origin))
|
CMVector(this.toCM().origin.preMultiply(vector.toCM().origin))
|
||||||
|
|
||||||
override fun Matrix<Double>.unaryMinus(): CMMatrix =
|
override operator fun Matrix<Double>.unaryMinus(): CMMatrix =
|
||||||
produce(rowNum, colNum) { i, j -> -get(i, j) }
|
produce(rowNum, colNum) { i, j -> -get(i, j) }
|
||||||
|
|
||||||
override fun add(a: Matrix<Double>, b: Matrix<Double>) =
|
override fun add(a: Matrix<Double>, b: Matrix<Double>): CMMatrix =
|
||||||
CMMatrix(a.toCM().origin.multiply(b.toCM().origin))
|
CMMatrix(a.toCM().origin.multiply(b.toCM().origin))
|
||||||
|
|
||||||
override fun Matrix<Double>.minus(b: Matrix<Double>) =
|
override operator fun Matrix<Double>.minus(b: Matrix<Double>): CMMatrix =
|
||||||
CMMatrix(this.toCM().origin.subtract(b.toCM().origin))
|
CMMatrix(this.toCM().origin.subtract(b.toCM().origin))
|
||||||
|
|
||||||
override fun multiply(a: Matrix<Double>, k: Number) =
|
override fun multiply(a: Matrix<Double>, k: Number): CMMatrix =
|
||||||
CMMatrix(a.toCM().origin.scalarMultiply(k.toDouble()))
|
CMMatrix(a.toCM().origin.scalarMultiply(k.toDouble()))
|
||||||
|
|
||||||
override fun Matrix<Double>.times(value: Double): Matrix<Double> =
|
override operator fun Matrix<Double>.times(value: Double): Matrix<Double> =
|
||||||
produce(rowNum, colNum) { i, j -> get(i, j) * value }
|
produce(rowNum, colNum) { i, j -> get(i, j) * value }
|
||||||
}
|
}
|
||||||
|
|
||||||
operator fun CMMatrix.plus(other: CMMatrix): CMMatrix =
|
operator fun CMMatrix.plus(other: CMMatrix): CMMatrix =
|
||||||
CMMatrix(this.origin.add(other.origin))
|
CMMatrix(this.origin.add(other.origin))
|
||||||
|
|
||||||
operator fun CMMatrix.minus(other: CMMatrix): CMMatrix =
|
operator fun CMMatrix.minus(other: CMMatrix): CMMatrix =
|
||||||
CMMatrix(this.origin.subtract(other.origin))
|
CMMatrix(this.origin.subtract(other.origin))
|
||||||
|
|
||||||
infix fun CMMatrix.dot(other: CMMatrix): CMMatrix =
|
infix fun CMMatrix.dot(other: CMMatrix): CMMatrix =
|
||||||
CMMatrix(this.origin.multiply(other.origin))
|
CMMatrix(this.origin.multiply(other.origin))
|
||||||
|
@ -4,10 +4,9 @@ import scientifik.kmath.prob.RandomGenerator
|
|||||||
|
|
||||||
class CMRandomGeneratorWrapper(val factory: (IntArray) -> RandomGenerator) :
|
class CMRandomGeneratorWrapper(val factory: (IntArray) -> RandomGenerator) :
|
||||||
org.apache.commons.math3.random.RandomGenerator {
|
org.apache.commons.math3.random.RandomGenerator {
|
||||||
private var generator = factory(intArrayOf())
|
private var generator: RandomGenerator = factory(intArrayOf())
|
||||||
|
|
||||||
override fun nextBoolean(): Boolean = generator.nextBoolean()
|
override fun nextBoolean(): Boolean = generator.nextBoolean()
|
||||||
|
|
||||||
override fun nextFloat(): Float = generator.nextDouble().toFloat()
|
override fun nextFloat(): Float = generator.nextDouble().toFloat()
|
||||||
|
|
||||||
override fun setSeed(seed: Int) {
|
override fun setSeed(seed: Int) {
|
||||||
@ -27,12 +26,8 @@ class CMRandomGeneratorWrapper(val factory: (IntArray) -> RandomGenerator) :
|
|||||||
}
|
}
|
||||||
|
|
||||||
override fun nextInt(): Int = generator.nextInt()
|
override fun nextInt(): Int = generator.nextInt()
|
||||||
|
|
||||||
override fun nextInt(n: Int): Int = generator.nextInt(n)
|
override fun nextInt(n: Int): Int = generator.nextInt(n)
|
||||||
|
|
||||||
override fun nextGaussian(): Double = TODO()
|
override fun nextGaussian(): Double = TODO()
|
||||||
|
|
||||||
override fun nextDouble(): Double = generator.nextDouble()
|
override fun nextDouble(): Double = generator.nextDouble()
|
||||||
|
|
||||||
override fun nextLong(): Long = generator.nextLong()
|
override fun nextLong(): Long = generator.nextLong()
|
||||||
}
|
}
|
||||||
|
@ -1,11 +1,17 @@
|
|||||||
package scientifik.kmath.commons.expressions
|
package scientifik.kmath.commons.expressions
|
||||||
|
|
||||||
import scientifik.kmath.expressions.invoke
|
import scientifik.kmath.expressions.invoke
|
||||||
|
import kotlin.contracts.ExperimentalContracts
|
||||||
|
import kotlin.contracts.InvocationKind
|
||||||
|
import kotlin.contracts.contract
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
inline fun <R> diff(order: Int, vararg parameters: Pair<String, Double>, block: DerivativeStructureField.() -> R) =
|
@OptIn(ExperimentalContracts::class)
|
||||||
DerivativeStructureField(order, mapOf(*parameters)).run(block)
|
inline fun <R> diff(order: Int, vararg parameters: Pair<String, Double>, block: DerivativeStructureField.() -> R): R {
|
||||||
|
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||||
|
return DerivativeStructureField(order, mapOf(*parameters)).run(block)
|
||||||
|
}
|
||||||
|
|
||||||
class AutoDiffTest {
|
class AutoDiffTest {
|
||||||
@Test
|
@Test
|
||||||
|
@ -4,28 +4,42 @@ import scientifik.kmath.operations.ExtendedField
|
|||||||
import scientifik.kmath.operations.Field
|
import scientifik.kmath.operations.Field
|
||||||
import scientifik.kmath.operations.Ring
|
import scientifik.kmath.operations.Ring
|
||||||
import scientifik.kmath.operations.Space
|
import scientifik.kmath.operations.Space
|
||||||
|
import kotlin.contracts.ExperimentalContracts
|
||||||
|
import kotlin.contracts.InvocationKind
|
||||||
|
import kotlin.contracts.contract
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates a functional expression with this [Space].
|
* Creates a functional expression with this [Space].
|
||||||
*/
|
*/
|
||||||
fun <T> Space<T>.spaceExpression(block: FunctionalExpressionSpace<T, Space<T>>.() -> Expression<T>): Expression<T> =
|
@OptIn(ExperimentalContracts::class)
|
||||||
FunctionalExpressionSpace(this).run(block)
|
inline fun <T> Space<T>.spaceExpression(block: FunctionalExpressionSpace<T, Space<T>>.() -> Expression<T>): Expression<T> {
|
||||||
|
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||||
|
return FunctionalExpressionSpace(this).block()
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates a functional expression with this [Ring].
|
* Creates a functional expression with this [Ring].
|
||||||
*/
|
*/
|
||||||
fun <T> Ring<T>.ringExpression(block: FunctionalExpressionRing<T, Ring<T>>.() -> Expression<T>): Expression<T> =
|
@OptIn(ExperimentalContracts::class)
|
||||||
FunctionalExpressionRing(this).run(block)
|
inline fun <T> Ring<T>.ringExpression(block: FunctionalExpressionRing<T, Ring<T>>.() -> Expression<T>): Expression<T> {
|
||||||
|
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||||
|
return FunctionalExpressionRing(this).block()
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates a functional expression with this [Field].
|
* Creates a functional expression with this [Field].
|
||||||
*/
|
*/
|
||||||
fun <T> Field<T>.fieldExpression(block: FunctionalExpressionField<T, Field<T>>.() -> Expression<T>): Expression<T> =
|
@OptIn(ExperimentalContracts::class)
|
||||||
FunctionalExpressionField(this).run(block)
|
inline fun <T> Field<T>.fieldExpression(block: FunctionalExpressionField<T, Field<T>>.() -> Expression<T>): Expression<T> {
|
||||||
|
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||||
|
return FunctionalExpressionField(this).block()
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates a functional expression with this [ExtendedField].
|
* Creates a functional expression with this [ExtendedField].
|
||||||
*/
|
*/
|
||||||
fun <T> ExtendedField<T>.fieldExpression(
|
@OptIn(ExperimentalContracts::class)
|
||||||
block: FunctionalExpressionExtendedField<T, ExtendedField<T>>.() -> Expression<T>
|
inline fun <T> ExtendedField<T>.extendedFieldExpression(block: FunctionalExpressionExtendedField<T, ExtendedField<T>>.() -> Expression<T>): Expression<T> {
|
||||||
): Expression<T> = FunctionalExpressionExtendedField(this).run(block)
|
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||||
|
return FunctionalExpressionExtendedField(this).block()
|
||||||
|
}
|
||||||
|
@ -22,7 +22,7 @@ interface Expression<T> {
|
|||||||
*/
|
*/
|
||||||
fun <T> Algebra<T>.expression(block: Algebra<T>.(arguments: Map<String, T>) -> T): Expression<T> =
|
fun <T> Algebra<T>.expression(block: Algebra<T>.(arguments: Map<String, T>) -> T): Expression<T> =
|
||||||
object : Expression<T> {
|
object : Expression<T> {
|
||||||
override fun invoke(arguments: Map<String, T>): T = block(arguments)
|
override operator fun invoke(arguments: Map<String, T>): T = block(arguments)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -4,7 +4,7 @@ import scientifik.kmath.operations.*
|
|||||||
|
|
||||||
internal class FunctionalUnaryOperation<T>(val context: Algebra<T>, val name: String, private val expr: Expression<T>) :
|
internal class FunctionalUnaryOperation<T>(val context: Algebra<T>, val name: String, private val expr: Expression<T>) :
|
||||||
Expression<T> {
|
Expression<T> {
|
||||||
override fun invoke(arguments: Map<String, T>): T = context.unaryOperation(name, expr.invoke(arguments))
|
override operator fun invoke(arguments: Map<String, T>): T = context.unaryOperation(name, expr.invoke(arguments))
|
||||||
}
|
}
|
||||||
|
|
||||||
internal class FunctionalBinaryOperation<T>(
|
internal class FunctionalBinaryOperation<T>(
|
||||||
@ -13,17 +13,17 @@ internal class FunctionalBinaryOperation<T>(
|
|||||||
val first: Expression<T>,
|
val first: Expression<T>,
|
||||||
val second: Expression<T>
|
val second: Expression<T>
|
||||||
) : Expression<T> {
|
) : Expression<T> {
|
||||||
override fun invoke(arguments: Map<String, T>): T =
|
override operator fun invoke(arguments: Map<String, T>): T =
|
||||||
context.binaryOperation(name, first.invoke(arguments), second.invoke(arguments))
|
context.binaryOperation(name, first.invoke(arguments), second.invoke(arguments))
|
||||||
}
|
}
|
||||||
|
|
||||||
internal class FunctionalVariableExpression<T>(val name: String, val default: T? = null) : Expression<T> {
|
internal class FunctionalVariableExpression<T>(val name: String, val default: T? = null) : Expression<T> {
|
||||||
override fun invoke(arguments: Map<String, T>): T =
|
override operator fun invoke(arguments: Map<String, T>): T =
|
||||||
arguments[name] ?: default ?: error("Parameter not found: $name")
|
arguments[name] ?: default ?: error("Parameter not found: $name")
|
||||||
}
|
}
|
||||||
|
|
||||||
internal class FunctionalConstantExpression<T>(val value: T) : Expression<T> {
|
internal class FunctionalConstantExpression<T>(val value: T) : Expression<T> {
|
||||||
override fun invoke(arguments: Map<String, T>): T = value
|
override operator fun invoke(arguments: Map<String, T>): T = value
|
||||||
}
|
}
|
||||||
|
|
||||||
internal class FunctionalConstProductExpression<T>(
|
internal class FunctionalConstProductExpression<T>(
|
||||||
@ -31,7 +31,7 @@ internal class FunctionalConstProductExpression<T>(
|
|||||||
private val expr: Expression<T>,
|
private val expr: Expression<T>,
|
||||||
val const: Number
|
val const: Number
|
||||||
) : Expression<T> {
|
) : Expression<T> {
|
||||||
override fun invoke(arguments: Map<String, T>): T = context.multiply(expr.invoke(arguments), const)
|
override operator fun invoke(arguments: Map<String, T>): T = context.multiply(expr.invoke(arguments), const)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -53,16 +53,12 @@ class BufferMatrix<T : Any>(
|
|||||||
override fun suggestFeature(vararg features: MatrixFeature): BufferMatrix<T> =
|
override fun suggestFeature(vararg features: MatrixFeature): BufferMatrix<T> =
|
||||||
BufferMatrix(rowNum, colNum, buffer, this.features + features)
|
BufferMatrix(rowNum, colNum, buffer, this.features + features)
|
||||||
|
|
||||||
override fun get(index: IntArray): T = get(index[0], index[1])
|
override operator fun get(index: IntArray): T = get(index[0], index[1])
|
||||||
|
|
||||||
override fun get(i: Int, j: Int): T = buffer[i * colNum + j]
|
override operator fun get(i: Int, j: Int): T = buffer[i * colNum + j]
|
||||||
|
|
||||||
override fun elements(): Sequence<Pair<IntArray, T>> = sequence {
|
override fun elements(): Sequence<Pair<IntArray, T>> = sequence {
|
||||||
for (i in 0 until rowNum) {
|
for (i in 0 until rowNum) for (j in 0 until colNum) yield(intArrayOf(i, j) to get(i, j))
|
||||||
for (j in 0 until colNum) {
|
|
||||||
yield(intArrayOf(i, j) to get(i, j))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun equals(other: Any?): Boolean {
|
override fun equals(other: Any?): Boolean {
|
||||||
@ -95,7 +91,7 @@ class BufferMatrix<T : Any>(
|
|||||||
* Optimized dot product for real matrices
|
* Optimized dot product for real matrices
|
||||||
*/
|
*/
|
||||||
infix fun BufferMatrix<Double>.dot(other: BufferMatrix<Double>): BufferMatrix<Double> {
|
infix fun BufferMatrix<Double>.dot(other: BufferMatrix<Double>): BufferMatrix<Double> {
|
||||||
if (this.colNum != other.rowNum) error("Matrix dot operation dimension mismatch: ($rowNum, $colNum) x (${other.rowNum}, ${other.colNum})")
|
require(colNum == other.rowNum) { "Matrix dot operation dimension mismatch: ($rowNum, $colNum) x (${other.rowNum}, ${other.colNum})" }
|
||||||
|
|
||||||
val array = DoubleArray(this.rowNum * other.colNum)
|
val array = DoubleArray(this.rowNum * other.colNum)
|
||||||
|
|
||||||
|
@ -4,6 +4,8 @@ import scientifik.kmath.operations.Ring
|
|||||||
import scientifik.kmath.structures.Matrix
|
import scientifik.kmath.structures.Matrix
|
||||||
import scientifik.kmath.structures.Structure2D
|
import scientifik.kmath.structures.Structure2D
|
||||||
import scientifik.kmath.structures.asBuffer
|
import scientifik.kmath.structures.asBuffer
|
||||||
|
import kotlin.contracts.ExperimentalContracts
|
||||||
|
import kotlin.contracts.contract
|
||||||
import kotlin.math.sqrt
|
import kotlin.math.sqrt
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -26,15 +28,18 @@ interface FeaturedMatrix<T : Any> : Matrix<T> {
|
|||||||
companion object
|
companion object
|
||||||
}
|
}
|
||||||
|
|
||||||
fun Structure2D.Companion.real(rows: Int, columns: Int, initializer: (Int, Int) -> Double): Matrix<Double> =
|
@OptIn(ExperimentalContracts::class)
|
||||||
MatrixContext.real.produce(rows, columns, initializer)
|
inline fun Structure2D.Companion.real(rows: Int, columns: Int, initializer: (Int, Int) -> Double): Matrix<Double> {
|
||||||
|
contract { callsInPlace(initializer) }
|
||||||
|
return MatrixContext.real.produce(rows, columns, initializer)
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Build a square matrix from given elements.
|
* Build a square matrix from given elements.
|
||||||
*/
|
*/
|
||||||
fun <T : Any> Structure2D.Companion.square(vararg elements: T): FeaturedMatrix<T> {
|
fun <T : Any> Structure2D.Companion.square(vararg elements: T): FeaturedMatrix<T> {
|
||||||
val size: Int = sqrt(elements.size.toDouble()).toInt()
|
val size: Int = sqrt(elements.size.toDouble()).toInt()
|
||||||
if (size * size != elements.size) error("The number of elements ${elements.size} is not a full square")
|
require(size * size == elements.size) { "The number of elements ${elements.size} is not a full square" }
|
||||||
val buffer = elements.asBuffer()
|
val buffer = elements.asBuffer()
|
||||||
return BufferMatrix(size, size, buffer)
|
return BufferMatrix(size, size, buffer)
|
||||||
}
|
}
|
||||||
|
@ -3,6 +3,7 @@ package scientifik.kmath.linear
|
|||||||
import scientifik.kmath.operations.Field
|
import scientifik.kmath.operations.Field
|
||||||
import scientifik.kmath.operations.RealField
|
import scientifik.kmath.operations.RealField
|
||||||
import scientifik.kmath.operations.Ring
|
import scientifik.kmath.operations.Ring
|
||||||
|
import scientifik.kmath.operations.invoke
|
||||||
import scientifik.kmath.structures.BufferAccessor2D
|
import scientifik.kmath.structures.BufferAccessor2D
|
||||||
import scientifik.kmath.structures.Matrix
|
import scientifik.kmath.structures.Matrix
|
||||||
import scientifik.kmath.structures.Structure2D
|
import scientifik.kmath.structures.Structure2D
|
||||||
@ -60,15 +61,13 @@ class LUPDecomposition<T : Any>(
|
|||||||
* @return determinant of the matrix
|
* @return determinant of the matrix
|
||||||
*/
|
*/
|
||||||
override val determinant: T by lazy {
|
override val determinant: T by lazy {
|
||||||
with(elementContext) {
|
elementContext { (0 until lu.shape[0]).fold(if (even) one else -one) { value, i -> value * lu[i, i] } }
|
||||||
(0 until lu.shape[0]).fold(if (even) one else -one) { value, i -> value * lu[i, i] }
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fun <T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.abs(value: T): T =
|
fun <T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.abs(value: T): T =
|
||||||
if (value > elementContext.zero) value else with(elementContext) { -value }
|
if (value > elementContext.zero) value else elementContext { -value }
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -88,43 +87,34 @@ fun <T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.lup(
|
|||||||
|
|
||||||
//TODO just waits for KEEP-176
|
//TODO just waits for KEEP-176
|
||||||
BufferAccessor2D(type, matrix.rowNum, matrix.colNum).run {
|
BufferAccessor2D(type, matrix.rowNum, matrix.colNum).run {
|
||||||
elementContext.run {
|
elementContext {
|
||||||
|
|
||||||
val lu = create(matrix)
|
val lu = create(matrix)
|
||||||
|
|
||||||
// Initialize permutation array and parity
|
// Initialize permutation array and parity
|
||||||
for (row in 0 until m) {
|
for (row in 0 until m) pivot[row] = row
|
||||||
pivot[row] = row
|
|
||||||
}
|
|
||||||
var even = true
|
var even = true
|
||||||
|
|
||||||
// Initialize permutation array and parity
|
// Initialize permutation array and parity
|
||||||
for (row in 0 until m) {
|
for (row in 0 until m) pivot[row] = row
|
||||||
pivot[row] = row
|
|
||||||
}
|
|
||||||
|
|
||||||
// Loop over columns
|
// Loop over columns
|
||||||
for (col in 0 until m) {
|
for (col in 0 until m) {
|
||||||
|
|
||||||
// upper
|
// upper
|
||||||
for (row in 0 until col) {
|
for (row in 0 until col) {
|
||||||
val luRow = lu.row(row)
|
val luRow = lu.row(row)
|
||||||
var sum = luRow[col]
|
var sum = luRow[col]
|
||||||
for (i in 0 until row) {
|
for (i in 0 until row) sum -= luRow[i] * lu[i, col]
|
||||||
sum -= luRow[i] * lu[i, col]
|
|
||||||
}
|
|
||||||
luRow[col] = sum
|
luRow[col] = sum
|
||||||
}
|
}
|
||||||
|
|
||||||
// lower
|
// lower
|
||||||
var max = col // permutation row
|
var max = col // permutation row
|
||||||
var largest = -one
|
var largest = -one
|
||||||
|
|
||||||
for (row in col until m) {
|
for (row in col until m) {
|
||||||
val luRow = lu.row(row)
|
val luRow = lu.row(row)
|
||||||
var sum = luRow[col]
|
var sum = luRow[col]
|
||||||
for (i in 0 until col) {
|
for (i in 0 until col) sum -= luRow[i] * lu[i, col]
|
||||||
sum -= luRow[i] * lu[i, col]
|
|
||||||
}
|
|
||||||
luRow[col] = sum
|
luRow[col] = sum
|
||||||
|
|
||||||
// maintain best permutation choice
|
// maintain best permutation choice
|
||||||
@ -135,19 +125,19 @@ fun <T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.lup(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Singularity check
|
// Singularity check
|
||||||
if (checkSingular(this@lup.abs(lu[max, col]))) {
|
check(!checkSingular(this@lup.abs(lu[max, col]))) { "The matrix is singular" }
|
||||||
error("The matrix is singular")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Pivot if necessary
|
// Pivot if necessary
|
||||||
if (max != col) {
|
if (max != col) {
|
||||||
val luMax = lu.row(max)
|
val luMax = lu.row(max)
|
||||||
val luCol = lu.row(col)
|
val luCol = lu.row(col)
|
||||||
|
|
||||||
for (i in 0 until m) {
|
for (i in 0 until m) {
|
||||||
val tmp = luMax[i]
|
val tmp = luMax[i]
|
||||||
luMax[i] = luCol[i]
|
luMax[i] = luCol[i]
|
||||||
luCol[i] = tmp
|
luCol[i] = tmp
|
||||||
}
|
}
|
||||||
|
|
||||||
val temp = pivot[max]
|
val temp = pivot[max]
|
||||||
pivot[max] = pivot[col]
|
pivot[max] = pivot[col]
|
||||||
pivot[col] = temp
|
pivot[col] = temp
|
||||||
@ -156,9 +146,7 @@ fun <T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.lup(
|
|||||||
|
|
||||||
// Divide the lower elements by the "winning" diagonal elt.
|
// Divide the lower elements by the "winning" diagonal elt.
|
||||||
val luDiag = lu[col, col]
|
val luDiag = lu[col, col]
|
||||||
for (row in col + 1 until m) {
|
for (row in col + 1 until m) lu[row, col] /= luDiag
|
||||||
lu[row, col] /= luDiag
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return LUPDecomposition(this@lup, lu.collect(), pivot, even)
|
return LUPDecomposition(this@lup, lu.collect(), pivot, even)
|
||||||
@ -175,28 +163,23 @@ fun GenericMatrixContext<Double, RealField>.lup(matrix: Matrix<Double>): LUPDeco
|
|||||||
lup(Double::class, matrix) { it < 1e-11 }
|
lup(Double::class, matrix) { it < 1e-11 }
|
||||||
|
|
||||||
fun <T : Any> LUPDecomposition<T>.solve(type: KClass<T>, matrix: Matrix<T>): Matrix<T> {
|
fun <T : Any> LUPDecomposition<T>.solve(type: KClass<T>, matrix: Matrix<T>): Matrix<T> {
|
||||||
|
require(matrix.rowNum == pivot.size) { "Matrix dimension mismatch. Expected ${pivot.size}, but got ${matrix.colNum}" }
|
||||||
if (matrix.rowNum != pivot.size) {
|
|
||||||
error("Matrix dimension mismatch. Expected ${pivot.size}, but got ${matrix.colNum}")
|
|
||||||
}
|
|
||||||
|
|
||||||
BufferAccessor2D(type, matrix.rowNum, matrix.colNum).run {
|
BufferAccessor2D(type, matrix.rowNum, matrix.colNum).run {
|
||||||
elementContext.run {
|
elementContext {
|
||||||
|
|
||||||
// Apply permutations to b
|
// Apply permutations to b
|
||||||
val bp = create { _, _ -> zero }
|
val bp = create { _, _ -> zero }
|
||||||
|
|
||||||
for (row in pivot.indices) {
|
for (row in pivot.indices) {
|
||||||
val bpRow = bp.row(row)
|
val bpRow = bp.row(row)
|
||||||
val pRow = pivot[row]
|
val pRow = pivot[row]
|
||||||
for (col in 0 until matrix.colNum) {
|
for (col in 0 until matrix.colNum) bpRow[col] = matrix[pRow, col]
|
||||||
bpRow[col] = matrix[pRow, col]
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Solve LY = b
|
// Solve LY = b
|
||||||
for (col in pivot.indices) {
|
for (col in pivot.indices) {
|
||||||
val bpCol = bp.row(col)
|
val bpCol = bp.row(col)
|
||||||
|
|
||||||
for (i in col + 1 until pivot.size) {
|
for (i in col + 1 until pivot.size) {
|
||||||
val bpI = bp.row(i)
|
val bpI = bp.row(i)
|
||||||
val luICol = lu[i, col]
|
val luICol = lu[i, col]
|
||||||
@ -210,17 +193,15 @@ fun <T : Any> LUPDecomposition<T>.solve(type: KClass<T>, matrix: Matrix<T>): Mat
|
|||||||
for (col in pivot.size - 1 downTo 0) {
|
for (col in pivot.size - 1 downTo 0) {
|
||||||
val bpCol = bp.row(col)
|
val bpCol = bp.row(col)
|
||||||
val luDiag = lu[col, col]
|
val luDiag = lu[col, col]
|
||||||
for (j in 0 until matrix.colNum) {
|
for (j in 0 until matrix.colNum) bpCol[j] /= luDiag
|
||||||
bpCol[j] /= luDiag
|
|
||||||
}
|
|
||||||
for (i in 0 until col) {
|
for (i in 0 until col) {
|
||||||
val bpI = bp.row(i)
|
val bpI = bp.row(i)
|
||||||
val luICol = lu[i, col]
|
val luICol = lu[i, col]
|
||||||
for (j in 0 until matrix.colNum) {
|
for (j in 0 until matrix.colNum) bpI[j] -= bpCol[j] * luICol
|
||||||
bpI[j] -= bpCol[j] * luICol
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return context.produce(pivot.size, matrix.colNum) { i, j -> bp[i, j] }
|
return context.produce(pivot.size, matrix.colNum) { i, j -> bp[i, j] }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -7,7 +7,7 @@ import scientifik.kmath.structures.asBuffer
|
|||||||
|
|
||||||
class MatrixBuilder(val rows: Int, val columns: Int) {
|
class MatrixBuilder(val rows: Int, val columns: Int) {
|
||||||
operator fun <T : Any> invoke(vararg elements: T): FeaturedMatrix<T> {
|
operator fun <T : Any> invoke(vararg elements: T): FeaturedMatrix<T> {
|
||||||
if (rows * columns != elements.size) error("The number of elements ${elements.size} is not equal $rows * $columns")
|
require(rows * columns == elements.size) { "The number of elements ${elements.size} is not equal $rows * $columns" }
|
||||||
val buffer = elements.asBuffer()
|
val buffer = elements.asBuffer()
|
||||||
return BufferMatrix(rows, columns, buffer)
|
return BufferMatrix(rows, columns, buffer)
|
||||||
}
|
}
|
||||||
|
@ -2,6 +2,7 @@ package scientifik.kmath.linear
|
|||||||
|
|
||||||
import scientifik.kmath.operations.Ring
|
import scientifik.kmath.operations.Ring
|
||||||
import scientifik.kmath.operations.SpaceOperations
|
import scientifik.kmath.operations.SpaceOperations
|
||||||
|
import scientifik.kmath.operations.invoke
|
||||||
import scientifik.kmath.operations.sum
|
import scientifik.kmath.operations.sum
|
||||||
import scientifik.kmath.structures.Buffer
|
import scientifik.kmath.structures.Buffer
|
||||||
import scientifik.kmath.structures.BufferFactory
|
import scientifik.kmath.structures.BufferFactory
|
||||||
@ -37,8 +38,7 @@ interface MatrixContext<T : Any> : SpaceOperations<Matrix<T>> {
|
|||||||
fun <T : Any, R : Ring<T>> buffered(
|
fun <T : Any, R : Ring<T>> buffered(
|
||||||
ring: R,
|
ring: R,
|
||||||
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing
|
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing
|
||||||
): GenericMatrixContext<T, R> =
|
): GenericMatrixContext<T, R> = BufferMatrixContext(ring, bufferFactory)
|
||||||
BufferMatrixContext(ring, bufferFactory)
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Automatic buffered matrix, unboxed if it is possible
|
* Automatic buffered matrix, unboxed if it is possible
|
||||||
@ -61,45 +61,49 @@ interface GenericMatrixContext<T : Any, R : Ring<T>> : MatrixContext<T> {
|
|||||||
|
|
||||||
override infix fun Matrix<T>.dot(other: Matrix<T>): Matrix<T> {
|
override infix fun Matrix<T>.dot(other: Matrix<T>): Matrix<T> {
|
||||||
//TODO add typed error
|
//TODO add typed error
|
||||||
if (this.colNum != other.rowNum) error("Matrix dot operation dimension mismatch: ($rowNum, $colNum) x (${other.rowNum}, ${other.colNum})")
|
require(colNum == other.rowNum) { "Matrix dot operation dimension mismatch: ($rowNum, $colNum) x (${other.rowNum}, ${other.colNum})" }
|
||||||
|
|
||||||
return produce(rowNum, other.colNum) { i, j ->
|
return produce(rowNum, other.colNum) { i, j ->
|
||||||
val row = rows[i]
|
val row = rows[i]
|
||||||
val column = other.columns[j]
|
val column = other.columns[j]
|
||||||
with(elementContext) {
|
elementContext { sum(row.asSequence().zip(column.asSequence(), ::multiply)) }
|
||||||
sum(row.asSequence().zip(column.asSequence(), ::multiply))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override infix fun Matrix<T>.dot(vector: Point<T>): Point<T> {
|
override infix fun Matrix<T>.dot(vector: Point<T>): Point<T> {
|
||||||
//TODO add typed error
|
//TODO add typed error
|
||||||
if (this.colNum != vector.size) error("Matrix dot vector operation dimension mismatch: ($rowNum, $colNum) x (${vector.size})")
|
require(colNum == vector.size) { "Matrix dot vector operation dimension mismatch: ($rowNum, $colNum) x (${vector.size})" }
|
||||||
|
|
||||||
return point(rowNum) { i ->
|
return point(rowNum) { i ->
|
||||||
val row = rows[i]
|
val row = rows[i]
|
||||||
with(elementContext) {
|
elementContext { sum(row.asSequence().zip(vector.asSequence(), ::multiply)) }
|
||||||
sum(row.asSequence().zip(vector.asSequence(), ::multiply))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override operator fun Matrix<T>.unaryMinus(): Matrix<T> =
|
override operator fun Matrix<T>.unaryMinus(): Matrix<T> =
|
||||||
produce(rowNum, colNum) { i, j -> elementContext.run { -get(i, j) } }
|
produce(rowNum, colNum) { i, j -> elementContext { -get(i, j) } }
|
||||||
|
|
||||||
override fun add(a: Matrix<T>, b: Matrix<T>): Matrix<T> {
|
override fun add(a: Matrix<T>, b: Matrix<T>): Matrix<T> {
|
||||||
if (a.rowNum != b.rowNum || a.colNum != b.colNum) error("Matrix operation dimension mismatch. [${a.rowNum},${a.colNum}] + [${b.rowNum},${b.colNum}]")
|
require(a.rowNum == b.rowNum && a.colNum == b.colNum) {
|
||||||
return produce(a.rowNum, a.colNum) { i, j -> elementContext.run { a[i, j] + b[i, j] } }
|
"Matrix operation dimension mismatch. [${a.rowNum},${a.colNum}] + [${b.rowNum},${b.colNum}]"
|
||||||
|
}
|
||||||
|
|
||||||
|
return produce(a.rowNum, a.colNum) { i, j -> elementContext { a[i, j] + b[i, j] } }
|
||||||
}
|
}
|
||||||
|
|
||||||
override operator fun Matrix<T>.minus(b: Matrix<T>): Matrix<T> {
|
override operator fun Matrix<T>.minus(b: Matrix<T>): Matrix<T> {
|
||||||
if (rowNum != b.rowNum || colNum != b.colNum) error("Matrix operation dimension mismatch. [$rowNum,$colNum] - [${b.rowNum},${b.colNum}]")
|
require(rowNum == b.rowNum && colNum == b.colNum) {
|
||||||
return produce(rowNum, colNum) { i, j -> elementContext.run { get(i, j) + b[i, j] } }
|
"Matrix operation dimension mismatch. [$rowNum,$colNum] - [${b.rowNum},${b.colNum}]"
|
||||||
|
}
|
||||||
|
|
||||||
|
return produce(rowNum, colNum) { i, j -> elementContext { get(i, j) + b[i, j] } }
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun multiply(a: Matrix<T>, k: Number): Matrix<T> =
|
override fun multiply(a: Matrix<T>, k: Number): Matrix<T> =
|
||||||
produce(a.rowNum, a.colNum) { i, j -> elementContext.run { a[i, j] * k } }
|
produce(a.rowNum, a.colNum) { i, j -> elementContext { a[i, j] * k } }
|
||||||
|
|
||||||
operator fun Number.times(matrix: FeaturedMatrix<T>): Matrix<T> = matrix * this
|
operator fun Number.times(matrix: FeaturedMatrix<T>): Matrix<T> = matrix * this
|
||||||
|
|
||||||
override fun Matrix<T>.times(value: T): Matrix<T> =
|
override operator fun Matrix<T>.times(value: T): Matrix<T> =
|
||||||
produce(rowNum, colNum) { i, j -> elementContext.run { get(i, j) * value } }
|
produce(rowNum, colNum) { i, j -> elementContext { get(i, j) * value } }
|
||||||
}
|
}
|
||||||
|
@ -2,6 +2,7 @@ package scientifik.kmath.linear
|
|||||||
|
|
||||||
import scientifik.kmath.operations.RealField
|
import scientifik.kmath.operations.RealField
|
||||||
import scientifik.kmath.operations.Space
|
import scientifik.kmath.operations.Space
|
||||||
|
import scientifik.kmath.operations.invoke
|
||||||
import scientifik.kmath.structures.Buffer
|
import scientifik.kmath.structures.Buffer
|
||||||
import scientifik.kmath.structures.BufferFactory
|
import scientifik.kmath.structures.BufferFactory
|
||||||
|
|
||||||
@ -10,10 +11,9 @@ import scientifik.kmath.structures.BufferFactory
|
|||||||
* Could be used on any point-like structure
|
* Could be used on any point-like structure
|
||||||
*/
|
*/
|
||||||
interface VectorSpace<T : Any, S : Space<T>> : Space<Point<T>> {
|
interface VectorSpace<T : Any, S : Space<T>> : Space<Point<T>> {
|
||||||
|
|
||||||
val size: Int
|
val size: Int
|
||||||
|
|
||||||
val space: S
|
val space: S
|
||||||
|
override val zero: Point<T> get() = produce { space.zero }
|
||||||
|
|
||||||
fun produce(initializer: (Int) -> T): Point<T>
|
fun produce(initializer: (Int) -> T): Point<T>
|
||||||
|
|
||||||
@ -22,29 +22,24 @@ interface VectorSpace<T : Any, S : Space<T>> : Space<Point<T>> {
|
|||||||
*/
|
*/
|
||||||
//fun produceElement(initializer: (Int) -> T): Vector<T, S>
|
//fun produceElement(initializer: (Int) -> T): Vector<T, S>
|
||||||
|
|
||||||
override val zero: Point<T> get() = produce { space.zero }
|
override fun add(a: Point<T>, b: Point<T>): Point<T> = produce { space { a[it] + b[it] } }
|
||||||
|
|
||||||
override fun add(a: Point<T>, b: Point<T>): Point<T> = produce { with(space) { a[it] + b[it] } }
|
override fun multiply(a: Point<T>, k: Number): Point<T> = produce { space { a[it] * k } }
|
||||||
|
|
||||||
override fun multiply(a: Point<T>, k: Number): Point<T> = produce { with(space) { a[it] * k } }
|
|
||||||
|
|
||||||
//TODO add basis
|
//TODO add basis
|
||||||
|
|
||||||
companion object {
|
companion object {
|
||||||
|
private val realSpaceCache: MutableMap<Int, BufferVectorSpace<Double, RealField>> = hashMapOf()
|
||||||
private val realSpaceCache = HashMap<Int, BufferVectorSpace<Double, RealField>>()
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Non-boxing double vector space
|
* Non-boxing double vector space
|
||||||
*/
|
*/
|
||||||
fun real(size: Int): BufferVectorSpace<Double, RealField> {
|
fun real(size: Int): BufferVectorSpace<Double, RealField> = realSpaceCache.getOrPut(size) {
|
||||||
return realSpaceCache.getOrPut(size) {
|
BufferVectorSpace(
|
||||||
BufferVectorSpace(
|
size,
|
||||||
size,
|
RealField,
|
||||||
RealField,
|
Buffer.Companion::auto
|
||||||
Buffer.Companion::auto
|
)
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -18,7 +18,7 @@ class VirtualMatrix<T : Any>(
|
|||||||
|
|
||||||
override val shape: IntArray get() = intArrayOf(rowNum, colNum)
|
override val shape: IntArray get() = intArrayOf(rowNum, colNum)
|
||||||
|
|
||||||
override fun get(i: Int, j: Int): T = generator(i, j)
|
override operator fun get(i: Int, j: Int): T = generator(i, j)
|
||||||
|
|
||||||
override fun suggestFeature(vararg features: MatrixFeature): VirtualMatrix<T> =
|
override fun suggestFeature(vararg features: MatrixFeature): VirtualMatrix<T> =
|
||||||
VirtualMatrix(rowNum, colNum, this.features + features, generator)
|
VirtualMatrix(rowNum, colNum, this.features + features, generator)
|
||||||
|
@ -3,8 +3,12 @@ package scientifik.kmath.misc
|
|||||||
import scientifik.kmath.linear.Point
|
import scientifik.kmath.linear.Point
|
||||||
import scientifik.kmath.operations.ExtendedField
|
import scientifik.kmath.operations.ExtendedField
|
||||||
import scientifik.kmath.operations.Field
|
import scientifik.kmath.operations.Field
|
||||||
|
import scientifik.kmath.operations.invoke
|
||||||
import scientifik.kmath.operations.sum
|
import scientifik.kmath.operations.sum
|
||||||
import scientifik.kmath.structures.asBuffer
|
import scientifik.kmath.structures.asBuffer
|
||||||
|
import kotlin.contracts.ExperimentalContracts
|
||||||
|
import kotlin.contracts.InvocationKind
|
||||||
|
import kotlin.contracts.contract
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Implementation of backward-mode automatic differentiation.
|
* Implementation of backward-mode automatic differentiation.
|
||||||
@ -27,15 +31,14 @@ class DerivationResult<T : Any>(
|
|||||||
/**
|
/**
|
||||||
* compute divergence
|
* compute divergence
|
||||||
*/
|
*/
|
||||||
fun div(): T = context.run { sum(deriv.values) }
|
fun div(): T = context { sum(deriv.values) }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Compute a gradient for variables in given order
|
* Compute a gradient for variables in given order
|
||||||
*/
|
*/
|
||||||
fun grad(vararg variables: Variable<T>): Point<T> = if (variables.isEmpty()) {
|
fun grad(vararg variables: Variable<T>): Point<T> {
|
||||||
error("Variable order is not provided for gradient construction")
|
check(variables.isNotEmpty()) { "Variable order is not provided for gradient construction" }
|
||||||
} else {
|
return variables.map(::deriv).asBuffer()
|
||||||
variables.map(::deriv).asBuffer()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -52,19 +55,28 @@ class DerivationResult<T : Any>(
|
|||||||
* assertEquals(9.0, x.d) // dy/dx
|
* assertEquals(9.0, x.d) // dy/dx
|
||||||
* ```
|
* ```
|
||||||
*/
|
*/
|
||||||
fun <T : Any, F : Field<T>> F.deriv(body: AutoDiffField<T, F>.() -> Variable<T>): DerivationResult<T> =
|
@OptIn(ExperimentalContracts::class)
|
||||||
AutoDiffContext(this).run {
|
inline fun <T : Any, F : Field<T>> F.deriv(body: AutoDiffField<T, F>.() -> Variable<T>): DerivationResult<T> {
|
||||||
|
contract { callsInPlace(body, InvocationKind.EXACTLY_ONCE) }
|
||||||
|
|
||||||
|
return (AutoDiffContext(this)) {
|
||||||
val result = body()
|
val result = body()
|
||||||
result.d = context.one// computing derivative w.r.t result
|
result.d = context.one // computing derivative w.r.t result
|
||||||
runBackwardPass()
|
runBackwardPass()
|
||||||
DerivationResult(result.value, derivatives, this@deriv)
|
DerivationResult(result.value, derivatives, this@deriv)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
abstract class AutoDiffField<T : Any, F : Field<T>> : Field<Variable<T>> {
|
abstract class AutoDiffField<T : Any, F : Field<T>> : Field<Variable<T>> {
|
||||||
|
|
||||||
abstract val context: F
|
abstract val context: F
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A variable accessing inner state of derivatives.
|
||||||
|
* Use this function in inner builders to avoid creating additional derivative bindings
|
||||||
|
*/
|
||||||
|
abstract var Variable<T>.d: T
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Performs update of derivative after the rest of the formula in the back-pass.
|
* Performs update of derivative after the rest of the formula in the back-pass.
|
||||||
*
|
*
|
||||||
@ -78,12 +90,6 @@ abstract class AutoDiffField<T : Any, F : Field<T>> : Field<Variable<T>> {
|
|||||||
*/
|
*/
|
||||||
abstract fun <R> derive(value: R, block: F.(R) -> Unit): R
|
abstract fun <R> derive(value: R, block: F.(R) -> Unit): R
|
||||||
|
|
||||||
/**
|
|
||||||
* A variable accessing inner state of derivatives.
|
|
||||||
* Use this function in inner builders to avoid creating additional derivative bindings
|
|
||||||
*/
|
|
||||||
abstract var Variable<T>.d: T
|
|
||||||
|
|
||||||
abstract fun variable(value: T): Variable<T>
|
abstract fun variable(value: T): Variable<T>
|
||||||
|
|
||||||
inline fun variable(block: F.() -> T): Variable<T> = variable(context.block())
|
inline fun variable(block: F.() -> T): Variable<T> = variable(context.block())
|
||||||
@ -98,46 +104,35 @@ abstract class AutoDiffField<T : Any, F : Field<T>> : Field<Variable<T>> {
|
|||||||
override operator fun Variable<T>.plus(b: Number): Variable<T> = b.plus(this)
|
override operator fun Variable<T>.plus(b: Number): Variable<T> = b.plus(this)
|
||||||
|
|
||||||
override operator fun Number.minus(b: Variable<T>): Variable<T> =
|
override operator fun Number.minus(b: Variable<T>): Variable<T> =
|
||||||
derive(variable { this@minus.toDouble() * one - b.value }) { z ->
|
derive(variable { this@minus.toDouble() * one - b.value }) { z -> b.d -= z.d }
|
||||||
b.d -= z.d
|
|
||||||
}
|
|
||||||
|
|
||||||
override operator fun Variable<T>.minus(b: Number): Variable<T> =
|
override operator fun Variable<T>.minus(b: Number): Variable<T> =
|
||||||
derive(variable { this@minus.value - one * b.toDouble() }) { z ->
|
derive(variable { this@minus.value - one * b.toDouble() }) { z -> this@minus.d += z.d }
|
||||||
this@minus.d += z.d
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Automatic Differentiation context class.
|
* Automatic Differentiation context class.
|
||||||
*/
|
*/
|
||||||
private class AutoDiffContext<T : Any, F : Field<T>>(override val context: F) : AutoDiffField<T, F>() {
|
@PublishedApi
|
||||||
|
internal class AutoDiffContext<T : Any, F : Field<T>>(override val context: F) : AutoDiffField<T, F>() {
|
||||||
// this stack contains pairs of blocks and values to apply them to
|
// this stack contains pairs of blocks and values to apply them to
|
||||||
private var stack = arrayOfNulls<Any?>(8)
|
private var stack: Array<Any?> = arrayOfNulls<Any?>(8)
|
||||||
private var sp = 0
|
private var sp: Int = 0
|
||||||
|
val derivatives: MutableMap<Variable<T>, T> = hashMapOf()
|
||||||
internal val derivatives = HashMap<Variable<T>, T>()
|
override val zero: Variable<T> get() = Variable(context.zero)
|
||||||
|
override val one: Variable<T> get() = Variable(context.one)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A variable coupled with its derivative. For internal use only
|
* A variable coupled with its derivative. For internal use only
|
||||||
*/
|
*/
|
||||||
private class VariableWithDeriv<T : Any>(x: T, var d: T) : Variable<T>(x)
|
private class VariableWithDeriv<T : Any>(x: T, var d: T) : Variable<T>(x)
|
||||||
|
|
||||||
|
|
||||||
override fun variable(value: T): Variable<T> =
|
override fun variable(value: T): Variable<T> =
|
||||||
VariableWithDeriv(value, context.zero)
|
VariableWithDeriv(value, context.zero)
|
||||||
|
|
||||||
override var Variable<T>.d: T
|
override var Variable<T>.d: T
|
||||||
get() = (this as? VariableWithDeriv)?.d ?: derivatives[this] ?: context.zero
|
get() = (this as? VariableWithDeriv)?.d ?: derivatives[this] ?: context.zero
|
||||||
set(value) {
|
set(value) = if (this is VariableWithDeriv) d = value else derivatives[this] = value
|
||||||
if (this is VariableWithDeriv) {
|
|
||||||
d = value
|
|
||||||
} else {
|
|
||||||
derivatives[this] = value
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Suppress("UNCHECKED_CAST")
|
@Suppress("UNCHECKED_CAST")
|
||||||
override fun <R> derive(value: R, block: F.(R) -> Unit): R {
|
override fun <R> derive(value: R, block: F.(R) -> Unit): R {
|
||||||
@ -160,67 +155,49 @@ private class AutoDiffContext<T : Any, F : Field<T>>(override val context: F) :
|
|||||||
// Basic math (+, -, *, /)
|
// Basic math (+, -, *, /)
|
||||||
|
|
||||||
|
|
||||||
override fun add(a: Variable<T>, b: Variable<T>): Variable<T> =
|
override fun add(a: Variable<T>, b: Variable<T>): Variable<T> = derive(variable { a.value + b.value }) { z ->
|
||||||
derive(variable { a.value + b.value }) { z ->
|
a.d += z.d
|
||||||
a.d += z.d
|
b.d += z.d
|
||||||
b.d += z.d
|
}
|
||||||
}
|
|
||||||
|
|
||||||
override fun multiply(a: Variable<T>, b: Variable<T>): Variable<T> =
|
override fun multiply(a: Variable<T>, b: Variable<T>): Variable<T> = derive(variable { a.value * b.value }) { z ->
|
||||||
derive(variable { a.value * b.value }) { z ->
|
a.d += z.d * b.value
|
||||||
a.d += z.d * b.value
|
b.d += z.d * a.value
|
||||||
b.d += z.d * a.value
|
}
|
||||||
}
|
|
||||||
|
|
||||||
override fun divide(a: Variable<T>, b: Variable<T>): Variable<T> =
|
override fun divide(a: Variable<T>, b: Variable<T>): Variable<T> = derive(variable { a.value / b.value }) { z ->
|
||||||
derive(variable { a.value / b.value }) { z ->
|
a.d += z.d / b.value
|
||||||
a.d += z.d / b.value
|
b.d -= z.d * a.value / (b.value * b.value)
|
||||||
b.d -= z.d * a.value / (b.value * b.value)
|
}
|
||||||
}
|
|
||||||
|
|
||||||
override fun multiply(a: Variable<T>, k: Number): Variable<T> =
|
override fun multiply(a: Variable<T>, k: Number): Variable<T> = derive(variable { k.toDouble() * a.value }) { z ->
|
||||||
derive(variable { k.toDouble() * a.value }) { z ->
|
a.d += z.d * k.toDouble()
|
||||||
a.d += z.d * k.toDouble()
|
}
|
||||||
}
|
|
||||||
|
|
||||||
override val zero: Variable<T> get() = Variable(context.zero)
|
|
||||||
override val one: Variable<T> get() = Variable(context.one)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extensions for differentiation of various basic mathematical functions
|
// Extensions for differentiation of various basic mathematical functions
|
||||||
|
|
||||||
// x ^ 2
|
// x ^ 2
|
||||||
fun <T : Any, F : Field<T>> AutoDiffField<T, F>.sqr(x: Variable<T>): Variable<T> =
|
fun <T : Any, F : Field<T>> AutoDiffField<T, F>.sqr(x: Variable<T>): Variable<T> =
|
||||||
derive(variable { x.value * x.value }) { z ->
|
derive(variable { x.value * x.value }) { z -> x.d += z.d * 2 * x.value }
|
||||||
x.d += z.d * 2 * x.value
|
|
||||||
}
|
|
||||||
|
|
||||||
// x ^ 1/2
|
// x ^ 1/2
|
||||||
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sqrt(x: Variable<T>): Variable<T> =
|
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sqrt(x: Variable<T>): Variable<T> =
|
||||||
derive(variable { sqrt(x.value) }) { z ->
|
derive(variable { sqrt(x.value) }) { z -> x.d += z.d * 0.5 / z.value }
|
||||||
x.d += z.d * 0.5 / z.value
|
|
||||||
}
|
|
||||||
|
|
||||||
// x ^ y (const)
|
// x ^ y (const)
|
||||||
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(x: Variable<T>, y: Double): Variable<T> =
|
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(x: Variable<T>, y: Double): Variable<T> =
|
||||||
derive(variable { power(x.value, y) }) { z ->
|
derive(variable { power(x.value, y) }) { z -> x.d += z.d * y * power(x.value, y - 1) }
|
||||||
x.d += z.d * y * power(x.value, y - 1)
|
|
||||||
}
|
|
||||||
|
|
||||||
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(x: Variable<T>, y: Int): Variable<T> = pow(x, y.toDouble())
|
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(x: Variable<T>, y: Int): Variable<T> = pow(x, y.toDouble())
|
||||||
|
|
||||||
// exp(x)
|
// exp(x)
|
||||||
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.exp(x: Variable<T>): Variable<T> =
|
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.exp(x: Variable<T>): Variable<T> =
|
||||||
derive(variable { exp(x.value) }) { z ->
|
derive(variable { exp(x.value) }) { z -> x.d += z.d * z.value }
|
||||||
x.d += z.d * z.value
|
|
||||||
}
|
|
||||||
|
|
||||||
// ln(x)
|
// ln(x)
|
||||||
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.ln(x: Variable<T>): Variable<T> = derive(
|
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.ln(x: Variable<T>): Variable<T> =
|
||||||
variable { ln(x.value) }
|
derive(variable { ln(x.value) }) { z -> x.d += z.d / x.value }
|
||||||
) { z ->
|
|
||||||
x.d += z.d / x.value
|
|
||||||
}
|
|
||||||
|
|
||||||
// x ^ y (any)
|
// x ^ y (any)
|
||||||
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(x: Variable<T>, y: Variable<T>): Variable<T> =
|
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(x: Variable<T>, y: Variable<T>): Variable<T> =
|
||||||
@ -228,12 +205,8 @@ fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(x: Variable<T>, y: V
|
|||||||
|
|
||||||
// sin(x)
|
// sin(x)
|
||||||
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sin(x: Variable<T>): Variable<T> =
|
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sin(x: Variable<T>): Variable<T> =
|
||||||
derive(variable { sin(x.value) }) { z ->
|
derive(variable { sin(x.value) }) { z -> x.d += z.d * cos(x.value) }
|
||||||
x.d += z.d * cos(x.value)
|
|
||||||
}
|
|
||||||
|
|
||||||
// cos(x)
|
// cos(x)
|
||||||
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.cos(x: Variable<T>): Variable<T> =
|
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.cos(x: Variable<T>): Variable<T> =
|
||||||
derive(variable { cos(x.value) }) { z ->
|
derive(variable { cos(x.value) }) { z -> x.d -= z.d * sin(x.value) }
|
||||||
x.d -= z.d * sin(x.value)
|
|
||||||
}
|
|
||||||
|
@ -41,6 +41,6 @@ fun ClosedFloatingPointRange<Double>.toSequenceWithPoints(numPoints: Int): Seque
|
|||||||
*/
|
*/
|
||||||
@Deprecated("Replace by 'toSequenceWithPoints'")
|
@Deprecated("Replace by 'toSequenceWithPoints'")
|
||||||
fun ClosedFloatingPointRange<Double>.toGrid(numPoints: Int): DoubleArray {
|
fun ClosedFloatingPointRange<Double>.toGrid(numPoints: Int): DoubleArray {
|
||||||
if (numPoints < 2) error("Can't create generic grid with less than two points")
|
require(numPoints >= 2) { "Can't create generic grid with less than two points" }
|
||||||
return DoubleArray(numPoints) { i -> start + (endInclusive - start) / (numPoints - 1) * i }
|
return DoubleArray(numPoints) { i -> start + (endInclusive - start) / (numPoints - 1) * i }
|
||||||
}
|
}
|
||||||
|
@ -2,6 +2,8 @@ package scientifik.kmath.misc
|
|||||||
|
|
||||||
import scientifik.kmath.operations.Space
|
import scientifik.kmath.operations.Space
|
||||||
import scientifik.kmath.operations.invoke
|
import scientifik.kmath.operations.invoke
|
||||||
|
import kotlin.contracts.ExperimentalContracts
|
||||||
|
import kotlin.contracts.contract
|
||||||
import kotlin.jvm.JvmName
|
import kotlin.jvm.JvmName
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -11,67 +13,69 @@ import kotlin.jvm.JvmName
|
|||||||
* @param R the type of resulting iterable.
|
* @param R the type of resulting iterable.
|
||||||
* @param initial lazy evaluated.
|
* @param initial lazy evaluated.
|
||||||
*/
|
*/
|
||||||
fun <T, R> Iterator<T>.cumulative(initial: R, operation: (R, T) -> R): Iterator<R> = object : Iterator<R> {
|
@OptIn(ExperimentalContracts::class)
|
||||||
var state: R = initial
|
inline fun <T, R> Iterator<T>.cumulative(initial: R, crossinline operation: (R, T) -> R): Iterator<R> {
|
||||||
override fun hasNext(): Boolean = this@cumulative.hasNext()
|
contract { callsInPlace(operation) }
|
||||||
|
|
||||||
override fun next(): R {
|
return object : Iterator<R> {
|
||||||
state = operation(state, this@cumulative.next())
|
var state: R = initial
|
||||||
return state
|
|
||||||
|
override fun hasNext(): Boolean = this@cumulative.hasNext()
|
||||||
|
|
||||||
|
override fun next(): R {
|
||||||
|
state = operation(state, this@cumulative.next())
|
||||||
|
return state
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fun <T, R> Iterable<T>.cumulative(initial: R, operation: (R, T) -> R): Iterable<R> = object : Iterable<R> {
|
inline fun <T, R> Iterable<T>.cumulative(initial: R, crossinline operation: (R, T) -> R): Iterable<R> =
|
||||||
override fun iterator(): Iterator<R> = this@cumulative.iterator().cumulative(initial, operation)
|
Iterable { this@cumulative.iterator().cumulative(initial, operation) }
|
||||||
}
|
|
||||||
|
|
||||||
fun <T, R> Sequence<T>.cumulative(initial: R, operation: (R, T) -> R): Sequence<R> = object : Sequence<R> {
|
inline fun <T, R> Sequence<T>.cumulative(initial: R, crossinline operation: (R, T) -> R): Sequence<R> = Sequence {
|
||||||
override fun iterator(): Iterator<R> = this@cumulative.iterator().cumulative(initial, operation)
|
this@cumulative.iterator().cumulative(initial, operation)
|
||||||
}
|
}
|
||||||
|
|
||||||
fun <T, R> List<T>.cumulative(initial: R, operation: (R, T) -> R): List<R> =
|
fun <T, R> List<T>.cumulative(initial: R, operation: (R, T) -> R): List<R> =
|
||||||
this.iterator().cumulative(initial, operation).asSequence().toList()
|
iterator().cumulative(initial, operation).asSequence().toList()
|
||||||
|
|
||||||
//Cumulative sum
|
//Cumulative sum
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Cumulative sum with custom space
|
* Cumulative sum with custom space
|
||||||
*/
|
*/
|
||||||
fun <T> Iterable<T>.cumulativeSum(space: Space<T>): Iterable<T> = space {
|
fun <T> Iterable<T>.cumulativeSum(space: Space<T>): Iterable<T> =
|
||||||
cumulative(zero) { element: T, sum: T -> sum + element }
|
space { cumulative(zero) { element: T, sum: T -> sum + element } }
|
||||||
}
|
|
||||||
|
|
||||||
@JvmName("cumulativeSumOfDouble")
|
@JvmName("cumulativeSumOfDouble")
|
||||||
fun Iterable<Double>.cumulativeSum(): Iterable<Double> = this.cumulative(0.0) { element, sum -> sum + element }
|
fun Iterable<Double>.cumulativeSum(): Iterable<Double> = cumulative(0.0) { element, sum -> sum + element }
|
||||||
|
|
||||||
@JvmName("cumulativeSumOfInt")
|
@JvmName("cumulativeSumOfInt")
|
||||||
fun Iterable<Int>.cumulativeSum(): Iterable<Int> = this.cumulative(0) { element, sum -> sum + element }
|
fun Iterable<Int>.cumulativeSum(): Iterable<Int> = cumulative(0) { element, sum -> sum + element }
|
||||||
|
|
||||||
@JvmName("cumulativeSumOfLong")
|
@JvmName("cumulativeSumOfLong")
|
||||||
fun Iterable<Long>.cumulativeSum(): Iterable<Long> = this.cumulative(0L) { element, sum -> sum + element }
|
fun Iterable<Long>.cumulativeSum(): Iterable<Long> = cumulative(0L) { element, sum -> sum + element }
|
||||||
|
|
||||||
fun <T> Sequence<T>.cumulativeSum(space: Space<T>): Sequence<T> = with(space) {
|
fun <T> Sequence<T>.cumulativeSum(space: Space<T>): Sequence<T> =
|
||||||
cumulative(zero) { element: T, sum: T -> sum + element }
|
space { cumulative(zero) { element: T, sum: T -> sum + element } }
|
||||||
}
|
|
||||||
|
|
||||||
@JvmName("cumulativeSumOfDouble")
|
@JvmName("cumulativeSumOfDouble")
|
||||||
fun Sequence<Double>.cumulativeSum(): Sequence<Double> = this.cumulative(0.0) { element, sum -> sum + element }
|
fun Sequence<Double>.cumulativeSum(): Sequence<Double> = cumulative(0.0) { element, sum -> sum + element }
|
||||||
|
|
||||||
@JvmName("cumulativeSumOfInt")
|
@JvmName("cumulativeSumOfInt")
|
||||||
fun Sequence<Int>.cumulativeSum(): Sequence<Int> = this.cumulative(0) { element, sum -> sum + element }
|
fun Sequence<Int>.cumulativeSum(): Sequence<Int> = cumulative(0) { element, sum -> sum + element }
|
||||||
|
|
||||||
@JvmName("cumulativeSumOfLong")
|
@JvmName("cumulativeSumOfLong")
|
||||||
fun Sequence<Long>.cumulativeSum(): Sequence<Long> = this.cumulative(0L) { element, sum -> sum + element }
|
fun Sequence<Long>.cumulativeSum(): Sequence<Long> = cumulative(0L) { element, sum -> sum + element }
|
||||||
|
|
||||||
fun <T> List<T>.cumulativeSum(space: Space<T>): List<T> = with(space) {
|
fun <T> List<T>.cumulativeSum(space: Space<T>): List<T> =
|
||||||
cumulative(zero) { element: T, sum: T -> sum + element }
|
space { cumulative(zero) { element: T, sum: T -> sum + element } }
|
||||||
}
|
|
||||||
|
|
||||||
@JvmName("cumulativeSumOfDouble")
|
@JvmName("cumulativeSumOfDouble")
|
||||||
fun List<Double>.cumulativeSum(): List<Double> = this.cumulative(0.0) { element, sum -> sum + element }
|
fun List<Double>.cumulativeSum(): List<Double> = cumulative(0.0) { element, sum -> sum + element }
|
||||||
|
|
||||||
@JvmName("cumulativeSumOfInt")
|
@JvmName("cumulativeSumOfInt")
|
||||||
fun List<Int>.cumulativeSum(): List<Int> = this.cumulative(0) { element, sum -> sum + element }
|
fun List<Int>.cumulativeSum(): List<Int> = cumulative(0) { element, sum -> sum + element }
|
||||||
|
|
||||||
@JvmName("cumulativeSumOfLong")
|
@JvmName("cumulativeSumOfLong")
|
||||||
fun List<Long>.cumulativeSum(): List<Long> = this.cumulative(0L) { element, sum -> sum + element }
|
fun List<Long>.cumulativeSum(): List<Long> = cumulative(0L) { element, sum -> sum + element }
|
||||||
|
@ -3,6 +3,8 @@ package scientifik.kmath.operations
|
|||||||
import scientifik.kmath.operations.BigInt.Companion.BASE
|
import scientifik.kmath.operations.BigInt.Companion.BASE
|
||||||
import scientifik.kmath.operations.BigInt.Companion.BASE_SIZE
|
import scientifik.kmath.operations.BigInt.Companion.BASE_SIZE
|
||||||
import scientifik.kmath.structures.*
|
import scientifik.kmath.structures.*
|
||||||
|
import kotlin.contracts.ExperimentalContracts
|
||||||
|
import kotlin.contracts.contract
|
||||||
import kotlin.math.log2
|
import kotlin.math.log2
|
||||||
import kotlin.math.max
|
import kotlin.math.max
|
||||||
import kotlin.math.min
|
import kotlin.math.min
|
||||||
@ -431,8 +433,8 @@ fun ULong.toBigInt(): BigInt = BigInt(
|
|||||||
* Create a [BigInt] with this array of magnitudes with protective copy
|
* Create a [BigInt] with this array of magnitudes with protective copy
|
||||||
*/
|
*/
|
||||||
fun UIntArray.toBigInt(sign: Byte): BigInt {
|
fun UIntArray.toBigInt(sign: Byte): BigInt {
|
||||||
if (sign == 0.toByte() && isNotEmpty()) error("")
|
require(sign != 0.toByte() || !isNotEmpty())
|
||||||
return BigInt(sign, this.copyOf())
|
return BigInt(sign, copyOf())
|
||||||
}
|
}
|
||||||
|
|
||||||
val hexChToInt: MutableMap<Char, Int> = hashMapOf(
|
val hexChToInt: MutableMap<Char, Int> = hashMapOf(
|
||||||
@ -485,11 +487,17 @@ fun String.parseBigInteger(): BigInt? {
|
|||||||
return res * sign
|
return res * sign
|
||||||
}
|
}
|
||||||
|
|
||||||
inline fun Buffer.Companion.bigInt(size: Int, initializer: (Int) -> BigInt): Buffer<BigInt> =
|
@OptIn(ExperimentalContracts::class)
|
||||||
boxing(size, initializer)
|
inline fun Buffer.Companion.bigInt(size: Int, initializer: (Int) -> BigInt): Buffer<BigInt> {
|
||||||
|
contract { callsInPlace(initializer) }
|
||||||
|
return boxing(size, initializer)
|
||||||
|
}
|
||||||
|
|
||||||
inline fun MutableBuffer.Companion.bigInt(size: Int, initializer: (Int) -> BigInt): MutableBuffer<BigInt> =
|
@OptIn(ExperimentalContracts::class)
|
||||||
boxing(size, initializer)
|
inline fun MutableBuffer.Companion.bigInt(size: Int, initializer: (Int) -> BigInt): MutableBuffer<BigInt> {
|
||||||
|
contract { callsInPlace(initializer) }
|
||||||
|
return boxing(size, initializer)
|
||||||
|
}
|
||||||
|
|
||||||
fun NDAlgebra.Companion.bigInt(vararg shape: Int): BoxingNDRing<BigInt, BigIntField> =
|
fun NDAlgebra.Companion.bigInt(vararg shape: Int): BoxingNDRing<BigInt, BigIntField> =
|
||||||
BoxingNDRing(shape, BigIntField, Buffer.Companion::bigInt)
|
BoxingNDRing(shape, BigIntField, Buffer.Companion::bigInt)
|
||||||
@ -497,5 +505,4 @@ fun NDAlgebra.Companion.bigInt(vararg shape: Int): BoxingNDRing<BigInt, BigIntFi
|
|||||||
fun NDElement.Companion.bigInt(
|
fun NDElement.Companion.bigInt(
|
||||||
vararg shape: Int,
|
vararg shape: Int,
|
||||||
initializer: BigIntField.(IntArray) -> BigInt
|
initializer: BigIntField.(IntArray) -> BigInt
|
||||||
): BufferedNDRingElement<BigInt, BigIntField> =
|
): BufferedNDRingElement<BigInt, BigIntField> = NDAlgebra.bigInt(*shape).produce(initializer)
|
||||||
NDAlgebra.bigInt(*shape).produce(initializer)
|
|
||||||
|
@ -6,6 +6,8 @@ import scientifik.kmath.structures.MutableBuffer
|
|||||||
import scientifik.memory.MemoryReader
|
import scientifik.memory.MemoryReader
|
||||||
import scientifik.memory.MemorySpec
|
import scientifik.memory.MemorySpec
|
||||||
import scientifik.memory.MemoryWriter
|
import scientifik.memory.MemoryWriter
|
||||||
|
import kotlin.contracts.ExperimentalContracts
|
||||||
|
import kotlin.contracts.contract
|
||||||
import kotlin.math.*
|
import kotlin.math.*
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -196,10 +198,14 @@ data class Complex(val re: Double, val im: Double) : FieldElement<Complex, Compl
|
|||||||
*/
|
*/
|
||||||
fun Number.toComplex(): Complex = Complex(this, 0.0)
|
fun Number.toComplex(): Complex = Complex(this, 0.0)
|
||||||
|
|
||||||
|
@OptIn(ExperimentalContracts::class)
|
||||||
inline fun Buffer.Companion.complex(size: Int, crossinline init: (Int) -> Complex): Buffer<Complex> {
|
inline fun Buffer.Companion.complex(size: Int, crossinline init: (Int) -> Complex): Buffer<Complex> {
|
||||||
|
contract { callsInPlace(init) }
|
||||||
return MemoryBuffer.create(Complex, size, init)
|
return MemoryBuffer.create(Complex, size, init)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@OptIn(ExperimentalContracts::class)
|
||||||
inline fun MutableBuffer.Companion.complex(size: Int, crossinline init: (Int) -> Complex): Buffer<Complex> {
|
inline fun MutableBuffer.Companion.complex(size: Int, crossinline init: (Int) -> Complex): Buffer<Complex> {
|
||||||
|
contract { callsInPlace(init) }
|
||||||
return MemoryBuffer.create(Complex, size, init)
|
return MemoryBuffer.create(Complex, size, init)
|
||||||
}
|
}
|
||||||
|
@ -8,19 +8,17 @@ class BoxingNDField<T, F : Field<T>>(
|
|||||||
override val elementContext: F,
|
override val elementContext: F,
|
||||||
val bufferFactory: BufferFactory<T>
|
val bufferFactory: BufferFactory<T>
|
||||||
) : BufferedNDField<T, F> {
|
) : BufferedNDField<T, F> {
|
||||||
|
override val zero: BufferedNDFieldElement<T, F> by lazy { produce { zero } }
|
||||||
|
override val one: BufferedNDFieldElement<T, F> by lazy { produce { one } }
|
||||||
override val strides: Strides = DefaultStrides(shape)
|
override val strides: Strides = DefaultStrides(shape)
|
||||||
|
|
||||||
fun buildBuffer(size: Int, initializer: (Int) -> T): Buffer<T> =
|
fun buildBuffer(size: Int, initializer: (Int) -> T): Buffer<T> =
|
||||||
bufferFactory(size, initializer)
|
bufferFactory(size, initializer)
|
||||||
|
|
||||||
override fun check(vararg elements: NDBuffer<T>) {
|
override fun check(vararg elements: NDBuffer<T>) {
|
||||||
if (!elements.all { it.strides == this.strides }) error("Element strides are not the same as context strides")
|
check(elements.all { it.strides == strides }) { "Element strides are not the same as context strides" }
|
||||||
}
|
}
|
||||||
|
|
||||||
override val zero: BufferedNDFieldElement<T, F> by lazy { produce { zero } }
|
|
||||||
override val one: BufferedNDFieldElement<T, F> by lazy { produce { one } }
|
|
||||||
|
|
||||||
override fun produce(initializer: F.(IntArray) -> T): BufferedNDFieldElement<T, F> =
|
override fun produce(initializer: F.(IntArray) -> T): BufferedNDFieldElement<T, F> =
|
||||||
BufferedNDFieldElement(
|
BufferedNDFieldElement(
|
||||||
this,
|
this,
|
||||||
@ -28,6 +26,7 @@ class BoxingNDField<T, F : Field<T>>(
|
|||||||
|
|
||||||
override fun map(arg: NDBuffer<T>, transform: F.(T) -> T): BufferedNDFieldElement<T, F> {
|
override fun map(arg: NDBuffer<T>, transform: F.(T) -> T): BufferedNDFieldElement<T, F> {
|
||||||
check(arg)
|
check(arg)
|
||||||
|
|
||||||
return BufferedNDFieldElement(
|
return BufferedNDFieldElement(
|
||||||
this,
|
this,
|
||||||
buildBuffer(arg.strides.linearSize) { offset -> elementContext.transform(arg.buffer[offset]) })
|
buildBuffer(arg.strides.linearSize) { offset -> elementContext.transform(arg.buffer[offset]) })
|
||||||
|
@ -8,19 +8,16 @@ class BoxingNDRing<T, R : Ring<T>>(
|
|||||||
override val elementContext: R,
|
override val elementContext: R,
|
||||||
val bufferFactory: BufferFactory<T>
|
val bufferFactory: BufferFactory<T>
|
||||||
) : BufferedNDRing<T, R> {
|
) : BufferedNDRing<T, R> {
|
||||||
|
|
||||||
override val strides: Strides = DefaultStrides(shape)
|
override val strides: Strides = DefaultStrides(shape)
|
||||||
|
|
||||||
fun buildBuffer(size: Int, initializer: (Int) -> T): Buffer<T> =
|
|
||||||
bufferFactory(size, initializer)
|
|
||||||
|
|
||||||
override fun check(vararg elements: NDBuffer<T>) {
|
|
||||||
if (!elements.all { it.strides == this.strides }) error("Element strides are not the same as context strides")
|
|
||||||
}
|
|
||||||
|
|
||||||
override val zero: BufferedNDRingElement<T, R> by lazy { produce { zero } }
|
override val zero: BufferedNDRingElement<T, R> by lazy { produce { zero } }
|
||||||
override val one: BufferedNDRingElement<T, R> by lazy { produce { one } }
|
override val one: BufferedNDRingElement<T, R> by lazy { produce { one } }
|
||||||
|
|
||||||
|
fun buildBuffer(size: Int, initializer: (Int) -> T): Buffer<T> = bufferFactory(size, initializer)
|
||||||
|
|
||||||
|
override fun check(vararg elements: NDBuffer<T>) {
|
||||||
|
require(elements.all { it.strides == strides }) { "Element strides are not the same as context strides" }
|
||||||
|
}
|
||||||
|
|
||||||
override fun produce(initializer: R.(IntArray) -> T): BufferedNDRingElement<T, R> =
|
override fun produce(initializer: R.(IntArray) -> T): BufferedNDRingElement<T, R> =
|
||||||
BufferedNDRingElement(
|
BufferedNDRingElement(
|
||||||
this,
|
this,
|
||||||
|
@ -6,7 +6,6 @@ import kotlin.reflect.KClass
|
|||||||
* A context that allows to operate on a [MutableBuffer] as on 2d array
|
* A context that allows to operate on a [MutableBuffer] as on 2d array
|
||||||
*/
|
*/
|
||||||
class BufferAccessor2D<T : Any>(val type: KClass<T>, val rowNum: Int, val colNum: Int) {
|
class BufferAccessor2D<T : Any>(val type: KClass<T>, val rowNum: Int, val colNum: Int) {
|
||||||
|
|
||||||
operator fun Buffer<T>.get(i: Int, j: Int): T = get(i + colNum * j)
|
operator fun Buffer<T>.get(i: Int, j: Int): T = get(i + colNum * j)
|
||||||
|
|
||||||
operator fun MutableBuffer<T>.set(i: Int, j: Int, value: T) {
|
operator fun MutableBuffer<T>.set(i: Int, j: Int, value: T) {
|
||||||
@ -26,15 +25,14 @@ class BufferAccessor2D<T : Any>(val type: KClass<T>, val rowNum: Int, val colNum
|
|||||||
inner class Row(val buffer: MutableBuffer<T>, val rowIndex: Int) : MutableBuffer<T> {
|
inner class Row(val buffer: MutableBuffer<T>, val rowIndex: Int) : MutableBuffer<T> {
|
||||||
override val size: Int get() = colNum
|
override val size: Int get() = colNum
|
||||||
|
|
||||||
override fun get(index: Int): T = buffer[rowIndex, index]
|
override operator fun get(index: Int): T = buffer[rowIndex, index]
|
||||||
|
|
||||||
override fun set(index: Int, value: T) {
|
override operator fun set(index: Int, value: T) {
|
||||||
buffer[rowIndex, index] = value
|
buffer[rowIndex, index] = value
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun copy(): MutableBuffer<T> = MutableBuffer.auto(type, colNum) { get(it) }
|
override fun copy(): MutableBuffer<T> = MutableBuffer.auto(type, colNum) { get(it) }
|
||||||
|
override operator fun iterator(): Iterator<T> = (0 until colNum).map(::get).iterator()
|
||||||
override fun iterator(): Iterator<T> = (0 until colNum).map(::get).iterator()
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -5,9 +5,8 @@ import scientifik.kmath.operations.*
|
|||||||
interface BufferedNDAlgebra<T, C> : NDAlgebra<T, C, NDBuffer<T>> {
|
interface BufferedNDAlgebra<T, C> : NDAlgebra<T, C, NDBuffer<T>> {
|
||||||
val strides: Strides
|
val strides: Strides
|
||||||
|
|
||||||
override fun check(vararg elements: NDBuffer<T>) {
|
override fun check(vararg elements: NDBuffer<T>): Unit =
|
||||||
if (!elements.all { it.strides == this.strides }) error("Strides mismatch")
|
require(elements.all { it.strides == strides }) { ("Strides mismatch") }
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Convert any [NDStructure] to buffered structure using strides from this context.
|
* Convert any [NDStructure] to buffered structure using strides from this context.
|
||||||
|
@ -30,7 +30,6 @@ class BufferedNDRingElement<T, R : Ring<T>>(
|
|||||||
override val context: BufferedNDRing<T, R>,
|
override val context: BufferedNDRing<T, R>,
|
||||||
override val buffer: Buffer<T>
|
override val buffer: Buffer<T>
|
||||||
) : BufferedNDElement<T, R>(), RingElement<NDBuffer<T>, BufferedNDRingElement<T, R>, BufferedNDRing<T, R>> {
|
) : BufferedNDElement<T, R>(), RingElement<NDBuffer<T>, BufferedNDRingElement<T, R>, BufferedNDRing<T, R>> {
|
||||||
|
|
||||||
override fun unwrap(): NDBuffer<T> = this
|
override fun unwrap(): NDBuffer<T> = this
|
||||||
|
|
||||||
override fun NDBuffer<T>.wrap(): BufferedNDRingElement<T, R> {
|
override fun NDBuffer<T>.wrap(): BufferedNDRingElement<T, R> {
|
||||||
@ -43,7 +42,6 @@ class BufferedNDFieldElement<T, F : Field<T>>(
|
|||||||
override val context: BufferedNDField<T, F>,
|
override val context: BufferedNDField<T, F>,
|
||||||
override val buffer: Buffer<T>
|
override val buffer: Buffer<T>
|
||||||
) : BufferedNDElement<T, F>(), FieldElement<NDBuffer<T>, BufferedNDFieldElement<T, F>, BufferedNDField<T, F>> {
|
) : BufferedNDElement<T, F>(), FieldElement<NDBuffer<T>, BufferedNDFieldElement<T, F>, BufferedNDField<T, F>> {
|
||||||
|
|
||||||
override fun unwrap(): NDBuffer<T> = this
|
override fun unwrap(): NDBuffer<T> = this
|
||||||
|
|
||||||
override fun NDBuffer<T>.wrap(): BufferedNDFieldElement<T, F> {
|
override fun NDBuffer<T>.wrap(): BufferedNDFieldElement<T, F> {
|
||||||
|
@ -2,6 +2,8 @@ package scientifik.kmath.structures
|
|||||||
|
|
||||||
import scientifik.kmath.operations.Complex
|
import scientifik.kmath.operations.Complex
|
||||||
import scientifik.kmath.operations.complex
|
import scientifik.kmath.operations.complex
|
||||||
|
import kotlin.contracts.ExperimentalContracts
|
||||||
|
import kotlin.contracts.contract
|
||||||
import kotlin.reflect.KClass
|
import kotlin.reflect.KClass
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -117,15 +119,14 @@ interface MutableBuffer<T> : Buffer<T> {
|
|||||||
MutableListBuffer(MutableList(size, initializer))
|
MutableListBuffer(MutableList(size, initializer))
|
||||||
|
|
||||||
@Suppress("UNCHECKED_CAST")
|
@Suppress("UNCHECKED_CAST")
|
||||||
inline fun <T : Any> auto(type: KClass<out T>, size: Int, initializer: (Int) -> T): MutableBuffer<T> {
|
inline fun <T : Any> auto(type: KClass<out T>, size: Int, initializer: (Int) -> T): MutableBuffer<T> =
|
||||||
return when (type) {
|
when (type) {
|
||||||
Double::class -> RealBuffer(DoubleArray(size) { initializer(it) as Double }) as MutableBuffer<T>
|
Double::class -> RealBuffer(DoubleArray(size) { initializer(it) as Double }) as MutableBuffer<T>
|
||||||
Short::class -> ShortBuffer(ShortArray(size) { initializer(it) as Short }) as MutableBuffer<T>
|
Short::class -> ShortBuffer(ShortArray(size) { initializer(it) as Short }) as MutableBuffer<T>
|
||||||
Int::class -> IntBuffer(IntArray(size) { initializer(it) as Int }) as MutableBuffer<T>
|
Int::class -> IntBuffer(IntArray(size) { initializer(it) as Int }) as MutableBuffer<T>
|
||||||
Long::class -> LongBuffer(LongArray(size) { initializer(it) as Long }) as MutableBuffer<T>
|
Long::class -> LongBuffer(LongArray(size) { initializer(it) as Long }) as MutableBuffer<T>
|
||||||
else -> boxing(size, initializer)
|
else -> boxing(size, initializer)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create most appropriate mutable buffer for given type avoiding boxing wherever possible
|
* Create most appropriate mutable buffer for given type avoiding boxing wherever possible
|
||||||
@ -150,9 +151,8 @@ inline class ListBuffer<T>(val list: List<T>) : Buffer<T> {
|
|||||||
override val size: Int
|
override val size: Int
|
||||||
get() = list.size
|
get() = list.size
|
||||||
|
|
||||||
override fun get(index: Int): T = list[index]
|
override operator fun get(index: Int): T = list[index]
|
||||||
|
override operator fun iterator(): Iterator<T> = list.iterator()
|
||||||
override fun iterator(): Iterator<T> = list.iterator()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -167,7 +167,11 @@ fun <T> List<T>.asBuffer(): ListBuffer<T> = ListBuffer(this)
|
|||||||
* The function [init] is called for each array element sequentially starting from the first one.
|
* The function [init] is called for each array element sequentially starting from the first one.
|
||||||
* It should return the value for an array element given its index.
|
* It should return the value for an array element given its index.
|
||||||
*/
|
*/
|
||||||
inline fun <T> ListBuffer(size: Int, init: (Int) -> T): ListBuffer<T> = List(size, init).asBuffer()
|
@OptIn(ExperimentalContracts::class)
|
||||||
|
inline fun <T> ListBuffer(size: Int, init: (Int) -> T): ListBuffer<T> {
|
||||||
|
contract { callsInPlace(init) }
|
||||||
|
return List(size, init).asBuffer()
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* [MutableBuffer] implementation over [MutableList].
|
* [MutableBuffer] implementation over [MutableList].
|
||||||
@ -176,17 +180,16 @@ inline fun <T> ListBuffer(size: Int, init: (Int) -> T): ListBuffer<T> = List(siz
|
|||||||
* @property list The underlying list.
|
* @property list The underlying list.
|
||||||
*/
|
*/
|
||||||
inline class MutableListBuffer<T>(val list: MutableList<T>) : MutableBuffer<T> {
|
inline class MutableListBuffer<T>(val list: MutableList<T>) : MutableBuffer<T> {
|
||||||
|
|
||||||
override val size: Int
|
override val size: Int
|
||||||
get() = list.size
|
get() = list.size
|
||||||
|
|
||||||
override fun get(index: Int): T = list[index]
|
override operator fun get(index: Int): T = list[index]
|
||||||
|
|
||||||
override fun set(index: Int, value: T) {
|
override operator fun set(index: Int, value: T) {
|
||||||
list[index] = value
|
list[index] = value
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun iterator(): Iterator<T> = list.iterator()
|
override operator fun iterator(): Iterator<T> = list.iterator()
|
||||||
override fun copy(): MutableBuffer<T> = MutableListBuffer(ArrayList(list))
|
override fun copy(): MutableBuffer<T> = MutableListBuffer(ArrayList(list))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -201,14 +204,13 @@ class ArrayBuffer<T>(private val array: Array<T>) : MutableBuffer<T> {
|
|||||||
override val size: Int
|
override val size: Int
|
||||||
get() = array.size
|
get() = array.size
|
||||||
|
|
||||||
override fun get(index: Int): T = array[index]
|
override operator fun get(index: Int): T = array[index]
|
||||||
|
|
||||||
override fun set(index: Int, value: T) {
|
override operator fun set(index: Int, value: T) {
|
||||||
array[index] = value
|
array[index] = value
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun iterator(): Iterator<T> = array.iterator()
|
override operator fun iterator(): Iterator<T> = array.iterator()
|
||||||
|
|
||||||
override fun copy(): MutableBuffer<T> = ArrayBuffer(array.copyOf())
|
override fun copy(): MutableBuffer<T> = ArrayBuffer(array.copyOf())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -226,9 +228,9 @@ fun <T> Array<T>.asBuffer(): ArrayBuffer<T> = ArrayBuffer(this)
|
|||||||
inline class ReadOnlyBuffer<T>(val buffer: MutableBuffer<T>) : Buffer<T> {
|
inline class ReadOnlyBuffer<T>(val buffer: MutableBuffer<T>) : Buffer<T> {
|
||||||
override val size: Int get() = buffer.size
|
override val size: Int get() = buffer.size
|
||||||
|
|
||||||
override fun get(index: Int): T = buffer[index]
|
override operator fun get(index: Int): T = buffer[index]
|
||||||
|
|
||||||
override fun iterator(): Iterator<T> = buffer.iterator()
|
override operator fun iterator(): Iterator<T> = buffer.iterator()
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -238,12 +240,12 @@ inline class ReadOnlyBuffer<T>(val buffer: MutableBuffer<T>) : Buffer<T> {
|
|||||||
* @param T the type of elements provided by the buffer.
|
* @param T the type of elements provided by the buffer.
|
||||||
*/
|
*/
|
||||||
class VirtualBuffer<T>(override val size: Int, private val generator: (Int) -> T) : Buffer<T> {
|
class VirtualBuffer<T>(override val size: Int, private val generator: (Int) -> T) : Buffer<T> {
|
||||||
override fun get(index: Int): T {
|
override operator fun get(index: Int): T {
|
||||||
if (index < 0 || index >= size) throw IndexOutOfBoundsException("Expected index from 0 to ${size - 1}, but found $index")
|
if (index < 0 || index >= size) throw IndexOutOfBoundsException("Expected index from 0 to ${size - 1}, but found $index")
|
||||||
return generator(index)
|
return generator(index)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun iterator(): Iterator<T> = (0 until size).asSequence().map(generator).iterator()
|
override operator fun iterator(): Iterator<T> = (0 until size).asSequence().map(generator).iterator()
|
||||||
|
|
||||||
override fun contentEquals(other: Buffer<*>): Boolean {
|
override fun contentEquals(other: Buffer<*>): Boolean {
|
||||||
return if (other is VirtualBuffer) {
|
return if (other is VirtualBuffer) {
|
||||||
|
@ -4,6 +4,9 @@ import scientifik.kmath.operations.Complex
|
|||||||
import scientifik.kmath.operations.ComplexField
|
import scientifik.kmath.operations.ComplexField
|
||||||
import scientifik.kmath.operations.FieldElement
|
import scientifik.kmath.operations.FieldElement
|
||||||
import scientifik.kmath.operations.complex
|
import scientifik.kmath.operations.complex
|
||||||
|
import kotlin.contracts.ExperimentalContracts
|
||||||
|
import kotlin.contracts.InvocationKind
|
||||||
|
import kotlin.contracts.contract
|
||||||
|
|
||||||
typealias ComplexNDElement = BufferedNDFieldElement<Complex, ComplexField>
|
typealias ComplexNDElement = BufferedNDFieldElement<Complex, ComplexField>
|
||||||
|
|
||||||
@ -109,7 +112,9 @@ inline fun ComplexNDElement.mapIndexed(crossinline transform: ComplexField.(inde
|
|||||||
/**
|
/**
|
||||||
* Map one [ComplexNDElement] using function without indices.
|
* Map one [ComplexNDElement] using function without indices.
|
||||||
*/
|
*/
|
||||||
|
@OptIn(ExperimentalContracts::class)
|
||||||
inline fun ComplexNDElement.map(crossinline transform: ComplexField.(Complex) -> Complex): ComplexNDElement {
|
inline fun ComplexNDElement.map(crossinline transform: ComplexField.(Complex) -> Complex): ComplexNDElement {
|
||||||
|
contract { callsInPlace(transform) }
|
||||||
val buffer = Buffer.complex(strides.linearSize) { offset -> ComplexField.transform(buffer[offset]) }
|
val buffer = Buffer.complex(strides.linearSize) { offset -> ComplexField.transform(buffer[offset]) }
|
||||||
return BufferedNDFieldElement(context, buffer)
|
return BufferedNDFieldElement(context, buffer)
|
||||||
}
|
}
|
||||||
@ -148,6 +153,8 @@ fun NDElement.Companion.complex(vararg shape: Int, initializer: ComplexField.(In
|
|||||||
/**
|
/**
|
||||||
* Produce a context for n-dimensional operations inside this real field
|
* Produce a context for n-dimensional operations inside this real field
|
||||||
*/
|
*/
|
||||||
|
@OptIn(ExperimentalContracts::class)
|
||||||
inline fun <R> ComplexField.nd(vararg shape: Int, action: ComplexNDField.() -> R): R {
|
inline fun <R> ComplexField.nd(vararg shape: Int, action: ComplexNDField.() -> R): R {
|
||||||
return NDField.complex(*shape).run(action)
|
contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) }
|
||||||
|
return NDField.complex(*shape).action()
|
||||||
}
|
}
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
package scientifik.kmath.structures
|
package scientifik.kmath.structures
|
||||||
|
|
||||||
|
import kotlin.contracts.ExperimentalContracts
|
||||||
|
import kotlin.contracts.contract
|
||||||
import kotlin.experimental.and
|
import kotlin.experimental.and
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -57,17 +59,19 @@ class FlaggedRealBuffer(val values: DoubleArray, val flags: ByteArray) : Flagged
|
|||||||
|
|
||||||
override val size: Int get() = values.size
|
override val size: Int get() = values.size
|
||||||
|
|
||||||
override fun get(index: Int): Double? = if (isValid(index)) values[index] else null
|
override operator fun get(index: Int): Double? = if (isValid(index)) values[index] else null
|
||||||
|
|
||||||
override fun iterator(): Iterator<Double?> = values.indices.asSequence().map {
|
override operator fun iterator(): Iterator<Double?> = values.indices.asSequence().map {
|
||||||
if (isValid(it)) values[it] else null
|
if (isValid(it)) values[it] else null
|
||||||
}.iterator()
|
}.iterator()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@OptIn(ExperimentalContracts::class)
|
||||||
inline fun FlaggedRealBuffer.forEachValid(block: (Double) -> Unit) {
|
inline fun FlaggedRealBuffer.forEachValid(block: (Double) -> Unit) {
|
||||||
for (i in indices) {
|
contract { callsInPlace(block) }
|
||||||
if (isValid(i)) {
|
|
||||||
block(values[i])
|
indices
|
||||||
}
|
.asSequence()
|
||||||
}
|
.filter(::isValid)
|
||||||
|
.forEach { block(values[it]) }
|
||||||
}
|
}
|
||||||
|
@ -1,5 +1,8 @@
|
|||||||
package scientifik.kmath.structures
|
package scientifik.kmath.structures
|
||||||
|
|
||||||
|
import kotlin.contracts.ExperimentalContracts
|
||||||
|
import kotlin.contracts.contract
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Specialized [MutableBuffer] implementation over [FloatArray].
|
* Specialized [MutableBuffer] implementation over [FloatArray].
|
||||||
*
|
*
|
||||||
@ -8,13 +11,13 @@ package scientifik.kmath.structures
|
|||||||
inline class FloatBuffer(val array: FloatArray) : MutableBuffer<Float> {
|
inline class FloatBuffer(val array: FloatArray) : MutableBuffer<Float> {
|
||||||
override val size: Int get() = array.size
|
override val size: Int get() = array.size
|
||||||
|
|
||||||
override fun get(index: Int): Float = array[index]
|
override operator fun get(index: Int): Float = array[index]
|
||||||
|
|
||||||
override fun set(index: Int, value: Float) {
|
override operator fun set(index: Int, value: Float) {
|
||||||
array[index] = value
|
array[index] = value
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun iterator(): FloatIterator = array.iterator()
|
override operator fun iterator(): FloatIterator = array.iterator()
|
||||||
|
|
||||||
override fun copy(): MutableBuffer<Float> =
|
override fun copy(): MutableBuffer<Float> =
|
||||||
FloatBuffer(array.copyOf())
|
FloatBuffer(array.copyOf())
|
||||||
@ -27,7 +30,11 @@ inline class FloatBuffer(val array: FloatArray) : MutableBuffer<Float> {
|
|||||||
* The function [init] is called for each array element sequentially starting from the first one.
|
* The function [init] is called for each array element sequentially starting from the first one.
|
||||||
* It should return the value for an buffer element given its index.
|
* It should return the value for an buffer element given its index.
|
||||||
*/
|
*/
|
||||||
inline fun FloatBuffer(size: Int, init: (Int) -> Float): FloatBuffer = FloatBuffer(FloatArray(size) { init(it) })
|
@OptIn(ExperimentalContracts::class)
|
||||||
|
inline fun FloatBuffer(size: Int, init: (Int) -> Float): FloatBuffer {
|
||||||
|
contract { callsInPlace(init) }
|
||||||
|
return FloatBuffer(FloatArray(size) { init(it) })
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns a new [FloatBuffer] of given elements.
|
* Returns a new [FloatBuffer] of given elements.
|
||||||
|
@ -1,5 +1,9 @@
|
|||||||
package scientifik.kmath.structures
|
package scientifik.kmath.structures
|
||||||
|
|
||||||
|
import kotlin.contracts.ExperimentalContracts
|
||||||
|
import kotlin.contracts.InvocationKind
|
||||||
|
import kotlin.contracts.contract
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Specialized [MutableBuffer] implementation over [IntArray].
|
* Specialized [MutableBuffer] implementation over [IntArray].
|
||||||
*
|
*
|
||||||
@ -8,17 +12,16 @@ package scientifik.kmath.structures
|
|||||||
inline class IntBuffer(val array: IntArray) : MutableBuffer<Int> {
|
inline class IntBuffer(val array: IntArray) : MutableBuffer<Int> {
|
||||||
override val size: Int get() = array.size
|
override val size: Int get() = array.size
|
||||||
|
|
||||||
override fun get(index: Int): Int = array[index]
|
override operator fun get(index: Int): Int = array[index]
|
||||||
|
|
||||||
override fun set(index: Int, value: Int) {
|
override operator fun set(index: Int, value: Int) {
|
||||||
array[index] = value
|
array[index] = value
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun iterator(): IntIterator = array.iterator()
|
override operator fun iterator(): IntIterator = array.iterator()
|
||||||
|
|
||||||
override fun copy(): MutableBuffer<Int> =
|
override fun copy(): MutableBuffer<Int> =
|
||||||
IntBuffer(array.copyOf())
|
IntBuffer(array.copyOf())
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -28,7 +31,11 @@ inline class IntBuffer(val array: IntArray) : MutableBuffer<Int> {
|
|||||||
* The function [init] is called for each array element sequentially starting from the first one.
|
* The function [init] is called for each array element sequentially starting from the first one.
|
||||||
* It should return the value for an buffer element given its index.
|
* It should return the value for an buffer element given its index.
|
||||||
*/
|
*/
|
||||||
inline fun IntBuffer(size: Int, init: (Int) -> Int): IntBuffer = IntBuffer(IntArray(size) { init(it) })
|
@OptIn(ExperimentalContracts::class)
|
||||||
|
inline fun IntBuffer(size: Int, init: (Int) -> Int): IntBuffer {
|
||||||
|
contract { callsInPlace(init) }
|
||||||
|
return IntBuffer(IntArray(size) { init(it) })
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns a new [IntBuffer] of given elements.
|
* Returns a new [IntBuffer] of given elements.
|
||||||
|
@ -1,5 +1,8 @@
|
|||||||
package scientifik.kmath.structures
|
package scientifik.kmath.structures
|
||||||
|
|
||||||
|
import kotlin.contracts.ExperimentalContracts
|
||||||
|
import kotlin.contracts.contract
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Specialized [MutableBuffer] implementation over [LongArray].
|
* Specialized [MutableBuffer] implementation over [LongArray].
|
||||||
*
|
*
|
||||||
@ -8,13 +11,13 @@ package scientifik.kmath.structures
|
|||||||
inline class LongBuffer(val array: LongArray) : MutableBuffer<Long> {
|
inline class LongBuffer(val array: LongArray) : MutableBuffer<Long> {
|
||||||
override val size: Int get() = array.size
|
override val size: Int get() = array.size
|
||||||
|
|
||||||
override fun get(index: Int): Long = array[index]
|
override operator fun get(index: Int): Long = array[index]
|
||||||
|
|
||||||
override fun set(index: Int, value: Long) {
|
override operator fun set(index: Int, value: Long) {
|
||||||
array[index] = value
|
array[index] = value
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun iterator(): LongIterator = array.iterator()
|
override operator fun iterator(): LongIterator = array.iterator()
|
||||||
|
|
||||||
override fun copy(): MutableBuffer<Long> =
|
override fun copy(): MutableBuffer<Long> =
|
||||||
LongBuffer(array.copyOf())
|
LongBuffer(array.copyOf())
|
||||||
@ -28,7 +31,11 @@ inline class LongBuffer(val array: LongArray) : MutableBuffer<Long> {
|
|||||||
* The function [init] is called for each array element sequentially starting from the first one.
|
* The function [init] is called for each array element sequentially starting from the first one.
|
||||||
* It should return the value for an buffer element given its index.
|
* It should return the value for an buffer element given its index.
|
||||||
*/
|
*/
|
||||||
inline fun LongBuffer(size: Int, init: (Int) -> Long): LongBuffer = LongBuffer(LongArray(size) { init(it) })
|
@OptIn(ExperimentalContracts::class)
|
||||||
|
inline fun LongBuffer(size: Int, init: (Int) -> Long): LongBuffer {
|
||||||
|
contract { callsInPlace(init) }
|
||||||
|
return LongBuffer(LongArray(size) { init(it) })
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns a new [LongBuffer] of given elements.
|
* Returns a new [LongBuffer] of given elements.
|
||||||
|
@ -14,10 +14,8 @@ open class MemoryBuffer<T : Any>(protected val memory: Memory, protected val spe
|
|||||||
|
|
||||||
private val reader: MemoryReader = memory.reader()
|
private val reader: MemoryReader = memory.reader()
|
||||||
|
|
||||||
override fun get(index: Int): T = reader.read(spec, spec.objectSize * index)
|
override operator fun get(index: Int): T = reader.read(spec, spec.objectSize * index)
|
||||||
|
override operator fun iterator(): Iterator<T> = (0 until size).asSequence().map { get(it) }.iterator()
|
||||||
override fun iterator(): Iterator<T> = (0 until size).asSequence().map { get(it) }.iterator()
|
|
||||||
|
|
||||||
|
|
||||||
companion object {
|
companion object {
|
||||||
fun <T : Any> create(spec: MemorySpec<T>, size: Int): MemoryBuffer<T> =
|
fun <T : Any> create(spec: MemorySpec<T>, size: Int): MemoryBuffer<T> =
|
||||||
@ -48,8 +46,7 @@ class MutableMemoryBuffer<T : Any>(memory: Memory, spec: MemorySpec<T>) : Memory
|
|||||||
|
|
||||||
private val writer: MemoryWriter = memory.writer()
|
private val writer: MemoryWriter = memory.writer()
|
||||||
|
|
||||||
override fun set(index: Int, value: T): Unit = writer.write(spec, spec.objectSize * index, value)
|
override operator fun set(index: Int, value: T): Unit = writer.write(spec, spec.objectSize * index, value)
|
||||||
|
|
||||||
override fun copy(): MutableBuffer<T> = MutableMemoryBuffer(memory.copy(), spec)
|
override fun copy(): MutableBuffer<T> = MutableMemoryBuffer(memory.copy(), spec)
|
||||||
|
|
||||||
companion object {
|
companion object {
|
||||||
|
@ -26,19 +26,20 @@ interface NDElement<T, C, N : NDStructure<T>> : NDStructure<T> {
|
|||||||
fun real(shape: IntArray, initializer: RealField.(IntArray) -> Double = { 0.0 }): RealNDElement =
|
fun real(shape: IntArray, initializer: RealField.(IntArray) -> Double = { 0.0 }): RealNDElement =
|
||||||
NDField.real(*shape).produce(initializer)
|
NDField.real(*shape).produce(initializer)
|
||||||
|
|
||||||
|
inline fun real1D(dim: Int, crossinline initializer: (Int) -> Double = { _ -> 0.0 }): RealNDElement =
|
||||||
fun real1D(dim: Int, initializer: (Int) -> Double = { _ -> 0.0 }): RealNDElement =
|
|
||||||
real(intArrayOf(dim)) { initializer(it[0]) }
|
real(intArrayOf(dim)) { initializer(it[0]) }
|
||||||
|
|
||||||
|
inline fun real2D(
|
||||||
|
dim1: Int,
|
||||||
|
dim2: Int,
|
||||||
|
crossinline initializer: (Int, Int) -> Double = { _, _ -> 0.0 }
|
||||||
|
): RealNDElement = real(intArrayOf(dim1, dim2)) { initializer(it[0], it[1]) }
|
||||||
|
|
||||||
fun real2D(dim1: Int, dim2: Int, initializer: (Int, Int) -> Double = { _, _ -> 0.0 }): RealNDElement =
|
inline fun real3D(
|
||||||
real(intArrayOf(dim1, dim2)) { initializer(it[0], it[1]) }
|
|
||||||
|
|
||||||
fun real3D(
|
|
||||||
dim1: Int,
|
dim1: Int,
|
||||||
dim2: Int,
|
dim2: Int,
|
||||||
dim3: Int,
|
dim3: Int,
|
||||||
initializer: (Int, Int, Int) -> Double = { _, _, _ -> 0.0 }
|
crossinline initializer: (Int, Int, Int) -> Double = { _, _, _ -> 0.0 }
|
||||||
): RealNDElement = real(intArrayOf(dim1, dim2, dim3)) { initializer(it[0], it[1], it[2]) }
|
): RealNDElement = real(intArrayOf(dim1, dim2, dim3)) { initializer(it[0], it[1], it[2]) }
|
||||||
|
|
||||||
|
|
||||||
@ -72,7 +73,6 @@ fun <T, C, N : NDStructure<T>> NDElement<T, C, N>.mapIndexed(transform: C.(index
|
|||||||
fun <T, C, N : NDStructure<T>> NDElement<T, C, N>.map(transform: C.(T) -> T): NDElement<T, C, N> =
|
fun <T, C, N : NDStructure<T>> NDElement<T, C, N>.map(transform: C.(T) -> T): NDElement<T, C, N> =
|
||||||
context.map(unwrap(), transform).wrap()
|
context.map(unwrap(), transform).wrap()
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Element by element application of any operation on elements to the whole [NDElement]
|
* Element by element application of any operation on elements to the whole [NDElement]
|
||||||
*/
|
*/
|
||||||
@ -107,7 +107,6 @@ operator fun <T, R : Ring<T>, N : NDStructure<T>> NDElement<T, R, N>.times(arg:
|
|||||||
operator fun <T, F : Field<T>, N : NDStructure<T>> NDElement<T, F, N>.div(arg: T): NDElement<T, F, N> =
|
operator fun <T, F : Field<T>, N : NDStructure<T>> NDElement<T, F, N>.div(arg: T): NDElement<T, F, N> =
|
||||||
map { value -> arg / value }
|
map { value -> arg / value }
|
||||||
|
|
||||||
|
|
||||||
// /**
|
// /**
|
||||||
// * Reverse sum operation
|
// * Reverse sum operation
|
||||||
// */
|
// */
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
package scientifik.kmath.structures
|
package scientifik.kmath.structures
|
||||||
|
|
||||||
|
import kotlin.contracts.ExperimentalContracts
|
||||||
|
import kotlin.contracts.contract
|
||||||
import kotlin.jvm.JvmName
|
import kotlin.jvm.JvmName
|
||||||
import kotlin.reflect.KClass
|
import kotlin.reflect.KClass
|
||||||
|
|
||||||
@ -138,10 +140,10 @@ interface MutableNDStructure<T> : NDStructure<T> {
|
|||||||
operator fun set(index: IntArray, value: T)
|
operator fun set(index: IntArray, value: T)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@OptIn(ExperimentalContracts::class)
|
||||||
inline fun <T> MutableNDStructure<T>.mapInPlace(action: (IntArray, T) -> T) {
|
inline fun <T> MutableNDStructure<T>.mapInPlace(action: (IntArray, T) -> T) {
|
||||||
elements().forEach { (index, oldValue) ->
|
contract { callsInPlace(action) }
|
||||||
this[index] = action(index, oldValue)
|
elements().forEach { (index, oldValue) -> this[index] = action(index, oldValue) }
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -200,14 +202,12 @@ class DefaultStrides private constructor(override val shape: IntArray) : Strides
|
|||||||
}.toList()
|
}.toList()
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun offset(index: IntArray): Int {
|
override fun offset(index: IntArray): Int = index.mapIndexed { i, value ->
|
||||||
return index.mapIndexed { i, value ->
|
if (value < 0 || value >= this.shape[i])
|
||||||
if (value < 0 || value >= this.shape[i]) {
|
throw IndexOutOfBoundsException("Index $value out of shape bounds: (0,${this.shape[i]})")
|
||||||
throw RuntimeException("Index $value out of shape bounds: (0,${this.shape[i]})")
|
|
||||||
}
|
value * strides[i]
|
||||||
value * strides[i]
|
}.sum()
|
||||||
}.sum()
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun index(offset: Int): IntArray {
|
override fun index(offset: Int): IntArray {
|
||||||
val res = IntArray(shape.size)
|
val res = IntArray(shape.size)
|
||||||
@ -259,7 +259,7 @@ abstract class NDBuffer<T> : NDStructure<T> {
|
|||||||
*/
|
*/
|
||||||
abstract val strides: Strides
|
abstract val strides: Strides
|
||||||
|
|
||||||
override fun get(index: IntArray): T = buffer[strides.offset(index)]
|
override operator fun get(index: IntArray): T = buffer[strides.offset(index)]
|
||||||
|
|
||||||
override val shape: IntArray get() = strides.shape
|
override val shape: IntArray get() = strides.shape
|
||||||
|
|
||||||
@ -319,13 +319,13 @@ class MutableBufferNDStructure<T>(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun set(index: IntArray, value: T): Unit = buffer.set(strides.offset(index), value)
|
override operator fun set(index: IntArray, value: T): Unit = buffer.set(strides.offset(index), value)
|
||||||
}
|
}
|
||||||
|
|
||||||
inline fun <reified T : Any> NDStructure<T>.combine(
|
inline fun <reified T : Any> NDStructure<T>.combine(
|
||||||
struct: NDStructure<T>,
|
struct: NDStructure<T>,
|
||||||
crossinline block: (T, T) -> T
|
crossinline block: (T, T) -> T
|
||||||
): NDStructure<T> {
|
): NDStructure<T> {
|
||||||
if (!this.shape.contentEquals(struct.shape)) error("Shape mismatch in structure combination")
|
require(shape.contentEquals(struct.shape)) { "Shape mismatch in structure combination" }
|
||||||
return NDStructure.auto(shape) { block(this[it], struct[it]) }
|
return NDStructure.auto(shape) { block(this[it], struct[it]) }
|
||||||
}
|
}
|
||||||
|
@ -1,5 +1,8 @@
|
|||||||
package scientifik.kmath.structures
|
package scientifik.kmath.structures
|
||||||
|
|
||||||
|
import kotlin.contracts.ExperimentalContracts
|
||||||
|
import kotlin.contracts.contract
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Specialized [MutableBuffer] implementation over [DoubleArray].
|
* Specialized [MutableBuffer] implementation over [DoubleArray].
|
||||||
*
|
*
|
||||||
@ -8,13 +11,13 @@ package scientifik.kmath.structures
|
|||||||
inline class RealBuffer(val array: DoubleArray) : MutableBuffer<Double> {
|
inline class RealBuffer(val array: DoubleArray) : MutableBuffer<Double> {
|
||||||
override val size: Int get() = array.size
|
override val size: Int get() = array.size
|
||||||
|
|
||||||
override fun get(index: Int): Double = array[index]
|
override operator fun get(index: Int): Double = array[index]
|
||||||
|
|
||||||
override fun set(index: Int, value: Double) {
|
override operator fun set(index: Int, value: Double) {
|
||||||
array[index] = value
|
array[index] = value
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun iterator(): DoubleIterator = array.iterator()
|
override operator fun iterator(): DoubleIterator = array.iterator()
|
||||||
|
|
||||||
override fun copy(): MutableBuffer<Double> =
|
override fun copy(): MutableBuffer<Double> =
|
||||||
RealBuffer(array.copyOf())
|
RealBuffer(array.copyOf())
|
||||||
@ -27,7 +30,11 @@ inline class RealBuffer(val array: DoubleArray) : MutableBuffer<Double> {
|
|||||||
* The function [init] is called for each array element sequentially starting from the first one.
|
* The function [init] is called for each array element sequentially starting from the first one.
|
||||||
* It should return the value for an buffer element given its index.
|
* It should return the value for an buffer element given its index.
|
||||||
*/
|
*/
|
||||||
inline fun RealBuffer(size: Int, init: (Int) -> Double): RealBuffer = RealBuffer(DoubleArray(size) { init(it) })
|
@OptIn(ExperimentalContracts::class)
|
||||||
|
inline fun RealBuffer(size: Int, init: (Int) -> Double): RealBuffer {
|
||||||
|
contract { callsInPlace(init) }
|
||||||
|
return RealBuffer(DoubleArray(size) { init(it) })
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns a new [RealBuffer] of given elements.
|
* Returns a new [RealBuffer] of given elements.
|
||||||
|
@ -1,5 +1,8 @@
|
|||||||
package scientifik.kmath.structures
|
package scientifik.kmath.structures
|
||||||
|
|
||||||
|
import kotlin.contracts.ExperimentalContracts
|
||||||
|
import kotlin.contracts.contract
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Specialized [MutableBuffer] implementation over [ShortArray].
|
* Specialized [MutableBuffer] implementation over [ShortArray].
|
||||||
*
|
*
|
||||||
@ -8,17 +11,16 @@ package scientifik.kmath.structures
|
|||||||
inline class ShortBuffer(val array: ShortArray) : MutableBuffer<Short> {
|
inline class ShortBuffer(val array: ShortArray) : MutableBuffer<Short> {
|
||||||
override val size: Int get() = array.size
|
override val size: Int get() = array.size
|
||||||
|
|
||||||
override fun get(index: Int): Short = array[index]
|
override operator fun get(index: Int): Short = array[index]
|
||||||
|
|
||||||
override fun set(index: Int, value: Short) {
|
override operator fun set(index: Int, value: Short) {
|
||||||
array[index] = value
|
array[index] = value
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun iterator(): ShortIterator = array.iterator()
|
override operator fun iterator(): ShortIterator = array.iterator()
|
||||||
|
|
||||||
override fun copy(): MutableBuffer<Short> =
|
override fun copy(): MutableBuffer<Short> =
|
||||||
ShortBuffer(array.copyOf())
|
ShortBuffer(array.copyOf())
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -28,7 +30,11 @@ inline class ShortBuffer(val array: ShortArray) : MutableBuffer<Short> {
|
|||||||
* The function [init] is called for each array element sequentially starting from the first one.
|
* The function [init] is called for each array element sequentially starting from the first one.
|
||||||
* It should return the value for an buffer element given its index.
|
* It should return the value for an buffer element given its index.
|
||||||
*/
|
*/
|
||||||
inline fun ShortBuffer(size: Int, init: (Int) -> Short): ShortBuffer = ShortBuffer(ShortArray(size) { init(it) })
|
@OptIn(ExperimentalContracts::class)
|
||||||
|
inline fun ShortBuffer(size: Int, init: (Int) -> Short): ShortBuffer {
|
||||||
|
contract { callsInPlace(init) }
|
||||||
|
return ShortBuffer(ShortArray(size) { init(it) })
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns a new [ShortBuffer] of given elements.
|
* Returns a new [ShortBuffer] of given elements.
|
||||||
|
@ -6,12 +6,12 @@ package scientifik.kmath.structures
|
|||||||
interface Structure1D<T> : NDStructure<T>, Buffer<T> {
|
interface Structure1D<T> : NDStructure<T>, Buffer<T> {
|
||||||
override val dimension: Int get() = 1
|
override val dimension: Int get() = 1
|
||||||
|
|
||||||
override fun get(index: IntArray): T {
|
override operator fun get(index: IntArray): T {
|
||||||
if (index.size != 1) error("Index dimension mismatch. Expected 1 but found ${index.size}")
|
require(index.size == 1) { "Index dimension mismatch. Expected 1 but found ${index.size}" }
|
||||||
return get(index[0])
|
return get(index[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun iterator(): Iterator<T> = (0 until size).asSequence().map { get(it) }.iterator()
|
override operator fun iterator(): Iterator<T> = (0 until size).asSequence().map { get(it) }.iterator()
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -22,7 +22,7 @@ private inline class Structure1DWrapper<T>(val structure: NDStructure<T>) : Stru
|
|||||||
override val shape: IntArray get() = structure.shape
|
override val shape: IntArray get() = structure.shape
|
||||||
override val size: Int get() = structure.shape[0]
|
override val size: Int get() = structure.shape[0]
|
||||||
|
|
||||||
override fun get(index: Int): T = structure[index]
|
override operator fun get(index: Int): T = structure[index]
|
||||||
|
|
||||||
override fun elements(): Sequence<Pair<IntArray, T>> = structure.elements()
|
override fun elements(): Sequence<Pair<IntArray, T>> = structure.elements()
|
||||||
}
|
}
|
||||||
@ -39,7 +39,7 @@ private inline class Buffer1DWrapper<T>(val buffer: Buffer<T>) : Structure1D<T>
|
|||||||
override fun elements(): Sequence<Pair<IntArray, T>> =
|
override fun elements(): Sequence<Pair<IntArray, T>> =
|
||||||
asSequence().mapIndexed { index, value -> intArrayOf(index) to value }
|
asSequence().mapIndexed { index, value -> intArrayOf(index) to value }
|
||||||
|
|
||||||
override fun get(index: Int): T = buffer[index]
|
override operator fun get(index: Int): T = buffer[index]
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -9,8 +9,8 @@ interface Structure2D<T> : NDStructure<T> {
|
|||||||
|
|
||||||
operator fun get(i: Int, j: Int): T
|
operator fun get(i: Int, j: Int): T
|
||||||
|
|
||||||
override fun get(index: IntArray): T {
|
override operator fun get(index: IntArray): T {
|
||||||
if (index.size != 2) error("Index dimension mismatch. Expected 2 but found ${index.size}")
|
require(index.size == 2) { "Index dimension mismatch. Expected 2 but found ${index.size}" }
|
||||||
return get(index[0], index[1])
|
return get(index[0], index[1])
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -39,10 +39,10 @@ interface Structure2D<T> : NDStructure<T> {
|
|||||||
* A 2D wrapper for nd-structure
|
* A 2D wrapper for nd-structure
|
||||||
*/
|
*/
|
||||||
private inline class Structure2DWrapper<T>(val structure: NDStructure<T>) : Structure2D<T> {
|
private inline class Structure2DWrapper<T>(val structure: NDStructure<T>) : Structure2D<T> {
|
||||||
override fun get(i: Int, j: Int): T = structure[i, j]
|
|
||||||
|
|
||||||
override val shape: IntArray get() = structure.shape
|
override val shape: IntArray get() = structure.shape
|
||||||
|
|
||||||
|
override operator fun get(i: Int, j: Int): T = structure[i, j]
|
||||||
|
|
||||||
override fun elements(): Sequence<Pair<IntArray, T>> = structure.elements()
|
override fun elements(): Sequence<Pair<IntArray, T>> = structure.elements()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3,6 +3,7 @@ package scientifik.kmath.expressions
|
|||||||
import scientifik.kmath.operations.Complex
|
import scientifik.kmath.operations.Complex
|
||||||
import scientifik.kmath.operations.ComplexField
|
import scientifik.kmath.operations.ComplexField
|
||||||
import scientifik.kmath.operations.RealField
|
import scientifik.kmath.operations.RealField
|
||||||
|
import scientifik.kmath.operations.invoke
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
@ -10,10 +11,12 @@ class ExpressionFieldTest {
|
|||||||
@Test
|
@Test
|
||||||
fun testExpression() {
|
fun testExpression() {
|
||||||
val context = FunctionalExpressionField(RealField)
|
val context = FunctionalExpressionField(RealField)
|
||||||
val expression = with(context) {
|
|
||||||
|
val expression = context {
|
||||||
val x = variable("x", 2.0)
|
val x = variable("x", 2.0)
|
||||||
x * x + 2 * x + one
|
x * x + 2 * x + one
|
||||||
}
|
}
|
||||||
|
|
||||||
assertEquals(expression("x" to 1.0), 4.0)
|
assertEquals(expression("x" to 1.0), 4.0)
|
||||||
assertEquals(expression(), 9.0)
|
assertEquals(expression(), 9.0)
|
||||||
}
|
}
|
||||||
@ -21,10 +24,12 @@ class ExpressionFieldTest {
|
|||||||
@Test
|
@Test
|
||||||
fun testComplex() {
|
fun testComplex() {
|
||||||
val context = FunctionalExpressionField(ComplexField)
|
val context = FunctionalExpressionField(ComplexField)
|
||||||
val expression = with(context) {
|
|
||||||
|
val expression = context {
|
||||||
val x = variable("x", Complex(2.0, 0.0))
|
val x = variable("x", Complex(2.0, 0.0))
|
||||||
x * x + 2 * x + one
|
x * x + 2 * x + one
|
||||||
}
|
}
|
||||||
|
|
||||||
assertEquals(expression("x" to Complex(1.0, 0.0)), Complex(4.0, 0.0))
|
assertEquals(expression("x" to Complex(1.0, 0.0)), Complex(4.0, 0.0))
|
||||||
assertEquals(expression(), Complex(9.0, 0.0))
|
assertEquals(expression(), Complex(9.0, 0.0))
|
||||||
}
|
}
|
||||||
|
@ -7,7 +7,6 @@ import kotlin.test.Test
|
|||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
class MatrixTest {
|
class MatrixTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testTranspose() {
|
fun testTranspose() {
|
||||||
val matrix = MatrixContext.real.one(3, 3)
|
val matrix = MatrixContext.real.one(3, 3)
|
||||||
@ -51,6 +50,7 @@ class MatrixTest {
|
|||||||
fun test2DDot() {
|
fun test2DDot() {
|
||||||
val firstMatrix = NDStructure.auto(2, 3) { (i, j) -> (i + j).toDouble() }.as2D()
|
val firstMatrix = NDStructure.auto(2, 3) { (i, j) -> (i + j).toDouble() }.as2D()
|
||||||
val secondMatrix = NDStructure.auto(3, 2) { (i, j) -> (i + j).toDouble() }.as2D()
|
val secondMatrix = NDStructure.auto(3, 2) { (i, j) -> (i + j).toDouble() }.as2D()
|
||||||
|
|
||||||
MatrixContext.real.run {
|
MatrixContext.real.run {
|
||||||
// val firstMatrix = produce(2, 3) { i, j -> (i + j).toDouble() }
|
// val firstMatrix = produce(2, 3) { i, j -> (i + j).toDouble() }
|
||||||
// val secondMatrix = produce(3, 2) { i, j -> (i + j).toDouble() }
|
// val secondMatrix = produce(3, 2) { i, j -> (i + j).toDouble() }
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package scientifik.kmath.structures
|
package scientifik.kmath.structures
|
||||||
|
|
||||||
import scientifik.kmath.operations.Norm
|
import scientifik.kmath.operations.Norm
|
||||||
|
import scientifik.kmath.operations.invoke
|
||||||
import scientifik.kmath.structures.NDElement.Companion.real2D
|
import scientifik.kmath.structures.NDElement.Companion.real2D
|
||||||
import kotlin.math.abs
|
import kotlin.math.abs
|
||||||
import kotlin.math.pow
|
import kotlin.math.pow
|
||||||
@ -56,17 +57,12 @@ class NumberNDFieldTest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
object L2Norm : Norm<NDStructure<out Number>, Double> {
|
object L2Norm : Norm<NDStructure<out Number>, Double> {
|
||||||
override fun norm(arg: NDStructure<out Number>): Double {
|
override fun norm(arg: NDStructure<out Number>): Double =
|
||||||
return kotlin.math.sqrt(arg.elements().sumByDouble { it.second.toDouble() })
|
kotlin.math.sqrt(arg.elements().sumByDouble { it.second.toDouble() })
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testInternalContext() {
|
fun testInternalContext() {
|
||||||
NDField.real(*array1.shape).run {
|
(NDField.real(*array1.shape)) { with(L2Norm) { 1 + norm(array1) + exp(array2) } }
|
||||||
with(L2Norm) {
|
|
||||||
1 + norm(array1) + exp(array2)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -17,10 +17,10 @@ object JBigIntegerField : Field<BigInteger> {
|
|||||||
override fun number(value: Number): BigInteger = BigInteger.valueOf(value.toLong())
|
override fun number(value: Number): BigInteger = BigInteger.valueOf(value.toLong())
|
||||||
override fun divide(a: BigInteger, b: BigInteger): BigInteger = a.div(b)
|
override fun divide(a: BigInteger, b: BigInteger): BigInteger = a.div(b)
|
||||||
override fun add(a: BigInteger, b: BigInteger): BigInteger = a.add(b)
|
override fun add(a: BigInteger, b: BigInteger): BigInteger = a.add(b)
|
||||||
override fun BigInteger.minus(b: BigInteger): BigInteger = this.subtract(b)
|
override operator fun BigInteger.minus(b: BigInteger): BigInteger = subtract(b)
|
||||||
override fun multiply(a: BigInteger, k: Number): BigInteger = a.multiply(k.toInt().toBigInteger())
|
override fun multiply(a: BigInteger, k: Number): BigInteger = a.multiply(k.toInt().toBigInteger())
|
||||||
override fun multiply(a: BigInteger, b: BigInteger): BigInteger = a.multiply(b)
|
override fun multiply(a: BigInteger, b: BigInteger): BigInteger = a.multiply(b)
|
||||||
override fun BigInteger.unaryMinus(): BigInteger = negate()
|
override operator fun BigInteger.unaryMinus(): BigInteger = negate()
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -38,7 +38,7 @@ abstract class JBigDecimalFieldBase internal constructor(val mathContext: MathCo
|
|||||||
get() = BigDecimal.ONE
|
get() = BigDecimal.ONE
|
||||||
|
|
||||||
override fun add(a: BigDecimal, b: BigDecimal): BigDecimal = a.add(b)
|
override fun add(a: BigDecimal, b: BigDecimal): BigDecimal = a.add(b)
|
||||||
override fun BigDecimal.minus(b: BigDecimal): BigDecimal = subtract(b)
|
override operator fun BigDecimal.minus(b: BigDecimal): BigDecimal = subtract(b)
|
||||||
override fun number(value: Number): BigDecimal = BigDecimal.valueOf(value.toDouble())
|
override fun number(value: Number): BigDecimal = BigDecimal.valueOf(value.toDouble())
|
||||||
|
|
||||||
override fun multiply(a: BigDecimal, k: Number): BigDecimal =
|
override fun multiply(a: BigDecimal, k: Number): BigDecimal =
|
||||||
@ -48,8 +48,7 @@ abstract class JBigDecimalFieldBase internal constructor(val mathContext: MathCo
|
|||||||
override fun divide(a: BigDecimal, b: BigDecimal): BigDecimal = a.divide(b, mathContext)
|
override fun divide(a: BigDecimal, b: BigDecimal): BigDecimal = a.divide(b, mathContext)
|
||||||
override fun power(arg: BigDecimal, pow: Number): BigDecimal = arg.pow(pow.toInt(), mathContext)
|
override fun power(arg: BigDecimal, pow: Number): BigDecimal = arg.pow(pow.toInt(), mathContext)
|
||||||
override fun sqrt(arg: BigDecimal): BigDecimal = arg.sqrt(mathContext)
|
override fun sqrt(arg: BigDecimal): BigDecimal = arg.sqrt(mathContext)
|
||||||
override fun BigDecimal.unaryMinus(): BigDecimal = negate(mathContext)
|
override operator fun BigDecimal.unaryMinus(): BigDecimal = negate(mathContext)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -139,9 +139,10 @@ fun <T, R> Chain<T>.map(func: suspend (T) -> R): Chain<R> = object : Chain<R> {
|
|||||||
fun <T> Chain<T>.filter(block: (T) -> Boolean): Chain<T> = object : Chain<T> {
|
fun <T> Chain<T>.filter(block: (T) -> Boolean): Chain<T> = object : Chain<T> {
|
||||||
override suspend fun next(): T {
|
override suspend fun next(): T {
|
||||||
var next: T
|
var next: T
|
||||||
do {
|
|
||||||
next = this@filter.next()
|
do next = this@filter.next()
|
||||||
} while (!block(next))
|
while (!block(next))
|
||||||
|
|
||||||
return next
|
return next
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -159,6 +160,7 @@ fun <T, R> Chain<T>.collect(mapper: suspend (Chain<T>) -> R): Chain<R> = object
|
|||||||
fun <T, S, R> Chain<T>.collectWithState(state: S, stateFork: (S) -> S, mapper: suspend S.(Chain<T>) -> R): Chain<R> =
|
fun <T, S, R> Chain<T>.collectWithState(state: S, stateFork: (S) -> S, mapper: suspend S.(Chain<T>) -> R): Chain<R> =
|
||||||
object : Chain<R> {
|
object : Chain<R> {
|
||||||
override suspend fun next(): R = state.mapper(this@collectWithState)
|
override suspend fun next(): R = state.mapper(this@collectWithState)
|
||||||
|
|
||||||
override fun fork(): Chain<R> =
|
override fun fork(): Chain<R> =
|
||||||
this@collectWithState.fork().collectWithState(stateFork(state), stateFork, mapper)
|
this@collectWithState.fork().collectWithState(stateFork(state), stateFork, mapper)
|
||||||
}
|
}
|
||||||
@ -168,6 +170,5 @@ fun <T, S, R> Chain<T>.collectWithState(state: S, stateFork: (S) -> S, mapper: s
|
|||||||
*/
|
*/
|
||||||
fun <T, U, R> Chain<T>.zip(other: Chain<U>, block: suspend (T, U) -> R): Chain<R> = object : Chain<R> {
|
fun <T, U, R> Chain<T>.zip(other: Chain<U>, block: suspend (T, U) -> R): Chain<R> = object : Chain<R> {
|
||||||
override suspend fun next(): R = block(this@zip.next(), other.next())
|
override suspend fun next(): R = block(this@zip.next(), other.next())
|
||||||
|
|
||||||
override fun fork(): Chain<R> = this@zip.fork().zip(other.fork(), block)
|
override fun fork(): Chain<R> = this@zip.fork().zip(other.fork(), block)
|
||||||
}
|
}
|
||||||
|
@ -7,15 +7,16 @@ import kotlinx.coroutines.flow.scan
|
|||||||
import kotlinx.coroutines.flow.scanReduce
|
import kotlinx.coroutines.flow.scanReduce
|
||||||
import scientifik.kmath.operations.Space
|
import scientifik.kmath.operations.Space
|
||||||
import scientifik.kmath.operations.SpaceOperations
|
import scientifik.kmath.operations.SpaceOperations
|
||||||
|
import scientifik.kmath.operations.invoke
|
||||||
|
|
||||||
|
|
||||||
@ExperimentalCoroutinesApi
|
@ExperimentalCoroutinesApi
|
||||||
fun <T> Flow<T>.cumulativeSum(space: SpaceOperations<T>): Flow<T> = with(space) {
|
fun <T> Flow<T>.cumulativeSum(space: SpaceOperations<T>): Flow<T> = space {
|
||||||
scanReduce { sum: T, element: T -> sum + element }
|
scanReduce { sum: T, element: T -> sum + element }
|
||||||
}
|
}
|
||||||
|
|
||||||
@ExperimentalCoroutinesApi
|
@ExperimentalCoroutinesApi
|
||||||
fun <T> Flow<T>.mean(space: Space<T>): Flow<T> = with(space) {
|
fun <T> Flow<T>.mean(space: Space<T>): Flow<T> = space {
|
||||||
class Accumulator(var sum: T, var num: Int)
|
class Accumulator(var sum: T, var num: Int)
|
||||||
|
|
||||||
scan(Accumulator(zero, 0)) { sum, element ->
|
scan(Accumulator(zero, 0)) { sum, element ->
|
||||||
|
@ -3,6 +3,8 @@ package scientifik.kmath.coroutines
|
|||||||
import kotlinx.coroutines.*
|
import kotlinx.coroutines.*
|
||||||
import kotlinx.coroutines.channels.produce
|
import kotlinx.coroutines.channels.produce
|
||||||
import kotlinx.coroutines.flow.*
|
import kotlinx.coroutines.flow.*
|
||||||
|
import kotlin.contracts.ExperimentalContracts
|
||||||
|
import kotlin.contracts.contract
|
||||||
|
|
||||||
val Dispatchers.Math: CoroutineDispatcher
|
val Dispatchers.Math: CoroutineDispatcher
|
||||||
get() = Default
|
get() = Default
|
||||||
@ -81,21 +83,24 @@ suspend fun <T> AsyncFlow<T>.collect(concurrency: Int, collector: FlowCollector<
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@OptIn(ExperimentalContracts::class)
|
||||||
@ExperimentalCoroutinesApi
|
@ExperimentalCoroutinesApi
|
||||||
@FlowPreview
|
@FlowPreview
|
||||||
suspend fun <T> AsyncFlow<T>.collect(concurrency: Int, action: suspend (value: T) -> Unit) {
|
suspend inline fun <T> AsyncFlow<T>.collect(concurrency: Int, crossinline action: suspend (value: T) -> Unit) {
|
||||||
|
contract { callsInPlace(action) }
|
||||||
|
|
||||||
collect(concurrency, object : FlowCollector<T> {
|
collect(concurrency, object : FlowCollector<T> {
|
||||||
override suspend fun emit(value: T): Unit = action(value)
|
override suspend fun emit(value: T): Unit = action(value)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@OptIn(ExperimentalContracts::class)
|
||||||
@ExperimentalCoroutinesApi
|
@ExperimentalCoroutinesApi
|
||||||
@FlowPreview
|
@FlowPreview
|
||||||
fun <T, R> Flow<T>.mapParallel(
|
inline fun <T, R> Flow<T>.mapParallel(
|
||||||
dispatcher: CoroutineDispatcher = Dispatchers.Default,
|
dispatcher: CoroutineDispatcher = Dispatchers.Default,
|
||||||
transform: suspend (T) -> R
|
crossinline transform: suspend (T) -> R
|
||||||
): Flow<R> {
|
): Flow<R> {
|
||||||
return flatMapMerge { value ->
|
contract { callsInPlace(transform) }
|
||||||
flow { emit(transform(value)) }
|
return flatMapMerge { value -> flow { emit(transform(value)) } }.flowOn(dispatcher)
|
||||||
}.flowOn(dispatcher)
|
|
||||||
}
|
}
|
||||||
|
@ -20,7 +20,7 @@ class RingBuffer<T>(
|
|||||||
override var size: Int = size
|
override var size: Int = size
|
||||||
private set
|
private set
|
||||||
|
|
||||||
override fun get(index: Int): T {
|
override operator fun get(index: Int): T {
|
||||||
require(index >= 0) { "Index must be positive" }
|
require(index >= 0) { "Index must be positive" }
|
||||||
require(index < size) { "Index $index is out of circular buffer size $size" }
|
require(index < size) { "Index $index is out of circular buffer size $size" }
|
||||||
return buffer[startIndex.forward(index)] as T
|
return buffer[startIndex.forward(index)] as T
|
||||||
@ -31,15 +31,13 @@ class RingBuffer<T>(
|
|||||||
/**
|
/**
|
||||||
* Iterator could provide wrong results if buffer is changed in initialization (iteration is safe)
|
* Iterator could provide wrong results if buffer is changed in initialization (iteration is safe)
|
||||||
*/
|
*/
|
||||||
override fun iterator(): Iterator<T> = object : AbstractIterator<T>() {
|
override operator fun iterator(): Iterator<T> = object : AbstractIterator<T>() {
|
||||||
private var count = size
|
private var count = size
|
||||||
private var index = startIndex
|
private var index = startIndex
|
||||||
val copy = buffer.copy()
|
val copy = buffer.copy()
|
||||||
|
|
||||||
override fun computeNext() {
|
override fun computeNext() {
|
||||||
if (count == 0) {
|
if (count == 0) done() else {
|
||||||
done()
|
|
||||||
} else {
|
|
||||||
setNext(copy[index] as T)
|
setNext(copy[index] as T)
|
||||||
index = index.forward(1)
|
index = index.forward(1)
|
||||||
count--
|
count--
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
package scientifik.kmath.chains
|
package scientifik.kmath.chains
|
||||||
|
|
||||||
import kotlinx.coroutines.runBlocking
|
import kotlinx.coroutines.runBlocking
|
||||||
import kotlin.sequences.Sequence
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Represent a chain as regular iterator (uses blocking calls)
|
* Represent a chain as regular iterator (uses blocking calls)
|
||||||
@ -15,6 +14,4 @@ operator fun <R> Chain<R>.iterator(): Iterator<R> = object : Iterator<R> {
|
|||||||
/**
|
/**
|
||||||
* Represent a chain as a sequence
|
* Represent a chain as a sequence
|
||||||
*/
|
*/
|
||||||
fun <R> Chain<R>.asSequence(): Sequence<R> = object : Sequence<R> {
|
fun <R> Chain<R>.asSequence(): Sequence<R> = Sequence { this@asSequence.iterator() }
|
||||||
override fun iterator(): Iterator<R> = this@asSequence.iterator()
|
|
||||||
}
|
|
@ -18,7 +18,7 @@ class LazyNDStructure<T>(
|
|||||||
|
|
||||||
suspend fun await(index: IntArray): T = deferred(index).await()
|
suspend fun await(index: IntArray): T = deferred(index).await()
|
||||||
|
|
||||||
override fun get(index: IntArray): T = runBlocking {
|
override operator fun get(index: IntArray): T = runBlocking {
|
||||||
deferred(index).await()
|
deferred(index).await()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -52,10 +52,12 @@ suspend fun <T> NDStructure<T>.await(index: IntArray): T =
|
|||||||
/**
|
/**
|
||||||
* PENDING would benefit from KEEP-176
|
* PENDING would benefit from KEEP-176
|
||||||
*/
|
*/
|
||||||
fun <T, R> NDStructure<T>.mapAsyncIndexed(
|
inline fun <T, R> NDStructure<T>.mapAsyncIndexed(
|
||||||
scope: CoroutineScope,
|
scope: CoroutineScope,
|
||||||
function: suspend (T, index: IntArray) -> R
|
crossinline function: suspend (T, index: IntArray) -> R
|
||||||
): LazyNDStructure<R> = LazyNDStructure(scope, shape) { index -> function(get(index), index) }
|
): LazyNDStructure<R> = LazyNDStructure(scope, shape) { index -> function(get(index), index) }
|
||||||
|
|
||||||
fun <T, R> NDStructure<T>.mapAsync(scope: CoroutineScope, function: suspend (T) -> R): LazyNDStructure<R> =
|
inline fun <T, R> NDStructure<T>.mapAsync(
|
||||||
LazyNDStructure(scope, shape) { index -> function(get(index)) }
|
scope: CoroutineScope,
|
||||||
|
crossinline function: suspend (T) -> R
|
||||||
|
): LazyNDStructure<R> = LazyNDStructure(scope, shape) { index -> function(get(index)) }
|
||||||
|
@ -4,7 +4,9 @@ import scientifik.kmath.linear.GenericMatrixContext
|
|||||||
import scientifik.kmath.linear.MatrixContext
|
import scientifik.kmath.linear.MatrixContext
|
||||||
import scientifik.kmath.linear.Point
|
import scientifik.kmath.linear.Point
|
||||||
import scientifik.kmath.linear.transpose
|
import scientifik.kmath.linear.transpose
|
||||||
|
import scientifik.kmath.operations.RealField
|
||||||
import scientifik.kmath.operations.Ring
|
import scientifik.kmath.operations.Ring
|
||||||
|
import scientifik.kmath.operations.invoke
|
||||||
import scientifik.kmath.structures.Matrix
|
import scientifik.kmath.structures.Matrix
|
||||||
import scientifik.kmath.structures.Structure2D
|
import scientifik.kmath.structures.Structure2D
|
||||||
|
|
||||||
@ -42,7 +44,7 @@ inline class DMatrixWrapper<T, R : Dimension, C : Dimension>(
|
|||||||
val structure: Structure2D<T>
|
val structure: Structure2D<T>
|
||||||
) : DMatrix<T, R, C> {
|
) : DMatrix<T, R, C> {
|
||||||
override val shape: IntArray get() = structure.shape
|
override val shape: IntArray get() = structure.shape
|
||||||
override fun get(i: Int, j: Int): T = structure[i, j]
|
override operator fun get(i: Int, j: Int): T = structure[i, j]
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -70,9 +72,9 @@ inline class DPointWrapper<T, D : Dimension>(val point: Point<T>) :
|
|||||||
DPoint<T, D> {
|
DPoint<T, D> {
|
||||||
override val size: Int get() = point.size
|
override val size: Int get() = point.size
|
||||||
|
|
||||||
override fun get(index: Int): T = point[index]
|
override operator fun get(index: Int): T = point[index]
|
||||||
|
|
||||||
override fun iterator(): Iterator<T> = point.iterator()
|
override operator fun iterator(): Iterator<T> = point.iterator()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -82,12 +84,14 @@ inline class DPointWrapper<T, D : Dimension>(val point: Point<T>) :
|
|||||||
inline class DMatrixContext<T : Any, Ri : Ring<T>>(val context: GenericMatrixContext<T, Ri>) {
|
inline class DMatrixContext<T : Any, Ri : Ring<T>>(val context: GenericMatrixContext<T, Ri>) {
|
||||||
|
|
||||||
inline fun <reified R : Dimension, reified C : Dimension> Matrix<T>.coerce(): DMatrix<T, R, C> {
|
inline fun <reified R : Dimension, reified C : Dimension> Matrix<T>.coerce(): DMatrix<T, R, C> {
|
||||||
if (rowNum != Dimension.dim<R>().toInt()) {
|
check(
|
||||||
error("Row number mismatch: expected ${Dimension.dim<R>()} but found $rowNum")
|
rowNum == Dimension.dim<R>().toInt()
|
||||||
}
|
) { "Row number mismatch: expected ${Dimension.dim<R>()} but found $rowNum" }
|
||||||
if (colNum != Dimension.dim<C>().toInt()) {
|
|
||||||
error("Column number mismatch: expected ${Dimension.dim<C>()} but found $colNum")
|
check(
|
||||||
}
|
colNum == Dimension.dim<C>().toInt()
|
||||||
|
) { "Column number mismatch: expected ${Dimension.dim<C>()} but found $colNum" }
|
||||||
|
|
||||||
return DMatrix.coerceUnsafe(this)
|
return DMatrix.coerceUnsafe(this)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -97,11 +101,12 @@ inline class DMatrixContext<T : Any, Ri : Ring<T>>(val context: GenericMatrixCon
|
|||||||
inline fun <reified R : Dimension, reified C : Dimension> produce(noinline initializer: (i: Int, j: Int) -> T): DMatrix<T, R, C> {
|
inline fun <reified R : Dimension, reified C : Dimension> produce(noinline initializer: (i: Int, j: Int) -> T): DMatrix<T, R, C> {
|
||||||
val rows = Dimension.dim<R>()
|
val rows = Dimension.dim<R>()
|
||||||
val cols = Dimension.dim<C>()
|
val cols = Dimension.dim<C>()
|
||||||
return context.produce(rows.toInt(), cols.toInt(), initializer).coerce<R,C>()
|
return context.produce(rows.toInt(), cols.toInt(), initializer).coerce<R, C>()
|
||||||
}
|
}
|
||||||
|
|
||||||
inline fun <reified D : Dimension> point(noinline initializer: (Int) -> T): DPoint<T, D> {
|
inline fun <reified D : Dimension> point(noinline initializer: (Int) -> T): DPoint<T, D> {
|
||||||
val size = Dimension.dim<D>()
|
val size = Dimension.dim<D>()
|
||||||
|
|
||||||
return DPoint.coerceUnsafe(
|
return DPoint.coerceUnsafe(
|
||||||
context.point(
|
context.point(
|
||||||
size.toInt(),
|
size.toInt(),
|
||||||
@ -112,37 +117,28 @@ inline class DMatrixContext<T : Any, Ri : Ring<T>>(val context: GenericMatrixCon
|
|||||||
|
|
||||||
inline infix fun <reified R1 : Dimension, reified C1 : Dimension, reified C2 : Dimension> DMatrix<T, R1, C1>.dot(
|
inline infix fun <reified R1 : Dimension, reified C1 : Dimension, reified C2 : Dimension> DMatrix<T, R1, C1>.dot(
|
||||||
other: DMatrix<T, C1, C2>
|
other: DMatrix<T, C1, C2>
|
||||||
): DMatrix<T, R1, C2> {
|
): DMatrix<T, R1, C2> = context { this@dot dot other }.coerce()
|
||||||
return context.run { this@dot dot other }.coerce()
|
|
||||||
}
|
|
||||||
|
|
||||||
inline infix fun <reified R : Dimension, reified C : Dimension> DMatrix<T, R, C>.dot(vector: DPoint<T, C>): DPoint<T, R> {
|
inline infix fun <reified R : Dimension, reified C : Dimension> DMatrix<T, R, C>.dot(vector: DPoint<T, C>): DPoint<T, R> =
|
||||||
return DPoint.coerceUnsafe(context.run { this@dot dot vector })
|
DPoint.coerceUnsafe(context { this@dot dot vector })
|
||||||
}
|
|
||||||
|
|
||||||
inline operator fun <reified R : Dimension, reified C : Dimension> DMatrix<T, R, C>.times(value: T): DMatrix<T, R, C> {
|
inline operator fun <reified R : Dimension, reified C : Dimension> DMatrix<T, R, C>.times(value: T): DMatrix<T, R, C> =
|
||||||
return context.run { this@times.times(value) }.coerce()
|
context { this@times.times(value) }.coerce()
|
||||||
}
|
|
||||||
|
|
||||||
inline operator fun <reified R : Dimension, reified C : Dimension> T.times(m: DMatrix<T, R, C>): DMatrix<T, R, C> =
|
inline operator fun <reified R : Dimension, reified C : Dimension> T.times(m: DMatrix<T, R, C>): DMatrix<T, R, C> =
|
||||||
m * this
|
m * this
|
||||||
|
|
||||||
|
inline operator fun <reified R : Dimension, reified C : Dimension> DMatrix<T, C, R>.plus(other: DMatrix<T, C, R>): DMatrix<T, C, R> =
|
||||||
|
context { this@plus + other }.coerce()
|
||||||
|
|
||||||
inline operator fun <reified R : Dimension, reified C : Dimension> DMatrix<T, C, R>.plus(other: DMatrix<T, C, R>): DMatrix<T, C, R> {
|
inline operator fun <reified R : Dimension, reified C : Dimension> DMatrix<T, C, R>.minus(other: DMatrix<T, C, R>): DMatrix<T, C, R> =
|
||||||
return context.run { this@plus + other }.coerce()
|
context { this@minus + other }.coerce()
|
||||||
}
|
|
||||||
|
|
||||||
inline operator fun <reified R : Dimension, reified C : Dimension> DMatrix<T, C, R>.minus(other: DMatrix<T, C, R>): DMatrix<T, C, R> {
|
inline operator fun <reified R : Dimension, reified C : Dimension> DMatrix<T, C, R>.unaryMinus(): DMatrix<T, C, R> =
|
||||||
return context.run { this@minus + other }.coerce()
|
context { this@unaryMinus.unaryMinus() }.coerce()
|
||||||
}
|
|
||||||
|
|
||||||
inline operator fun <reified R : Dimension, reified C : Dimension> DMatrix<T, C, R>.unaryMinus(): DMatrix<T, C, R> {
|
inline fun <reified R : Dimension, reified C : Dimension> DMatrix<T, C, R>.transpose(): DMatrix<T, R, C> =
|
||||||
return context.run { this@unaryMinus.unaryMinus() }.coerce()
|
context { (this@transpose as Matrix<T>).transpose() }.coerce()
|
||||||
}
|
|
||||||
|
|
||||||
inline fun <reified R : Dimension, reified C : Dimension> DMatrix<T, C, R>.transpose(): DMatrix<T, R, C> {
|
|
||||||
return context.run { (this@transpose as Matrix<T>).transpose() }.coerce()
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A square unit matrix
|
* A square unit matrix
|
||||||
@ -156,6 +152,6 @@ inline class DMatrixContext<T : Any, Ri : Ring<T>>(val context: GenericMatrixCon
|
|||||||
}
|
}
|
||||||
|
|
||||||
companion object {
|
companion object {
|
||||||
val real = DMatrixContext(MatrixContext.real)
|
val real: DMatrixContext<Double, RealField> = DMatrixContext(MatrixContext.real)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -5,11 +5,10 @@ import scientifik.kmath.dimensions.D3
|
|||||||
import scientifik.kmath.dimensions.DMatrixContext
|
import scientifik.kmath.dimensions.DMatrixContext
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
|
|
||||||
|
|
||||||
class DMatrixContextTest {
|
class DMatrixContextTest {
|
||||||
@Test
|
@Test
|
||||||
fun testDimensionSafeMatrix() {
|
fun testDimensionSafeMatrix() {
|
||||||
val res = DMatrixContext.real.run {
|
val res = with(DMatrixContext.real) {
|
||||||
val m = produce<D2, D2> { i, j -> (i + j).toDouble() }
|
val m = produce<D2, D2> { i, j -> (i + j).toDouble() }
|
||||||
|
|
||||||
//The dimension of `one()` is inferred from type
|
//The dimension of `one()` is inferred from type
|
||||||
@ -19,7 +18,7 @@ class DMatrixContextTest {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testTypeCheck() {
|
fun testTypeCheck() {
|
||||||
val res = DMatrixContext.real.run {
|
val res = with(DMatrixContext.real) {
|
||||||
val m1 = produce<D2, D3> { i, j -> (i + j).toDouble() }
|
val m1 = produce<D2, D3> { i, j -> (i + j).toDouble() }
|
||||||
val m2 = produce<D3, D2> { i, j -> (i + j).toDouble() }
|
val m2 = produce<D3, D2> { i, j -> (i + j).toDouble() }
|
||||||
|
|
||||||
|
@ -14,8 +14,8 @@ import kotlin.math.sqrt
|
|||||||
|
|
||||||
typealias RealPoint = Point<Double>
|
typealias RealPoint = Point<Double>
|
||||||
|
|
||||||
fun DoubleArray.asVector() = RealVector(this.asBuffer())
|
fun DoubleArray.asVector(): RealVector = RealVector(this.asBuffer())
|
||||||
fun List<Double>.asVector() = RealVector(this.asBuffer())
|
fun List<Double>.asVector(): RealVector = RealVector(this.asBuffer())
|
||||||
|
|
||||||
object VectorL2Norm : Norm<Point<out Number>, Double> {
|
object VectorL2Norm : Norm<Point<out Number>, Double> {
|
||||||
override fun norm(arg: Point<out Number>): Double = sqrt(arg.asIterable().sumByDouble { it.toDouble() })
|
override fun norm(arg: Point<out Number>): Double = sqrt(arg.asIterable().sumByDouble { it.toDouble() })
|
||||||
@ -32,15 +32,14 @@ inline class RealVector(private val point: Point<Double>) :
|
|||||||
|
|
||||||
override val size: Int get() = point.size
|
override val size: Int get() = point.size
|
||||||
|
|
||||||
override fun get(index: Int): Double = point[index]
|
override operator fun get(index: Int): Double = point[index]
|
||||||
|
|
||||||
override fun iterator(): Iterator<Double> = point.iterator()
|
override operator fun iterator(): Iterator<Double> = point.iterator()
|
||||||
|
|
||||||
companion object {
|
companion object {
|
||||||
|
private val spaceCache: MutableMap<Int, BufferVectorSpace<Double, RealField>> = hashMapOf()
|
||||||
|
|
||||||
private val spaceCache = HashMap<Int, BufferVectorSpace<Double, RealField>>()
|
inline operator fun invoke(dim: Int, initializer: (Int) -> Double): RealVector =
|
||||||
|
|
||||||
inline operator fun invoke(dim: Int, initializer: (Int) -> Double) =
|
|
||||||
RealVector(RealBuffer(dim, initializer))
|
RealVector(RealBuffer(dim, initializer))
|
||||||
|
|
||||||
operator fun invoke(vararg values: Double): RealVector = values.asVector()
|
operator fun invoke(vararg values: Double): RealVector = values.asVector()
|
||||||
@ -49,4 +48,4 @@ inline class RealVector(private val point: Point<Double>) :
|
|||||||
BufferVectorSpace(dim, RealField) { size, init -> Buffer.real(size, init) }
|
BufferVectorSpace(dim, RealField) { size, init -> Buffer.real(size, init) }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -3,11 +3,14 @@ package scientifik.kmath.real
|
|||||||
import scientifik.kmath.linear.MatrixContext
|
import scientifik.kmath.linear.MatrixContext
|
||||||
import scientifik.kmath.linear.RealMatrixContext.elementContext
|
import scientifik.kmath.linear.RealMatrixContext.elementContext
|
||||||
import scientifik.kmath.linear.VirtualMatrix
|
import scientifik.kmath.linear.VirtualMatrix
|
||||||
|
import scientifik.kmath.operations.invoke
|
||||||
import scientifik.kmath.operations.sum
|
import scientifik.kmath.operations.sum
|
||||||
import scientifik.kmath.structures.Buffer
|
import scientifik.kmath.structures.Buffer
|
||||||
import scientifik.kmath.structures.Matrix
|
import scientifik.kmath.structures.Matrix
|
||||||
import scientifik.kmath.structures.RealBuffer
|
import scientifik.kmath.structures.RealBuffer
|
||||||
import scientifik.kmath.structures.asIterable
|
import scientifik.kmath.structures.asIterable
|
||||||
|
import kotlin.contracts.ExperimentalContracts
|
||||||
|
import kotlin.contracts.contract
|
||||||
import kotlin.math.pow
|
import kotlin.math.pow
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@ -27,7 +30,7 @@ typealias RealMatrix = Matrix<Double>
|
|||||||
fun realMatrix(rowNum: Int, colNum: Int, initializer: (i: Int, j: Int) -> Double): RealMatrix =
|
fun realMatrix(rowNum: Int, colNum: Int, initializer: (i: Int, j: Int) -> Double): RealMatrix =
|
||||||
MatrixContext.real.produce(rowNum, colNum, initializer)
|
MatrixContext.real.produce(rowNum, colNum, initializer)
|
||||||
|
|
||||||
fun Array<DoubleArray>.toMatrix(): RealMatrix{
|
fun Array<DoubleArray>.toMatrix(): RealMatrix {
|
||||||
return MatrixContext.real.produce(size, this[0].size) { row, col -> this[row][col] }
|
return MatrixContext.real.produce(size, this[0].size) { row, col -> this[row][col] }
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -117,13 +120,17 @@ operator fun Matrix<Double>.minus(other: Matrix<Double>): RealMatrix =
|
|||||||
* Operations on columns
|
* Operations on columns
|
||||||
*/
|
*/
|
||||||
|
|
||||||
inline fun Matrix<Double>.appendColumn(crossinline mapper: (Buffer<Double>) -> Double) =
|
@OptIn(ExperimentalContracts::class)
|
||||||
MatrixContext.real.produce(rowNum, colNum + 1) { row, col ->
|
inline fun Matrix<Double>.appendColumn(crossinline mapper: (Buffer<Double>) -> Double): Matrix<Double> {
|
||||||
|
contract { callsInPlace(mapper) }
|
||||||
|
|
||||||
|
return MatrixContext.real.produce(rowNum, colNum + 1) { row, col ->
|
||||||
if (col < colNum)
|
if (col < colNum)
|
||||||
this[row, col]
|
this[row, col]
|
||||||
else
|
else
|
||||||
mapper(rows[row])
|
mapper(rows[row])
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fun Matrix<Double>.extractColumns(columnRange: IntRange): RealMatrix =
|
fun Matrix<Double>.extractColumns(columnRange: IntRange): RealMatrix =
|
||||||
MatrixContext.real.produce(rowNum, columnRange.count()) { row, col ->
|
MatrixContext.real.produce(rowNum, columnRange.count()) { row, col ->
|
||||||
@ -135,17 +142,15 @@ fun Matrix<Double>.extractColumn(columnIndex: Int): RealMatrix =
|
|||||||
|
|
||||||
fun Matrix<Double>.sumByColumn(): RealBuffer = RealBuffer(colNum) { j ->
|
fun Matrix<Double>.sumByColumn(): RealBuffer = RealBuffer(colNum) { j ->
|
||||||
val column = columns[j]
|
val column = columns[j]
|
||||||
with(elementContext) {
|
elementContext { sum(column.asIterable()) }
|
||||||
sum(column.asIterable())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fun Matrix<Double>.minByColumn(): RealBuffer = RealBuffer(colNum) { j ->
|
fun Matrix<Double>.minByColumn(): RealBuffer = RealBuffer(colNum) { j ->
|
||||||
columns[j].asIterable().min() ?: throw Exception("Cannot produce min on empty column")
|
columns[j].asIterable().min() ?: error("Cannot produce min on empty column")
|
||||||
}
|
}
|
||||||
|
|
||||||
fun Matrix<Double>.maxByColumn(): RealBuffer = RealBuffer(colNum) { j ->
|
fun Matrix<Double>.maxByColumn(): RealBuffer = RealBuffer(colNum) { j ->
|
||||||
columns[j].asIterable().max() ?: throw Exception("Cannot produce min on empty column")
|
columns[j].asIterable().max() ?: error("Cannot produce min on empty column")
|
||||||
}
|
}
|
||||||
|
|
||||||
fun Matrix<Double>.averageByColumn(): RealBuffer = RealBuffer(colNum) { j ->
|
fun Matrix<Double>.averageByColumn(): RealBuffer = RealBuffer(colNum) { j ->
|
||||||
@ -156,10 +161,7 @@ fun Matrix<Double>.averageByColumn(): RealBuffer = RealBuffer(colNum) { j ->
|
|||||||
* Operations processing all elements
|
* Operations processing all elements
|
||||||
*/
|
*/
|
||||||
|
|
||||||
fun Matrix<Double>.sum() = elements().map { (_, value) -> value }.sum()
|
fun Matrix<Double>.sum(): Double = elements().map { (_, value) -> value }.sum()
|
||||||
|
fun Matrix<Double>.min(): Double? = elements().map { (_, value) -> value }.min()
|
||||||
fun Matrix<Double>.min() = elements().map { (_, value) -> value }.min()
|
fun Matrix<Double>.max(): Double? = elements().map { (_, value) -> value }.max()
|
||||||
|
fun Matrix<Double>.average(): Double = elements().map { (_, value) -> value }.average()
|
||||||
fun Matrix<Double>.max() = elements().map { (_, value) -> value }.max()
|
|
||||||
|
|
||||||
fun Matrix<Double>.average() = elements().map { (_, value) -> value }.average()
|
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
package scientifik.kmath.linear
|
package scientifik.kmath.linear
|
||||||
|
|
||||||
|
import scientifik.kmath.operations.invoke
|
||||||
import scientifik.kmath.real.RealVector
|
import scientifik.kmath.real.RealVector
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
@ -24,14 +25,10 @@ class VectorTest {
|
|||||||
fun testDot() {
|
fun testDot() {
|
||||||
val vector1 = RealVector(5) { it.toDouble() }
|
val vector1 = RealVector(5) { it.toDouble() }
|
||||||
val vector2 = RealVector(5) { 5 - it.toDouble() }
|
val vector2 = RealVector(5) { 5 - it.toDouble() }
|
||||||
|
|
||||||
val matrix1 = vector1.asMatrix()
|
val matrix1 = vector1.asMatrix()
|
||||||
val matrix2 = vector2.asMatrix().transpose()
|
val matrix2 = vector2.asMatrix().transpose()
|
||||||
val product = MatrixContext.real.run { matrix1 dot matrix2 }
|
val product = MatrixContext.real { matrix1 dot matrix2 }
|
||||||
|
|
||||||
|
|
||||||
assertEquals(5.0, product[1, 0])
|
assertEquals(5.0, product[1, 0])
|
||||||
assertEquals(6.0, product[2, 2])
|
assertEquals(6.0, product[2, 2])
|
||||||
}
|
}
|
||||||
|
}
|
||||||
}
|
|
||||||
|
@ -2,6 +2,10 @@ package scientifik.kmath.functions
|
|||||||
|
|
||||||
import scientifik.kmath.operations.Ring
|
import scientifik.kmath.operations.Ring
|
||||||
import scientifik.kmath.operations.Space
|
import scientifik.kmath.operations.Space
|
||||||
|
import scientifik.kmath.operations.invoke
|
||||||
|
import kotlin.contracts.ExperimentalContracts
|
||||||
|
import kotlin.contracts.InvocationKind
|
||||||
|
import kotlin.contracts.contract
|
||||||
import kotlin.math.max
|
import kotlin.math.max
|
||||||
import kotlin.math.pow
|
import kotlin.math.pow
|
||||||
|
|
||||||
@ -13,20 +17,21 @@ inline class Polynomial<T : Any>(val coefficients: List<T>) {
|
|||||||
constructor(vararg coefficients: T) : this(coefficients.toList())
|
constructor(vararg coefficients: T) : this(coefficients.toList())
|
||||||
}
|
}
|
||||||
|
|
||||||
fun Polynomial<Double>.value() =
|
fun Polynomial<Double>.value(): Double =
|
||||||
coefficients.reduceIndexed { index: Int, acc: Double, d: Double -> acc + d.pow(index) }
|
coefficients.reduceIndexed { index: Int, acc: Double, d: Double -> acc + d.pow(index) }
|
||||||
|
|
||||||
|
fun <T : Any, C : Ring<T>> Polynomial<T>.value(ring: C, arg: T): T = ring {
|
||||||
fun <T : Any, C : Ring<T>> Polynomial<T>.value(ring: C, arg: T): T = ring.run {
|
if (coefficients.isEmpty()) return@ring zero
|
||||||
if (coefficients.isEmpty()) return@run zero
|
|
||||||
var res = coefficients.first()
|
var res = coefficients.first()
|
||||||
var powerArg = arg
|
var powerArg = arg
|
||||||
|
|
||||||
for (index in 1 until coefficients.size) {
|
for (index in 1 until coefficients.size) {
|
||||||
res += coefficients[index] * powerArg
|
res += coefficients[index] * powerArg
|
||||||
//recalculating power on each step to avoid power costs on long polynomials
|
//recalculating power on each step to avoid power costs on long polynomials
|
||||||
powerArg *= arg
|
powerArg *= arg
|
||||||
}
|
}
|
||||||
return@run res
|
|
||||||
|
res
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -34,7 +39,7 @@ fun <T : Any, C : Ring<T>> Polynomial<T>.value(ring: C, arg: T): T = ring.run {
|
|||||||
*/
|
*/
|
||||||
fun <T : Any, C : Ring<T>> Polynomial<T>.asMathFunction(): MathFunction<T, out C, T> = object :
|
fun <T : Any, C : Ring<T>> Polynomial<T>.asMathFunction(): MathFunction<T, out C, T> = object :
|
||||||
MathFunction<T, C, T> {
|
MathFunction<T, C, T> {
|
||||||
override fun C.invoke(arg: T): T = value(this, arg)
|
override operator fun C.invoke(arg: T): T = value(this, arg)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -49,18 +54,16 @@ class PolynomialSpace<T : Any, C : Ring<T>>(val ring: C) : Space<Polynomial<T>>
|
|||||||
|
|
||||||
override fun add(a: Polynomial<T>, b: Polynomial<T>): Polynomial<T> {
|
override fun add(a: Polynomial<T>, b: Polynomial<T>): Polynomial<T> {
|
||||||
val dim = max(a.coefficients.size, b.coefficients.size)
|
val dim = max(a.coefficients.size, b.coefficients.size)
|
||||||
ring.run {
|
|
||||||
return Polynomial(List(dim) { index ->
|
return ring {
|
||||||
|
Polynomial(List(dim) { index ->
|
||||||
a.coefficients.getOrElse(index) { zero } + b.coefficients.getOrElse(index) { zero }
|
a.coefficients.getOrElse(index) { zero } + b.coefficients.getOrElse(index) { zero }
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun multiply(a: Polynomial<T>, k: Number): Polynomial<T> {
|
override fun multiply(a: Polynomial<T>, k: Number): Polynomial<T> =
|
||||||
ring.run {
|
ring { Polynomial(List(a.coefficients.size) { index -> a.coefficients[index] * k }) }
|
||||||
return Polynomial(List(a.coefficients.size) { index -> a.coefficients[index] * k })
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
override val zero: Polynomial<T> =
|
override val zero: Polynomial<T> =
|
||||||
Polynomial(emptyList())
|
Polynomial(emptyList())
|
||||||
@ -68,6 +71,8 @@ class PolynomialSpace<T : Any, C : Ring<T>>(val ring: C) : Space<Polynomial<T>>
|
|||||||
operator fun Polynomial<T>.invoke(arg: T): T = value(ring, arg)
|
operator fun Polynomial<T>.invoke(arg: T): T = value(ring, arg)
|
||||||
}
|
}
|
||||||
|
|
||||||
fun <T : Any, C : Ring<T>, R> C.polynomial(block: PolynomialSpace<T, C>.() -> R): R {
|
@OptIn(ExperimentalContracts::class)
|
||||||
return PolynomialSpace(this).run(block)
|
inline fun <T : Any, C : Ring<T>, R> C.polynomial(block: PolynomialSpace<T, C>.() -> R): R {
|
||||||
}
|
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||||
|
return PolynomialSpace(this).block()
|
||||||
|
}
|
||||||
|
@ -4,13 +4,13 @@ import scientifik.kmath.functions.OrderedPiecewisePolynomial
|
|||||||
import scientifik.kmath.functions.PiecewisePolynomial
|
import scientifik.kmath.functions.PiecewisePolynomial
|
||||||
import scientifik.kmath.functions.Polynomial
|
import scientifik.kmath.functions.Polynomial
|
||||||
import scientifik.kmath.operations.Field
|
import scientifik.kmath.operations.Field
|
||||||
|
import scientifik.kmath.operations.invoke
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Reference JVM implementation: https://github.com/apache/commons-math/blob/master/src/main/java/org/apache/commons/math4/analysis/interpolation/LinearInterpolator.java
|
* Reference JVM implementation: https://github.com/apache/commons-math/blob/master/src/main/java/org/apache/commons/math4/analysis/interpolation/LinearInterpolator.java
|
||||||
*/
|
*/
|
||||||
class LinearInterpolator<T : Comparable<T>>(override val algebra: Field<T>) : PolynomialInterpolator<T> {
|
class LinearInterpolator<T : Comparable<T>>(override val algebra: Field<T>) : PolynomialInterpolator<T> {
|
||||||
|
override fun interpolatePolynomials(points: XYPointSet<T, T>): PiecewisePolynomial<T> = algebra {
|
||||||
override fun interpolatePolynomials(points: XYPointSet<T, T>): PiecewisePolynomial<T> = algebra.run {
|
|
||||||
require(points.size > 0) { "Point array should not be empty" }
|
require(points.size > 0) { "Point array should not be empty" }
|
||||||
insureSorted(points)
|
insureSorted(points)
|
||||||
|
|
||||||
@ -23,4 +23,4 @@ class LinearInterpolator<T : Comparable<T>>(override val algebra: Field<T>) : Po
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -4,6 +4,7 @@ import scientifik.kmath.functions.OrderedPiecewisePolynomial
|
|||||||
import scientifik.kmath.functions.PiecewisePolynomial
|
import scientifik.kmath.functions.PiecewisePolynomial
|
||||||
import scientifik.kmath.functions.Polynomial
|
import scientifik.kmath.functions.Polynomial
|
||||||
import scientifik.kmath.operations.Field
|
import scientifik.kmath.operations.Field
|
||||||
|
import scientifik.kmath.operations.invoke
|
||||||
import scientifik.kmath.structures.MutableBufferFactory
|
import scientifik.kmath.structures.MutableBufferFactory
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -17,7 +18,7 @@ class SplineInterpolator<T : Comparable<T>>(
|
|||||||
|
|
||||||
//TODO possibly optimize zeroed buffers
|
//TODO possibly optimize zeroed buffers
|
||||||
|
|
||||||
override fun interpolatePolynomials(points: XYPointSet<T, T>): PiecewisePolynomial<T> = algebra.run {
|
override fun interpolatePolynomials(points: XYPointSet<T, T>): PiecewisePolynomial<T> = algebra {
|
||||||
if (points.size < 3) {
|
if (points.size < 3) {
|
||||||
error("Can't use spline interpolator with less than 3 points")
|
error("Can't use spline interpolator with less than 3 points")
|
||||||
}
|
}
|
||||||
|
@ -14,9 +14,7 @@ interface XYZPointSet<X, Y, Z> : XYPointSet<X, Y> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
internal fun <T : Comparable<T>> insureSorted(points: XYPointSet<T, *>) {
|
internal fun <T : Comparable<T>> insureSorted(points: XYPointSet<T, *>) {
|
||||||
for (i in 0 until points.size - 1) {
|
for (i in 0 until points.size - 1) require(points.x[i + 1] > points.x[i]) { "Input data is not sorted at index $i" }
|
||||||
if (points.x[i + 1] <= points.x[i]) error("Input data is not sorted at index $i")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
class NDStructureColumn<T>(val structure: Structure2D<T>, val column: Int) : Buffer<T> {
|
class NDStructureColumn<T>(val structure: Structure2D<T>, val column: Int) : Buffer<T> {
|
||||||
@ -26,9 +24,9 @@ class NDStructureColumn<T>(val structure: Structure2D<T>, val column: Int) : Buf
|
|||||||
|
|
||||||
override val size: Int get() = structure.rowNum
|
override val size: Int get() = structure.rowNum
|
||||||
|
|
||||||
override fun get(index: Int): T = structure[index, column]
|
override operator fun get(index: Int): T = structure[index, column]
|
||||||
|
|
||||||
override fun iterator(): Iterator<T> = sequence {
|
override operator fun iterator(): Iterator<T> = sequence {
|
||||||
repeat(size) {
|
repeat(size) {
|
||||||
yield(get(it))
|
yield(get(it))
|
||||||
}
|
}
|
||||||
|
@ -9,25 +9,21 @@ import kotlin.math.sqrt
|
|||||||
interface Vector2D : Point<Double>, Vector, SpaceElement<Vector2D, Vector2D, Euclidean2DSpace> {
|
interface Vector2D : Point<Double>, Vector, SpaceElement<Vector2D, Vector2D, Euclidean2DSpace> {
|
||||||
val x: Double
|
val x: Double
|
||||||
val y: Double
|
val y: Double
|
||||||
|
override val context: Euclidean2DSpace get() = Euclidean2DSpace
|
||||||
override val size: Int get() = 2
|
override val size: Int get() = 2
|
||||||
|
|
||||||
override fun get(index: Int): Double = when (index) {
|
override operator fun get(index: Int): Double = when (index) {
|
||||||
1 -> x
|
1 -> x
|
||||||
2 -> y
|
2 -> y
|
||||||
else -> error("Accessing outside of point bounds")
|
else -> error("Accessing outside of point bounds")
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun iterator(): Iterator<Double> = listOf(x, y).iterator()
|
override operator fun iterator(): Iterator<Double> = listOf(x, y).iterator()
|
||||||
|
|
||||||
override val context: Euclidean2DSpace get() = Euclidean2DSpace
|
|
||||||
|
|
||||||
override fun unwrap(): Vector2D = this
|
override fun unwrap(): Vector2D = this
|
||||||
|
|
||||||
override fun Vector2D.wrap(): Vector2D = this
|
override fun Vector2D.wrap(): Vector2D = this
|
||||||
}
|
}
|
||||||
|
|
||||||
val Vector2D.r: Double get() = Euclidean2DSpace.run { sqrt(norm()) }
|
val Vector2D.r: Double get() = Euclidean2DSpace { sqrt(norm()) }
|
||||||
|
|
||||||
@Suppress("FunctionName")
|
@Suppress("FunctionName")
|
||||||
fun Vector2D(x: Double, y: Double): Vector2D = Vector2DImpl(x, y)
|
fun Vector2D(x: Double, y: Double): Vector2D = Vector2DImpl(x, y)
|
||||||
|
@ -2,6 +2,7 @@ package scientifik.kmath.geometry
|
|||||||
|
|
||||||
import scientifik.kmath.linear.Point
|
import scientifik.kmath.linear.Point
|
||||||
import scientifik.kmath.operations.SpaceElement
|
import scientifik.kmath.operations.SpaceElement
|
||||||
|
import scientifik.kmath.operations.invoke
|
||||||
import kotlin.math.sqrt
|
import kotlin.math.sqrt
|
||||||
|
|
||||||
|
|
||||||
@ -9,19 +10,17 @@ interface Vector3D : Point<Double>, Vector, SpaceElement<Vector3D, Vector3D, Euc
|
|||||||
val x: Double
|
val x: Double
|
||||||
val y: Double
|
val y: Double
|
||||||
val z: Double
|
val z: Double
|
||||||
|
override val context: Euclidean3DSpace get() = Euclidean3DSpace
|
||||||
override val size: Int get() = 3
|
override val size: Int get() = 3
|
||||||
|
|
||||||
override fun get(index: Int): Double = when (index) {
|
override operator fun get(index: Int): Double = when (index) {
|
||||||
1 -> x
|
1 -> x
|
||||||
2 -> y
|
2 -> y
|
||||||
3 -> z
|
3 -> z
|
||||||
else -> error("Accessing outside of point bounds")
|
else -> error("Accessing outside of point bounds")
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun iterator(): Iterator<Double> = listOf(x, y, z).iterator()
|
override operator fun iterator(): Iterator<Double> = listOf(x, y, z).iterator()
|
||||||
|
|
||||||
override val context: Euclidean3DSpace get() = Euclidean3DSpace
|
|
||||||
|
|
||||||
override fun unwrap(): Vector3D = this
|
override fun unwrap(): Vector3D = this
|
||||||
|
|
||||||
@ -31,7 +30,7 @@ interface Vector3D : Point<Double>, Vector, SpaceElement<Vector3D, Vector3D, Euc
|
|||||||
@Suppress("FunctionName")
|
@Suppress("FunctionName")
|
||||||
fun Vector3D(x: Double, y: Double, z: Double): Vector3D = Vector3DImpl(x, y, z)
|
fun Vector3D(x: Double, y: Double, z: Double): Vector3D = Vector3DImpl(x, y, z)
|
||||||
|
|
||||||
val Vector3D.r: Double get() = Euclidean3DSpace.run { sqrt(norm()) }
|
val Vector3D.r: Double get() = Euclidean3DSpace { sqrt(norm()) }
|
||||||
|
|
||||||
private data class Vector3DImpl(
|
private data class Vector3DImpl(
|
||||||
override val x: Double,
|
override val x: Double,
|
||||||
@ -54,4 +53,4 @@ object Euclidean3DSpace : GeometrySpace<Vector3D> {
|
|||||||
|
|
||||||
override fun Vector3D.dot(other: Vector3D): Double =
|
override fun Vector3D.dot(other: Vector3D): Double =
|
||||||
x * other.x + y * other.y + z * other.z
|
x * other.x + y * other.y + z * other.z
|
||||||
}
|
}
|
||||||
|
@ -4,6 +4,9 @@ import scientifik.kmath.domains.Domain
|
|||||||
import scientifik.kmath.linear.Point
|
import scientifik.kmath.linear.Point
|
||||||
import scientifik.kmath.structures.ArrayBuffer
|
import scientifik.kmath.structures.ArrayBuffer
|
||||||
import scientifik.kmath.structures.RealBuffer
|
import scientifik.kmath.structures.RealBuffer
|
||||||
|
import kotlin.contracts.ExperimentalContracts
|
||||||
|
import kotlin.contracts.InvocationKind
|
||||||
|
import kotlin.contracts.contract
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The bin in the histogram. The histogram is by definition always done in the real space
|
* The bin in the histogram. The histogram is by definition always done in the real space
|
||||||
@ -37,20 +40,20 @@ interface MutableHistogram<T : Any, out B : Bin<T>> : Histogram<T, B> {
|
|||||||
*/
|
*/
|
||||||
fun putWithWeight(point: Point<out T>, weight: Double)
|
fun putWithWeight(point: Point<out T>, weight: Double)
|
||||||
|
|
||||||
fun put(point: Point<out T>) = putWithWeight(point, 1.0)
|
fun put(point: Point<out T>): Unit = putWithWeight(point, 1.0)
|
||||||
}
|
}
|
||||||
|
|
||||||
fun <T : Any> MutableHistogram<T, *>.put(vararg point: T) = put(ArrayBuffer(point))
|
fun <T : Any> MutableHistogram<T, *>.put(vararg point: T): Unit = put(ArrayBuffer(point))
|
||||||
|
|
||||||
fun MutableHistogram<Double, *>.put(vararg point: Number) =
|
fun MutableHistogram<Double, *>.put(vararg point: Number): Unit =
|
||||||
put(RealBuffer(point.map { it.toDouble() }.toDoubleArray()))
|
put(RealBuffer(point.map { it.toDouble() }.toDoubleArray()))
|
||||||
|
|
||||||
fun MutableHistogram<Double, *>.put(vararg point: Double) = put(RealBuffer(point))
|
fun MutableHistogram<Double, *>.put(vararg point: Double): Unit = put(RealBuffer(point))
|
||||||
|
|
||||||
fun <T : Any> MutableHistogram<T, *>.fill(sequence: Iterable<Point<T>>) = sequence.forEach { put(it) }
|
fun <T : Any> MutableHistogram<T, *>.fill(sequence: Iterable<Point<T>>): Unit = sequence.forEach { put(it) }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Pass a sequence builder into histogram
|
* Pass a sequence builder into histogram
|
||||||
*/
|
*/
|
||||||
fun <T : Any> MutableHistogram<T, *>.fill(buider: suspend SequenceScope<Point<T>>.() -> Unit) =
|
fun <T : Any> MutableHistogram<T, *>.fill(block: suspend SequenceScope<Point<T>>.() -> Unit): Unit =
|
||||||
fill(sequence(buider).asIterable())
|
fill(sequence(block).asIterable())
|
||||||
|
@ -2,6 +2,7 @@ package scientifik.kmath.histogram
|
|||||||
|
|
||||||
import scientifik.kmath.linear.Point
|
import scientifik.kmath.linear.Point
|
||||||
import scientifik.kmath.operations.SpaceOperations
|
import scientifik.kmath.operations.SpaceOperations
|
||||||
|
import scientifik.kmath.operations.invoke
|
||||||
import scientifik.kmath.real.asVector
|
import scientifik.kmath.real.asVector
|
||||||
import scientifik.kmath.structures.*
|
import scientifik.kmath.structures.*
|
||||||
import kotlin.math.floor
|
import kotlin.math.floor
|
||||||
@ -9,19 +10,16 @@ import kotlin.math.floor
|
|||||||
|
|
||||||
data class BinDef<T : Comparable<T>>(val space: SpaceOperations<Point<T>>, val center: Point<T>, val sizes: Point<T>) {
|
data class BinDef<T : Comparable<T>>(val space: SpaceOperations<Point<T>>, val center: Point<T>, val sizes: Point<T>) {
|
||||||
fun contains(vector: Point<out T>): Boolean {
|
fun contains(vector: Point<out T>): Boolean {
|
||||||
if (vector.size != center.size) error("Dimension mismatch for input vector. Expected ${center.size}, but found ${vector.size}")
|
require(vector.size == center.size) { "Dimension mismatch for input vector. Expected ${center.size}, but found ${vector.size}" }
|
||||||
val upper = space.run { center + sizes / 2.0 }
|
val upper = space { center + sizes / 2.0 }
|
||||||
val lower = space.run { center - sizes / 2.0 }
|
val lower = space { center - sizes / 2.0 }
|
||||||
return vector.asSequence().mapIndexed { i, value ->
|
return vector.asSequence().mapIndexed { i, value -> value in lower[i]..upper[i] }.all { it }
|
||||||
value in lower[i]..upper[i]
|
|
||||||
}.all { it }
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class MultivariateBin<T : Comparable<T>>(val def: BinDef<T>, override val value: Number) : Bin<T> {
|
class MultivariateBin<T : Comparable<T>>(val def: BinDef<T>, override val value: Number) : Bin<T> {
|
||||||
|
override operator fun contains(point: Point<T>): Boolean = def.contains(point)
|
||||||
override fun contains(point: Point<T>): Boolean = def.contains(point)
|
|
||||||
|
|
||||||
override val dimension: Int
|
override val dimension: Int
|
||||||
get() = def.center.size
|
get() = def.center.size
|
||||||
@ -39,47 +37,34 @@ class RealHistogram(
|
|||||||
private val upper: Buffer<Double>,
|
private val upper: Buffer<Double>,
|
||||||
private val binNums: IntArray = IntArray(lower.size) { 20 }
|
private val binNums: IntArray = IntArray(lower.size) { 20 }
|
||||||
) : MutableHistogram<Double, MultivariateBin<Double>> {
|
) : MutableHistogram<Double, MultivariateBin<Double>> {
|
||||||
|
|
||||||
|
|
||||||
private val strides = DefaultStrides(IntArray(binNums.size) { binNums[it] + 2 })
|
private val strides = DefaultStrides(IntArray(binNums.size) { binNums[it] + 2 })
|
||||||
|
|
||||||
private val values: NDStructure<LongCounter> = NDStructure.auto(strides) { LongCounter() }
|
private val values: NDStructure<LongCounter> = NDStructure.auto(strides) { LongCounter() }
|
||||||
|
|
||||||
private val weights: NDStructure<DoubleCounter> = NDStructure.auto(strides) { DoubleCounter() }
|
private val weights: NDStructure<DoubleCounter> = NDStructure.auto(strides) { DoubleCounter() }
|
||||||
|
|
||||||
override val dimension: Int get() = lower.size
|
override val dimension: Int get() = lower.size
|
||||||
|
|
||||||
|
|
||||||
private val binSize = RealBuffer(dimension) { (upper[it] - lower[it]) / binNums[it] }
|
private val binSize = RealBuffer(dimension) { (upper[it] - lower[it]) / binNums[it] }
|
||||||
|
|
||||||
init {
|
init {
|
||||||
// argument checks
|
// argument checks
|
||||||
if (lower.size != upper.size) error("Dimension mismatch in histogram lower and upper limits.")
|
require(lower.size == upper.size) { "Dimension mismatch in histogram lower and upper limits." }
|
||||||
if (lower.size != binNums.size) error("Dimension mismatch in bin count.")
|
require(lower.size == binNums.size) { "Dimension mismatch in bin count." }
|
||||||
if ((0 until dimension).any { upper[it] - lower[it] < 0 }) error("Range for one of axis is not strictly positive")
|
require(!(0 until dimension).any { upper[it] - lower[it] < 0 }) { "Range for one of axis is not strictly positive" }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get internal [NDStructure] bin index for given axis
|
* Get internal [NDStructure] bin index for given axis
|
||||||
*/
|
*/
|
||||||
private fun getIndex(axis: Int, value: Double): Int {
|
private fun getIndex(axis: Int, value: Double): Int = when {
|
||||||
return when {
|
value >= upper[axis] -> binNums[axis] + 1 // overflow
|
||||||
value >= upper[axis] -> binNums[axis] + 1 // overflow
|
value < lower[axis] -> 0 // underflow
|
||||||
value < lower[axis] -> 0 // underflow
|
else -> floor((value - lower[axis]) / binSize[axis]).toInt() + 1
|
||||||
else -> floor((value - lower[axis]) / binSize[axis]).toInt() + 1
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun getIndex(point: Buffer<out Double>): IntArray = IntArray(dimension) { getIndex(it, point[it]) }
|
private fun getIndex(point: Buffer<out Double>): IntArray = IntArray(dimension) { getIndex(it, point[it]) }
|
||||||
|
|
||||||
private fun getValue(index: IntArray): Long {
|
private fun getValue(index: IntArray): Long = values[index].sum()
|
||||||
return values[index].sum()
|
|
||||||
}
|
|
||||||
|
|
||||||
fun getValue(point: Buffer<out Double>): Long {
|
fun getValue(point: Buffer<out Double>): Long = getValue(getIndex(point))
|
||||||
return getValue(getIndex(point))
|
|
||||||
}
|
|
||||||
|
|
||||||
private fun getDef(index: IntArray): BinDef<Double> {
|
private fun getDef(index: IntArray): BinDef<Double> {
|
||||||
val center = index.mapIndexed { axis, i ->
|
val center = index.mapIndexed { axis, i ->
|
||||||
@ -89,14 +74,13 @@ class RealHistogram(
|
|||||||
else -> lower[axis] + (i.toDouble() - 0.5) * binSize[axis]
|
else -> lower[axis] + (i.toDouble() - 0.5) * binSize[axis]
|
||||||
}
|
}
|
||||||
}.asBuffer()
|
}.asBuffer()
|
||||||
|
|
||||||
return BinDef(RealBufferFieldOperations, center, binSize)
|
return BinDef(RealBufferFieldOperations, center, binSize)
|
||||||
}
|
}
|
||||||
|
|
||||||
fun getDef(point: Buffer<out Double>): BinDef<Double> {
|
fun getDef(point: Buffer<out Double>): BinDef<Double> = getDef(getIndex(point))
|
||||||
return getDef(getIndex(point))
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun get(point: Buffer<out Double>): MultivariateBin<Double>? {
|
override operator fun get(point: Buffer<out Double>): MultivariateBin<Double>? {
|
||||||
val index = getIndex(point)
|
val index = getIndex(point)
|
||||||
return MultivariateBin(getDef(index), getValue(index))
|
return MultivariateBin(getDef(index), getValue(index))
|
||||||
}
|
}
|
||||||
@ -112,26 +96,21 @@ class RealHistogram(
|
|||||||
weights[index].add(weight)
|
weights[index].add(weight)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun iterator(): Iterator<MultivariateBin<Double>> = weights.elements().map { (index, value) ->
|
override operator fun iterator(): Iterator<MultivariateBin<Double>> = weights.elements().map { (index, value) ->
|
||||||
MultivariateBin(getDef(index), value.sum())
|
MultivariateBin(getDef(index), value.sum())
|
||||||
}.iterator()
|
}.iterator()
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Convert this histogram into NDStructure containing bin values but not bin descriptions
|
* Convert this histogram into NDStructure containing bin values but not bin descriptions
|
||||||
*/
|
*/
|
||||||
fun values(): NDStructure<Number> {
|
fun values(): NDStructure<Number> = NDStructure.auto(values.shape) { values[it].sum() }
|
||||||
return NDStructure.auto(values.shape) { values[it].sum() }
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Sum of weights
|
* Sum of weights
|
||||||
*/
|
*/
|
||||||
fun weights():NDStructure<Double>{
|
fun weights(): NDStructure<Double> = NDStructure.auto(weights.shape) { weights[it].sum() }
|
||||||
return NDStructure.auto(weights.shape) { weights[it].sum() }
|
|
||||||
}
|
|
||||||
|
|
||||||
companion object {
|
companion object {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Use it like
|
* Use it like
|
||||||
* ```
|
* ```
|
||||||
@ -141,12 +120,10 @@ class RealHistogram(
|
|||||||
*)
|
*)
|
||||||
*```
|
*```
|
||||||
*/
|
*/
|
||||||
fun fromRanges(vararg ranges: ClosedFloatingPointRange<Double>): RealHistogram {
|
fun fromRanges(vararg ranges: ClosedFloatingPointRange<Double>): RealHistogram = RealHistogram(
|
||||||
return RealHistogram(
|
ranges.map { it.start }.asVector(),
|
||||||
ranges.map { it.start }.asVector(),
|
ranges.map { it.endInclusive }.asVector()
|
||||||
ranges.map { it.endInclusive }.asVector()
|
)
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Use it like
|
* Use it like
|
||||||
@ -157,13 +134,10 @@ class RealHistogram(
|
|||||||
*)
|
*)
|
||||||
*```
|
*```
|
||||||
*/
|
*/
|
||||||
fun fromRanges(vararg ranges: Pair<ClosedFloatingPointRange<Double>, Int>): RealHistogram {
|
fun fromRanges(vararg ranges: Pair<ClosedFloatingPointRange<Double>, Int>): RealHistogram = RealHistogram(
|
||||||
return RealHistogram(
|
ListBuffer(ranges.map { it.first.start }),
|
||||||
ListBuffer(ranges.map { it.first.start }),
|
ListBuffer(ranges.map { it.first.endInclusive }),
|
||||||
ListBuffer(ranges.map { it.first.endInclusive }),
|
ranges.map { it.second }.toIntArray()
|
||||||
ranges.map { it.second }.toIntArray()
|
)
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
}
|
|
||||||
|
@ -46,11 +46,11 @@ class UnivariateHistogram private constructor(private val factory: (Double) -> U
|
|||||||
synchronized(this) { bins.put(it.position, it) }
|
synchronized(this) { bins.put(it.position, it) }
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun get(point: Buffer<out Double>): UnivariateBin? = get(point[0])
|
override operator fun get(point: Buffer<out Double>): UnivariateBin? = get(point[0])
|
||||||
|
|
||||||
override val dimension: Int get() = 1
|
override val dimension: Int get() = 1
|
||||||
|
|
||||||
override fun iterator(): Iterator<UnivariateBin> = bins.values.iterator()
|
override operator fun iterator(): Iterator<UnivariateBin> = bins.values.iterator()
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Thread safe put operation
|
* Thread safe put operation
|
||||||
@ -65,15 +65,14 @@ class UnivariateHistogram private constructor(private val factory: (Double) -> U
|
|||||||
}
|
}
|
||||||
|
|
||||||
companion object {
|
companion object {
|
||||||
fun uniform(binSize: Double, start: Double = 0.0): UnivariateHistogram {
|
fun uniform(binSize: Double, start: Double = 0.0): UnivariateHistogram = UnivariateHistogram { value ->
|
||||||
return UnivariateHistogram { value ->
|
val center = start + binSize * floor((value - start) / binSize + 0.5)
|
||||||
val center = start + binSize * floor((value - start) / binSize + 0.5)
|
UnivariateBin(center, binSize)
|
||||||
UnivariateBin(center, binSize)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fun custom(borders: DoubleArray): UnivariateHistogram {
|
fun custom(borders: DoubleArray): UnivariateHistogram {
|
||||||
val sorted = borders.sortedArray()
|
val sorted = borders.sortedArray()
|
||||||
|
|
||||||
return UnivariateHistogram { value ->
|
return UnivariateHistogram { value ->
|
||||||
when {
|
when {
|
||||||
value < sorted.first() -> UnivariateBin(
|
value < sorted.first() -> UnivariateBin(
|
||||||
|
@ -3,16 +3,16 @@ package scientifik.kmath.linear
|
|||||||
import koma.extensions.fill
|
import koma.extensions.fill
|
||||||
import koma.matrix.MatrixFactory
|
import koma.matrix.MatrixFactory
|
||||||
import scientifik.kmath.operations.Space
|
import scientifik.kmath.operations.Space
|
||||||
|
import scientifik.kmath.operations.invoke
|
||||||
import scientifik.kmath.structures.Matrix
|
import scientifik.kmath.structures.Matrix
|
||||||
import scientifik.kmath.structures.NDStructure
|
import scientifik.kmath.structures.NDStructure
|
||||||
|
|
||||||
class KomaMatrixContext<T : Any>(
|
class KomaMatrixContext<T : Any>(
|
||||||
private val factory: MatrixFactory<koma.matrix.Matrix<T>>,
|
private val factory: MatrixFactory<koma.matrix.Matrix<T>>,
|
||||||
private val space: Space<T>
|
private val space: Space<T>
|
||||||
) :
|
) : MatrixContext<T> {
|
||||||
MatrixContext<T> {
|
|
||||||
|
|
||||||
override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T) =
|
override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): KomaMatrix<T> =
|
||||||
KomaMatrix(factory.zeros(rows, columns).fill(initializer))
|
KomaMatrix(factory.zeros(rows, columns).fill(initializer))
|
||||||
|
|
||||||
fun Matrix<T>.toKoma(): KomaMatrix<T> = if (this is KomaMatrix) {
|
fun Matrix<T>.toKoma(): KomaMatrix<T> = if (this is KomaMatrix) {
|
||||||
@ -28,31 +28,28 @@ class KomaMatrixContext<T : Any>(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
override fun Matrix<T>.dot(other: Matrix<T>) =
|
override fun Matrix<T>.dot(other: Matrix<T>): KomaMatrix<T> =
|
||||||
KomaMatrix(this.toKoma().origin * other.toKoma().origin)
|
KomaMatrix(toKoma().origin * other.toKoma().origin)
|
||||||
|
|
||||||
override fun Matrix<T>.dot(vector: Point<T>) =
|
override fun Matrix<T>.dot(vector: Point<T>): KomaVector<T> =
|
||||||
KomaVector(this.toKoma().origin * vector.toKoma().origin)
|
KomaVector(toKoma().origin * vector.toKoma().origin)
|
||||||
|
|
||||||
override fun Matrix<T>.unaryMinus() =
|
override operator fun Matrix<T>.unaryMinus(): KomaMatrix<T> =
|
||||||
KomaMatrix(this.toKoma().origin.unaryMinus())
|
KomaMatrix(toKoma().origin.unaryMinus())
|
||||||
|
|
||||||
override fun add(a: Matrix<T>, b: Matrix<T>) =
|
override fun add(a: Matrix<T>, b: Matrix<T>): KomaMatrix<T> =
|
||||||
KomaMatrix(a.toKoma().origin + b.toKoma().origin)
|
KomaMatrix(a.toKoma().origin + b.toKoma().origin)
|
||||||
|
|
||||||
override fun Matrix<T>.minus(b: Matrix<T>) =
|
override operator fun Matrix<T>.minus(b: Matrix<T>): KomaMatrix<T> =
|
||||||
KomaMatrix(this.toKoma().origin - b.toKoma().origin)
|
KomaMatrix(toKoma().origin - b.toKoma().origin)
|
||||||
|
|
||||||
override fun multiply(a: Matrix<T>, k: Number): Matrix<T> =
|
override fun multiply(a: Matrix<T>, k: Number): Matrix<T> =
|
||||||
produce(a.rowNum, a.colNum) { i, j -> space.run { a[i, j] * k } }
|
produce(a.rowNum, a.colNum) { i, j -> space { a[i, j] * k } }
|
||||||
|
|
||||||
override fun Matrix<T>.times(value: T) =
|
override operator fun Matrix<T>.times(value: T): KomaMatrix<T> =
|
||||||
KomaMatrix(this.toKoma().origin * value)
|
KomaMatrix(toKoma().origin * value)
|
||||||
|
|
||||||
companion object {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
companion object
|
||||||
}
|
}
|
||||||
|
|
||||||
fun <T : Any> KomaMatrixContext<T>.solve(a: Matrix<T>, b: Matrix<T>) =
|
fun <T : Any> KomaMatrixContext<T>.solve(a: Matrix<T>, b: Matrix<T>) =
|
||||||
@ -70,10 +67,11 @@ class KomaMatrix<T : Any>(val origin: koma.matrix.Matrix<T>, features: Set<Matri
|
|||||||
|
|
||||||
override val shape: IntArray get() = intArrayOf(origin.numRows(), origin.numCols())
|
override val shape: IntArray get() = intArrayOf(origin.numRows(), origin.numCols())
|
||||||
|
|
||||||
override val features: Set<MatrixFeature> = features ?: setOf(
|
override val features: Set<MatrixFeature> = features ?: hashSetOf(
|
||||||
object : DeterminantFeature<T> {
|
object : DeterminantFeature<T> {
|
||||||
override val determinant: T get() = origin.det()
|
override val determinant: T get() = origin.det()
|
||||||
},
|
},
|
||||||
|
|
||||||
object : LUPDecompositionFeature<T> {
|
object : LUPDecompositionFeature<T> {
|
||||||
private val lup by lazy { origin.LU() }
|
private val lup by lazy { origin.LU() }
|
||||||
override val l: FeaturedMatrix<T> get() = KomaMatrix(lup.second)
|
override val l: FeaturedMatrix<T> get() = KomaMatrix(lup.second)
|
||||||
@ -85,7 +83,7 @@ class KomaMatrix<T : Any>(val origin: koma.matrix.Matrix<T>, features: Set<Matri
|
|||||||
override fun suggestFeature(vararg features: MatrixFeature): FeaturedMatrix<T> =
|
override fun suggestFeature(vararg features: MatrixFeature): FeaturedMatrix<T> =
|
||||||
KomaMatrix(this.origin, this.features + features)
|
KomaMatrix(this.origin, this.features + features)
|
||||||
|
|
||||||
override fun get(i: Int, j: Int): T = origin.getGeneric(i, j)
|
override operator fun get(i: Int, j: Int): T = origin.getGeneric(i, j)
|
||||||
|
|
||||||
override fun equals(other: Any?): Boolean {
|
override fun equals(other: Any?): Boolean {
|
||||||
return NDStructure.equals(this, other as? NDStructure<*> ?: return false)
|
return NDStructure.equals(this, other as? NDStructure<*> ?: return false)
|
||||||
@ -101,14 +99,12 @@ class KomaMatrix<T : Any>(val origin: koma.matrix.Matrix<T>, features: Set<Matri
|
|||||||
}
|
}
|
||||||
|
|
||||||
class KomaVector<T : Any> internal constructor(val origin: koma.matrix.Matrix<T>) : Point<T> {
|
class KomaVector<T : Any> internal constructor(val origin: koma.matrix.Matrix<T>) : Point<T> {
|
||||||
init {
|
|
||||||
if (origin.numCols() != 1) error("Only single column matrices are allowed")
|
|
||||||
}
|
|
||||||
|
|
||||||
override val size: Int get() = origin.numRows()
|
override val size: Int get() = origin.numRows()
|
||||||
|
|
||||||
override fun get(index: Int): T = origin.getGeneric(index)
|
init {
|
||||||
|
require(origin.numCols() == 1) { error("Only single column matrices are allowed") }
|
||||||
|
}
|
||||||
|
|
||||||
override fun iterator(): Iterator<T> = origin.toIterable().iterator()
|
override operator fun get(index: Int): T = origin.getGeneric(index)
|
||||||
|
override operator fun iterator(): Iterator<T> = origin.toIterable().iterator()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,5 +1,9 @@
|
|||||||
package scientifik.memory
|
package scientifik.memory
|
||||||
|
|
||||||
|
import kotlin.contracts.ExperimentalContracts
|
||||||
|
import kotlin.contracts.InvocationKind
|
||||||
|
import kotlin.contracts.contract
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Represents a display of certain memory structure.
|
* Represents a display of certain memory structure.
|
||||||
*/
|
*/
|
||||||
@ -80,7 +84,9 @@ interface MemoryReader {
|
|||||||
/**
|
/**
|
||||||
* Uses the memory for read then releases the reader.
|
* Uses the memory for read then releases the reader.
|
||||||
*/
|
*/
|
||||||
|
@OptIn(ExperimentalContracts::class)
|
||||||
inline fun Memory.read(block: MemoryReader.() -> Unit) {
|
inline fun Memory.read(block: MemoryReader.() -> Unit) {
|
||||||
|
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||||
reader().apply(block).release()
|
reader().apply(block).release()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -132,7 +138,9 @@ interface MemoryWriter {
|
|||||||
/**
|
/**
|
||||||
* Uses the memory for write then releases the writer.
|
* Uses the memory for write then releases the writer.
|
||||||
*/
|
*/
|
||||||
|
@OptIn(ExperimentalContracts::class)
|
||||||
inline fun Memory.write(block: MemoryWriter.() -> Unit) {
|
inline fun Memory.write(block: MemoryWriter.() -> Unit) {
|
||||||
|
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||||
writer().apply(block).release()
|
writer().apply(block).release()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -38,11 +38,7 @@ fun <T : Any> MemoryWriter.write(spec: MemorySpec<T>, offset: Int, value: T): Un
|
|||||||
* Reads array of [size] objects mapped by [spec] at certain [offset].
|
* Reads array of [size] objects mapped by [spec] at certain [offset].
|
||||||
*/
|
*/
|
||||||
inline fun <reified T : Any> MemoryReader.readArray(spec: MemorySpec<T>, offset: Int, size: Int): Array<T> =
|
inline fun <reified T : Any> MemoryReader.readArray(spec: MemorySpec<T>, offset: Int, size: Int): Array<T> =
|
||||||
Array(size) { i ->
|
Array(size) { i -> with(spec) { read(offset + i * objectSize) } }
|
||||||
spec.run {
|
|
||||||
read(offset + i * objectSize)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Writes [array] of objects mapped by [spec] at certain [offset].
|
* Writes [array] of objects mapped by [spec] at certain [offset].
|
||||||
|
@ -1,12 +1,17 @@
|
|||||||
package scientifik.memory
|
package scientifik.memory
|
||||||
|
|
||||||
|
import java.io.IOException
|
||||||
import java.nio.ByteBuffer
|
import java.nio.ByteBuffer
|
||||||
import java.nio.channels.FileChannel
|
import java.nio.channels.FileChannel
|
||||||
import java.nio.file.Files
|
import java.nio.file.Files
|
||||||
import java.nio.file.Path
|
import java.nio.file.Path
|
||||||
import java.nio.file.StandardOpenOption
|
import java.nio.file.StandardOpenOption
|
||||||
|
import kotlin.contracts.ExperimentalContracts
|
||||||
|
import kotlin.contracts.InvocationKind
|
||||||
|
import kotlin.contracts.contract
|
||||||
|
|
||||||
private class ByteBufferMemory(
|
@PublishedApi
|
||||||
|
internal class ByteBufferMemory(
|
||||||
val buffer: ByteBuffer,
|
val buffer: ByteBuffer,
|
||||||
val startOffset: Int = 0,
|
val startOffset: Int = 0,
|
||||||
override val size: Int = buffer.limit()
|
override val size: Int = buffer.limit()
|
||||||
@ -112,7 +117,12 @@ fun ByteBuffer.asMemory(startOffset: Int = 0, size: Int = limit()): Memory =
|
|||||||
/**
|
/**
|
||||||
* Uses direct memory-mapped buffer from file to read something and close it afterwards.
|
* Uses direct memory-mapped buffer from file to read something and close it afterwards.
|
||||||
*/
|
*/
|
||||||
fun <R> Path.readAsMemory(position: Long = 0, size: Long = Files.size(this), block: Memory.() -> R): R =
|
@OptIn(ExperimentalContracts::class)
|
||||||
FileChannel.open(this, StandardOpenOption.READ).use {
|
@Throws(IOException::class)
|
||||||
ByteBufferMemory(it.map(FileChannel.MapMode.READ_ONLY, position, size)).block()
|
inline fun <R> Path.readAsMemory(position: Long = 0, size: Long = Files.size(this), block: Memory.() -> R): R {
|
||||||
}
|
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||||
|
|
||||||
|
return FileChannel
|
||||||
|
.open(this, StandardOpenOption.READ)
|
||||||
|
.use { ByteBufferMemory(it.map(FileChannel.MapMode.READ_ONLY, position, size)).block() }
|
||||||
|
}
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package scientifik.kmath.prob
|
package scientifik.kmath.prob
|
||||||
|
|
||||||
import scientifik.kmath.chains.Chain
|
import scientifik.kmath.chains.Chain
|
||||||
|
import kotlin.contracts.ExperimentalContracts
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A possibly stateful chain producing random values.
|
* A possibly stateful chain producing random values.
|
||||||
@ -11,4 +12,4 @@ class RandomChain<out R>(val generator: RandomGenerator, private val gen: suspen
|
|||||||
override fun fork(): Chain<R> = RandomChain(generator.fork(), gen)
|
override fun fork(): Chain<R> = RandomChain(generator.fork(), gen)
|
||||||
}
|
}
|
||||||
|
|
||||||
fun <R> RandomGenerator.chain(gen: suspend RandomGenerator.() -> R): RandomChain<R> = RandomChain(this, gen)
|
fun <R> RandomGenerator.chain(gen: suspend RandomGenerator.() -> R): RandomChain<R> = RandomChain(this, gen)
|
||||||
|
@ -5,6 +5,7 @@ import scientifik.kmath.chains.ConstantChain
|
|||||||
import scientifik.kmath.chains.map
|
import scientifik.kmath.chains.map
|
||||||
import scientifik.kmath.chains.zip
|
import scientifik.kmath.chains.zip
|
||||||
import scientifik.kmath.operations.Space
|
import scientifik.kmath.operations.Space
|
||||||
|
import scientifik.kmath.operations.invoke
|
||||||
|
|
||||||
class BasicSampler<T : Any>(val chainBuilder: (RandomGenerator) -> Chain<T>) : Sampler<T> {
|
class BasicSampler<T : Any>(val chainBuilder: (RandomGenerator) -> Chain<T>) : Sampler<T> {
|
||||||
override fun sample(generator: RandomGenerator): Chain<T> = chainBuilder(generator)
|
override fun sample(generator: RandomGenerator): Chain<T> = chainBuilder(generator)
|
||||||
@ -22,10 +23,10 @@ class SamplerSpace<T : Any>(val space: Space<T>) : Space<Sampler<T>> {
|
|||||||
override val zero: Sampler<T> = ConstantSampler(space.zero)
|
override val zero: Sampler<T> = ConstantSampler(space.zero)
|
||||||
|
|
||||||
override fun add(a: Sampler<T>, b: Sampler<T>): Sampler<T> = BasicSampler { generator ->
|
override fun add(a: Sampler<T>, b: Sampler<T>): Sampler<T> = BasicSampler { generator ->
|
||||||
a.sample(generator).zip(b.sample(generator)) { aValue, bValue -> space.run { aValue + bValue } }
|
a.sample(generator).zip(b.sample(generator)) { aValue, bValue -> space { aValue + bValue } }
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun multiply(a: Sampler<T>, k: Number): Sampler<T> = BasicSampler { generator ->
|
override fun multiply(a: Sampler<T>, k: Number): Sampler<T> = BasicSampler { generator ->
|
||||||
a.sample(generator).map { space.run { it * k.toDouble() } }
|
a.sample(generator).map { space { it * k.toDouble() } }
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -29,8 +29,10 @@ interface Statistic<T, R> {
|
|||||||
interface ComposableStatistic<T, I, R> : Statistic<T, R> {
|
interface ComposableStatistic<T, I, R> : Statistic<T, R> {
|
||||||
//compute statistic on a single block
|
//compute statistic on a single block
|
||||||
suspend fun computeIntermediate(data: Buffer<T>): I
|
suspend fun computeIntermediate(data: Buffer<T>): I
|
||||||
|
|
||||||
//Compose two blocks
|
//Compose two blocks
|
||||||
suspend fun composeIntermediate(first: I, second: I): I
|
suspend fun composeIntermediate(first: I, second: I): I
|
||||||
|
|
||||||
//Transform block to result
|
//Transform block to result
|
||||||
suspend fun toResult(intermediate: I): R
|
suspend fun toResult(intermediate: I): R
|
||||||
|
|
||||||
@ -58,26 +60,26 @@ private fun <T, I, R> ComposableStatistic<T, I, R>.flowIntermediate(
|
|||||||
fun <T, I, R> ComposableStatistic<T, I, R>.flow(
|
fun <T, I, R> ComposableStatistic<T, I, R>.flow(
|
||||||
flow: Flow<Buffer<T>>,
|
flow: Flow<Buffer<T>>,
|
||||||
dispatcher: CoroutineDispatcher = Dispatchers.Default
|
dispatcher: CoroutineDispatcher = Dispatchers.Default
|
||||||
): Flow<R> = flowIntermediate(flow,dispatcher).map(::toResult)
|
): Flow<R> = flowIntermediate(flow, dispatcher).map(::toResult)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Arithmetic mean
|
* Arithmetic mean
|
||||||
*/
|
*/
|
||||||
class Mean<T>(val space: Space<T>) : ComposableStatistic<T, Pair<T, Int>, T> {
|
class Mean<T>(val space: Space<T>) : ComposableStatistic<T, Pair<T, Int>, T> {
|
||||||
override suspend fun computeIntermediate(data: Buffer<T>): Pair<T, Int> =
|
override suspend fun computeIntermediate(data: Buffer<T>): Pair<T, Int> =
|
||||||
space.run { sum(data.asIterable()) } to data.size
|
space { sum(data.asIterable()) } to data.size
|
||||||
|
|
||||||
override suspend fun composeIntermediate(first: Pair<T, Int>, second: Pair<T, Int>): Pair<T, Int> =
|
override suspend fun composeIntermediate(first: Pair<T, Int>, second: Pair<T, Int>): Pair<T, Int> =
|
||||||
space.run { first.first + second.first } to (first.second + second.second)
|
space { first.first + second.first } to (first.second + second.second)
|
||||||
|
|
||||||
override suspend fun toResult(intermediate: Pair<T, Int>): T =
|
override suspend fun toResult(intermediate: Pair<T, Int>): T =
|
||||||
space.run { intermediate.first / intermediate.second }
|
space { intermediate.first / intermediate.second }
|
||||||
|
|
||||||
companion object {
|
companion object {
|
||||||
//TODO replace with optimized version which respects overflow
|
//TODO replace with optimized version which respects overflow
|
||||||
val real = Mean(RealField)
|
val real: Mean<Double> = Mean(RealField)
|
||||||
val int = Mean(IntRing)
|
val int: Mean<Int> = Mean(IntRing)
|
||||||
val long = Mean(LongRing)
|
val long: Mean<Long> = Mean(LongRing)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -85,11 +87,10 @@ class Mean<T>(val space: Space<T>) : ComposableStatistic<T, Pair<T, Int>, T> {
|
|||||||
* Non-composable median
|
* Non-composable median
|
||||||
*/
|
*/
|
||||||
class Median<T>(private val comparator: Comparator<T>) : Statistic<T, T> {
|
class Median<T>(private val comparator: Comparator<T>) : Statistic<T, T> {
|
||||||
override suspend fun invoke(data: Buffer<T>): T {
|
override suspend fun invoke(data: Buffer<T>): T =
|
||||||
return data.asSequence().sortedWith(comparator).toList()[data.size / 2] //TODO check if this is correct
|
data.asSequence().sortedWith(comparator).toList()[data.size / 2] //TODO check if this is correct
|
||||||
}
|
|
||||||
|
|
||||||
companion object {
|
companion object {
|
||||||
val real = Median(Comparator { a: Double, b: Double -> a.compareTo(b) })
|
val real: Median<Double> = Median(Comparator { a: Double, b: Double -> a.compareTo(b) })
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -16,5 +16,5 @@ inline class ViktorBuffer(val flatArray: F64FlatArray) : MutableBuffer<Double> {
|
|||||||
return ViktorBuffer(flatArray.copy().flatten())
|
return ViktorBuffer(flatArray.copy().flatten())
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun iterator(): Iterator<Double> = flatArray.data.iterator()
|
override operator fun iterator(): Iterator<Double> = flatArray.data.iterator()
|
||||||
}
|
}
|
Loading…
Reference in New Issue
Block a user