Code style update + RealNDField tweaks

This commit is contained in:
Alexander Nozik 2019-01-04 18:12:28 +03:00
parent 53d752dee8
commit c0a43c1bd1
40 changed files with 665 additions and 341 deletions

View File

@ -19,7 +19,7 @@ open class ArrayBenchmark {
arrayBuffer = IntBuffer.wrap(array) arrayBuffer = IntBuffer.wrap(array)
nativeBuffer = IntBuffer.allocate(10000) nativeBuffer = IntBuffer.allocate(10000)
for (i in 0 until 10000) { for (i in 0 until 10000) {
nativeBuffer.put(i,i) nativeBuffer.put(i, i)
} }
} }

View File

@ -22,12 +22,12 @@ open class BufferBenchmark {
@Benchmark @Benchmark
fun complexBufferReadWrite() { fun complexBufferReadWrite() {
val buffer = Complex.createBuffer(size/2) val buffer = Complex.createBuffer(size / 2)
(0 until size/2).forEach { (0 until size / 2).forEach {
buffer[it] = Complex(it.toDouble(), -it.toDouble()) buffer[it] = Complex(it.toDouble(), -it.toDouble())
} }
(0 until size/2).forEach { (0 until size / 2).forEach {
buffer[it] buffer[it]
} }
} }

View File

@ -5,26 +5,50 @@ import kotlin.system.measureTimeMillis
fun main(args: Array<String>) { fun main(args: Array<String>) {
val dim = 1000 val dim = 1000
val n = 1000 val n = 10000
val genericField = NDField.generic(intArrayOf(dim, dim), DoubleField) val bufferedField = NDField.buffered(intArrayOf(dim, dim), DoubleField)
val doubleField = NDField.inline(intArrayOf(dim, dim), DoubleField)
val specializedField = NDField.real(intArrayOf(dim, dim)) val specializedField = NDField.real(intArrayOf(dim, dim))
val genericField = NDField.generic(intArrayOf(dim, dim), DoubleField)
// val action: NDField<Double, DoubleField, NDStructure<Double>>.() -> Unit = {
// var res = one
// repeat(n) {
// res += 1.0
// }
// }
val doubleTime = measureTimeMillis { val doubleTime = measureTimeMillis {
var res = doubleField.produce { one }
bufferedField.run {
var res: NDBuffer<Double> = one
repeat(n) {
res += 1.0
}
}
}
println("Buffered addition completed in $doubleTime millis")
val elementTime = measureTimeMillis {
var res = bufferedField.produce { one }
repeat(n) { repeat(n) {
res += 1.0 res += 1.0
} }
} }
println("Inlined addition completed in $doubleTime millis") println("Element addition completed in $elementTime millis")
val specializedTime = measureTimeMillis { val specializedTime = measureTimeMillis {
var res = specializedField.produce { one } //specializedField.run(action)
repeat(n) { specializedField.run {
res += 1.0 var res: NDBuffer<Double> = one
repeat(n) {
res += 1.0
}
} }
} }
@ -32,9 +56,12 @@ fun main(args: Array<String>) {
val genericTime = measureTimeMillis { val genericTime = measureTimeMillis {
var res = genericField.produce { one } //genericField.run(action)
repeat(n) { genericField.run {
res += 1.0 var res = one
repeat(n) {
res += 1.0
}
} }
} }

View File

@ -3,7 +3,6 @@ package scientifik.kmath.structures
import kotlin.system.measureTimeMillis import kotlin.system.measureTimeMillis
fun main(args: Array<String>) { fun main(args: Array<String>) {
val n = 6000 val n = 6000
@ -17,21 +16,21 @@ fun main(args: Array<String>) {
} }
println("Structure mapping finished in $time1 millis") println("Structure mapping finished in $time1 millis")
val array = DoubleArray(n*n){1.0} val array = DoubleArray(n * n) { 1.0 }
val time2 = measureTimeMillis { val time2 = measureTimeMillis {
val target = DoubleArray(n*n) val target = DoubleArray(n * n)
val res = array.forEachIndexed{index, value -> val res = array.forEachIndexed { index, value ->
target[index] = value + 1 target[index] = value + 1
} }
} }
println("Array mapping finished in $time2 millis") println("Array mapping finished in $time2 millis")
val buffer = DoubleBuffer(DoubleArray(n*n){1.0}) val buffer = DoubleBuffer(DoubleArray(n * n) { 1.0 })
val time3 = measureTimeMillis { val time3 = measureTimeMillis {
val target = DoubleBuffer(DoubleArray(n*n)) val target = DoubleBuffer(DoubleArray(n * n))
val res = array.forEachIndexed{index, value -> val res = array.forEachIndexed { index, value ->
target[index] = value + 1 target[index] = value + 1
} }
} }

View File

@ -1,6 +1,6 @@
buildscript { buildscript {
extra["kotlinVersion"] = "1.3.20-eap-52" extra["kotlinVersion"] = "1.3.20-eap-52"
extra["ioVersion"] = "0.1.2-dev-2" extra["ioVersion"] = "0.1.2"
extra["coroutinesVersion"] = "1.1.0" extra["coroutinesVersion"] = "1.1.0"
val kotlinVersion: String by extra val kotlinVersion: String by extra
@ -8,7 +8,7 @@ buildscript {
val coroutinesVersion: String by extra val coroutinesVersion: String by extra
repositories { repositories {
maven ("https://dl.bintray.com/kotlin/kotlin-eap") maven("https://dl.bintray.com/kotlin/kotlin-eap")
jcenter() jcenter()
} }
@ -19,7 +19,7 @@ buildscript {
} }
plugins { plugins {
id("com.jfrog.artifactory") version "4.8.1" apply false id("com.jfrog.artifactory") version "4.8.1" apply false
// id("org.jetbrains.kotlin.multiplatform") apply false // id("org.jetbrains.kotlin.multiplatform") apply false
} }
@ -28,14 +28,14 @@ allprojects {
apply(plugin = "com.jfrog.artifactory") apply(plugin = "com.jfrog.artifactory")
group = "scientifik" group = "scientifik"
version = "0.0.2-dev-1" version = "0.0.3-dev-1"
repositories{ repositories {
maven ("https://dl.bintray.com/kotlin/kotlin-eap") maven("https://dl.bintray.com/kotlin/kotlin-eap")
jcenter() jcenter()
} }
} }
if(file("artifactory.gradle").exists()){ if (file("artifactory.gradle").exists()) {
apply(from = "artifactory.gradle") apply(from = "artifactory.gradle")
} }

View File

@ -26,23 +26,28 @@ internal class ConstantExpression<T>(val value: T) : Expression<T> {
override fun invoke(arguments: Map<String, T>): T = value override fun invoke(arguments: Map<String, T>): T = value
} }
internal class SumExpression<T>(val context: Space<T>, val first: Expression<T>, val second: Expression<T>) : Expression<T> { internal class SumExpression<T>(val context: Space<T>, val first: Expression<T>, val second: Expression<T>) :
Expression<T> {
override fun invoke(arguments: Map<String, T>): T = context.add(first.invoke(arguments), second.invoke(arguments)) override fun invoke(arguments: Map<String, T>): T = context.add(first.invoke(arguments), second.invoke(arguments))
} }
internal class ProductExpression<T>(val context: Field<T>, val first: Expression<T>, val second: Expression<T>) : Expression<T> { internal class ProductExpression<T>(val context: Field<T>, val first: Expression<T>, val second: Expression<T>) :
override fun invoke(arguments: Map<String, T>): T = context.multiply(first.invoke(arguments), second.invoke(arguments)) Expression<T> {
override fun invoke(arguments: Map<String, T>): T =
context.multiply(first.invoke(arguments), second.invoke(arguments))
} }
internal class ConstProductExpession<T>(val context: Field<T>, val expr: Expression<T>, val const: Double) : Expression<T> { internal class ConstProductExpession<T>(val context: Field<T>, val expr: Expression<T>, val const: Double) :
Expression<T> {
override fun invoke(arguments: Map<String, T>): T = context.multiply(expr.invoke(arguments), const) override fun invoke(arguments: Map<String, T>): T = context.multiply(expr.invoke(arguments), const)
} }
internal class DivExpession<T>(val context: Field<T>, val expr: Expression<T>, val second: Expression<T>) : Expression<T> { internal class DivExpession<T>(val context: Field<T>, val expr: Expression<T>, val second: Expression<T>) :
Expression<T> {
override fun invoke(arguments: Map<String, T>): T = context.divide(expr.invoke(arguments), second.invoke(arguments)) override fun invoke(arguments: Map<String, T>): T = context.divide(expr.invoke(arguments), second.invoke(arguments))
} }
class FieldExpressionContext<T>(val field: Field<T>) : Field<Expression<T>>, ExpressionContext<T> { class ExpressionField<T>(val field: Field<T>) : Field<Expression<T>>, ExpressionContext<T> {
override val zero: Expression<T> = ConstantExpression(field.zero) override val zero: Expression<T> = ConstantExpression(field.zero)
@ -59,4 +64,15 @@ class FieldExpressionContext<T>(val field: Field<T>) : Field<Expression<T>>, Exp
override fun multiply(a: Expression<T>, b: Expression<T>): Expression<T> = ProductExpression(field, a, b) override fun multiply(a: Expression<T>, b: Expression<T>): Expression<T> = ProductExpression(field, a, b)
override fun divide(a: Expression<T>, b: Expression<T>): Expression<T> = DivExpession(field, a, b) override fun divide(a: Expression<T>, b: Expression<T>): Expression<T> = DivExpession(field, a, b)
operator fun Expression<T>.plus(arg: T) = this + const(arg)
operator fun Expression<T>.minus(arg: T) = this - const(arg)
operator fun Expression<T>.times(arg: T) = this * const(arg)
operator fun Expression<T>.div(arg: T) = this / const(arg)
operator fun T.plus(arg: Expression<T>) = arg + this
operator fun T.minus(arg: Expression<T>) = arg - this
operator fun T.times(arg: Expression<T>) = arg * this
operator fun T.div(arg: Expression<T>) = arg / this
} }

View File

@ -5,16 +5,16 @@ package scientifik.kmath.histogram
*/ */
expect class LongCounter(){ expect class LongCounter() {
fun decrement() fun decrement()
fun increment() fun increment()
fun reset() fun reset()
fun sum(): Long fun sum(): Long
fun add(l:Long) fun add(l: Long)
} }
expect class DoubleCounter(){ expect class DoubleCounter() {
fun reset() fun reset()
fun sum(): Double fun sum(): Double
fun add(d: Double) fun add(d: Double)
} }

View File

@ -6,15 +6,16 @@ import kotlin.math.floor
private operator fun RealPoint.minus(other: RealPoint) = ListBuffer((0 until size).map { get(it) - other[it] }) private operator fun RealPoint.minus(other: RealPoint) = ListBuffer((0 until size).map { get(it) - other[it] })
private inline fun <T> Buffer<out Double>.mapIndexed(crossinline mapper: (Int, Double) -> T): Sequence<T> = (0 until size).asSequence().map { mapper(it, get(it)) } private inline fun <T> Buffer<out Double>.mapIndexed(crossinline mapper: (Int, Double) -> T): Sequence<T> =
(0 until size).asSequence().map { mapper(it, get(it)) }
/** /**
* Uniform multivariate histogram with fixed borders. Based on NDStructure implementation with complexity of m for bin search, where m is the number of dimensions. * Uniform multivariate histogram with fixed borders. Based on NDStructure implementation with complexity of m for bin search, where m is the number of dimensions.
*/ */
class FastHistogram( class FastHistogram(
private val lower: RealPoint, private val lower: RealPoint,
private val upper: RealPoint, private val upper: RealPoint,
private val binNums: IntArray = IntArray(lower.size) { 20 } private val binNums: IntArray = IntArray(lower.size) { 20 }
) : MutableHistogram<Double, PhantomBin<Double>> { ) : MutableHistogram<Double, PhantomBin<Double>> {
@ -25,7 +26,8 @@ class FastHistogram(
//private val weight: NDStructure<DoubleCounter?> = ndStructure(strides){null} //private val weight: NDStructure<DoubleCounter?> = ndStructure(strides){null}
//TODO optimize binSize performance if needed //TODO optimize binSize performance if needed
private val binSize: RealPoint = ListBuffer((upper - lower).mapIndexed { index, value -> value / binNums[index] }.toList()) private val binSize: RealPoint =
ListBuffer((upper - lower).mapIndexed { index, value -> value / binNums[index] }.toList())
init { init {
// argument checks // argument checks
@ -130,9 +132,9 @@ class FastHistogram(
*/ */
fun fromRanges(vararg ranges: Pair<ClosedFloatingPointRange<Double>, Int>): FastHistogram { fun fromRanges(vararg ranges: Pair<ClosedFloatingPointRange<Double>, Int>): FastHistogram {
return FastHistogram( return FastHistogram(
ListBuffer(ranges.map { it.first.start }), ListBuffer(ranges.map { it.first.start }),
ListBuffer(ranges.map { it.first.endInclusive }), ListBuffer(ranges.map { it.first.endInclusive }),
ranges.map { it.second }.toIntArray() ranges.map { it.second }.toIntArray()
) )
} }
} }

View File

@ -1,10 +1,10 @@
package scientifik.kmath.histogram package scientifik.kmath.histogram
import scientifik.kmath.linear.Point
import scientifik.kmath.structures.ArrayBuffer import scientifik.kmath.structures.ArrayBuffer
import scientifik.kmath.structures.Buffer import scientifik.kmath.structures.Buffer
import scientifik.kmath.structures.DoubleBuffer import scientifik.kmath.structures.DoubleBuffer
typealias Point<T> = Buffer<T>
typealias RealPoint = Buffer<Double> typealias RealPoint = Buffer<Double>
@ -12,7 +12,7 @@ typealias RealPoint = Buffer<Double>
* A simple geometric domain * A simple geometric domain
* TODO move to geometry module * TODO move to geometry module
*/ */
interface Domain<T: Any> { interface Domain<T : Any> {
operator fun contains(vector: Point<out T>): Boolean operator fun contains(vector: Point<out T>): Boolean
val dimension: Int val dimension: Int
} }
@ -20,7 +20,7 @@ interface Domain<T: Any> {
/** /**
* The bin in the histogram. The histogram is by definition always done in the real space * The bin in the histogram. The histogram is by definition always done in the real space
*/ */
interface Bin<T: Any> : Domain<T> { interface Bin<T : Any> : Domain<T> {
/** /**
* The value of this bin * The value of this bin
*/ */
@ -28,7 +28,7 @@ interface Bin<T: Any> : Domain<T> {
val center: Point<T> val center: Point<T>
} }
interface Histogram<T: Any, out B : Bin<T>> : Iterable<B> { interface Histogram<T : Any, out B : Bin<T>> : Iterable<B> {
/** /**
* Find existing bin, corresponding to given coordinates * Find existing bin, corresponding to given coordinates
@ -42,7 +42,7 @@ interface Histogram<T: Any, out B : Bin<T>> : Iterable<B> {
} }
interface MutableHistogram<T: Any, out B : Bin<T>>: Histogram<T,B>{ interface MutableHistogram<T : Any, out B : Bin<T>> : Histogram<T, B> {
/** /**
* Increment appropriate bin * Increment appropriate bin
@ -50,14 +50,17 @@ interface MutableHistogram<T: Any, out B : Bin<T>>: Histogram<T,B>{
fun put(point: Point<out T>, weight: Double = 1.0) fun put(point: Point<out T>, weight: Double = 1.0)
} }
fun <T: Any> MutableHistogram<T,*>.put(vararg point: T) = put(ArrayBuffer(point)) fun <T : Any> MutableHistogram<T, *>.put(vararg point: T) = put(ArrayBuffer(point))
fun MutableHistogram<Double,*>.put(vararg point: Number) = put(DoubleBuffer(point.map { it.toDouble() }.toDoubleArray())) fun MutableHistogram<Double, *>.put(vararg point: Number) =
fun MutableHistogram<Double,*>.put(vararg point: Double) = put(DoubleBuffer(point)) put(DoubleBuffer(point.map { it.toDouble() }.toDoubleArray()))
fun <T: Any> MutableHistogram<T,*>.fill(sequence: Iterable<Point<T>>) = sequence.forEach { put(it) } fun MutableHistogram<Double, *>.put(vararg point: Double) = put(DoubleBuffer(point))
fun <T : Any> MutableHistogram<T, *>.fill(sequence: Iterable<Point<T>>) = sequence.forEach { put(it) }
/** /**
* Pass a sequence builder into histogram * Pass a sequence builder into histogram
*/ */
fun <T: Any> MutableHistogram<T, *>.fill(buider: suspend SequenceScope<Point<T>>.() -> Unit) = fill(sequence(buider).asIterable()) fun <T : Any> MutableHistogram<T, *>.fill(buider: suspend SequenceScope<Point<T>>.() -> Unit) =
fill(sequence(buider).asIterable())

View File

@ -1,5 +1,6 @@
package scientifik.kmath.histogram package scientifik.kmath.histogram
import scientifik.kmath.linear.Point
import scientifik.kmath.linear.Vector import scientifik.kmath.linear.Vector
import scientifik.kmath.operations.Space import scientifik.kmath.operations.Space
import scientifik.kmath.structures.NDStructure import scientifik.kmath.structures.NDStructure
@ -8,8 +9,8 @@ import scientifik.kmath.structures.asSequence
data class BinTemplate<T : Comparable<T>>(val center: Vector<T, *>, val sizes: Point<T>) { data class BinTemplate<T : Comparable<T>>(val center: Vector<T, *>, val sizes: Point<T>) {
fun contains(vector: Point<out T>): Boolean { fun contains(vector: Point<out T>): Boolean {
if (vector.size != center.size) error("Dimension mismatch for input vector. Expected ${center.size}, but found ${vector.size}") if (vector.size != center.size) error("Dimension mismatch for input vector. Expected ${center.size}, but found ${vector.size}")
val upper = center.context.run { center + sizes / 2.0} val upper = center.context.run { center + sizes / 2.0 }
val lower = center.context.run {center - sizes / 2.0} val lower = center.context.run { center - sizes / 2.0 }
return vector.asSequence().mapIndexed { i, value -> return vector.asSequence().mapIndexed { i, value ->
value in lower[i]..upper[i] value in lower[i]..upper[i]
}.all { it } }.all { it }
@ -44,8 +45,8 @@ class PhantomBin<T : Comparable<T>>(val template: BinTemplate<T>, override val v
* @param bins map a template into structure index * @param bins map a template into structure index
*/ */
class PhantomHistogram<T : Comparable<T>>( class PhantomHistogram<T : Comparable<T>>(
val bins: Map<BinTemplate<T>, IntArray>, val bins: Map<BinTemplate<T>, IntArray>,
val data: NDStructure<Number> val data: NDStructure<Number>
) : Histogram<T, PhantomBin<T>> { ) : Histogram<T, PhantomBin<T>> {
override val dimension: Int override val dimension: Int

View File

@ -104,7 +104,10 @@ abstract class LUDecomposition<T : Comparable<T>, F : Field<T>>(val matrix: Matr
val m = matrix.numCols val m = matrix.numCols
val pivot = IntArray(matrix.numRows) val pivot = IntArray(matrix.numRows)
//TODO fix performance //TODO fix performance
val lu: MutableNDStructure<T> = mutableNdStructure(intArrayOf(matrix.numRows, matrix.numCols), ::boxingMutableBuffer) { index: IntArray -> matrix[index[0], index[1]] } val lu: MutableNDStructure<T> = mutableNdStructure(
intArrayOf(matrix.numRows, matrix.numCols),
::boxingMutableBuffer
) { index: IntArray -> matrix[index[0], index[1]] }
with(matrix.context.ring) { with(matrix.context.ring) {
@ -180,7 +183,8 @@ abstract class LUDecomposition<T : Comparable<T>, F : Field<T>>(val matrix: Matr
} }
class RealLUDecomposition(matrix: RealMatrix, private val singularityThreshold: Double = DEFAULT_TOO_SMALL) : LUDecomposition<Double, DoubleField>(matrix) { class RealLUDecomposition(matrix: RealMatrix, private val singularityThreshold: Double = DEFAULT_TOO_SMALL) :
LUDecomposition<Double, DoubleField>(matrix) {
override fun isSingular(value: Double): Boolean { override fun isSingular(value: Double): Boolean {
return value.absoluteValue < singularityThreshold return value.absoluteValue < singularityThreshold
} }
@ -195,7 +199,8 @@ class RealLUDecomposition(matrix: RealMatrix, private val singularityThreshold:
/** Specialized solver. */ /** Specialized solver. */
object RealLUSolver : LinearSolver<Double, DoubleField> { object RealLUSolver : LinearSolver<Double, DoubleField> {
fun decompose(mat: Matrix<Double, DoubleField>, threshold: Double = 1e-11): RealLUDecomposition = RealLUDecomposition(mat, threshold) fun decompose(mat: Matrix<Double, DoubleField>, threshold: Double = 1e-11): RealLUDecomposition =
RealLUDecomposition(mat, threshold)
override fun solve(a: RealMatrix, b: RealMatrix): RealMatrix { override fun solve(a: RealMatrix, b: RealMatrix): RealMatrix {
val decomposition = decompose(a, a.context.ring.zero) val decomposition = decompose(a, a.context.ring.zero)

View File

@ -1,6 +1,5 @@
package scientifik.kmath.linear package scientifik.kmath.linear
import scientifik.kmath.histogram.Point
import scientifik.kmath.operations.DoubleField import scientifik.kmath.operations.DoubleField
import scientifik.kmath.operations.Ring import scientifik.kmath.operations.Ring
import scientifik.kmath.operations.Space import scientifik.kmath.operations.Space
@ -33,25 +32,34 @@ interface MatrixSpace<T : Any, R : Ring<T>> : Space<Matrix<T, R>> {
val one get() = produce { i, j -> if (i == j) ring.one else ring.zero } val one get() = produce { i, j -> if (i == j) ring.one else ring.zero }
override fun add(a: Matrix<T, R>, b: Matrix<T, R>): Matrix<T, R> = produce(rowNum, colNum) { i, j -> ring.run { a[i, j] + b[i, j] } } override fun add(a: Matrix<T, R>, b: Matrix<T, R>): Matrix<T, R> =
produce(rowNum, colNum) { i, j -> ring.run { a[i, j] + b[i, j] } }
override fun multiply(a: Matrix<T, R>, k: Double): Matrix<T, R> = produce(rowNum, colNum) { i, j -> ring.run { a[i, j] * k } } override fun multiply(a: Matrix<T, R>, k: Double): Matrix<T, R> =
produce(rowNum, colNum) { i, j -> ring.run { a[i, j] * k } }
companion object { companion object {
/** /**
* Non-boxing double matrix * Non-boxing double matrix
*/ */
fun real(rows: Int, columns: Int): MatrixSpace<Double, DoubleField> = StructureMatrixSpace(rows, columns, DoubleField, DoubleBufferFactory) fun real(rows: Int, columns: Int): MatrixSpace<Double, DoubleField> =
StructureMatrixSpace(rows, columns, DoubleField, DoubleBufferFactory)
/** /**
* A structured matrix with custom buffer * A structured matrix with custom buffer
*/ */
fun <T : Any, R : Ring<T>> buffered(rows: Int, columns: Int, ring: R, bufferFactory: BufferFactory<T> = ::boxingBuffer): MatrixSpace<T, R> = StructureMatrixSpace(rows, columns, ring, bufferFactory) fun <T : Any, R : Ring<T>> buffered(
rows: Int,
columns: Int,
ring: R,
bufferFactory: BufferFactory<T> = ::boxingBuffer
): MatrixSpace<T, R> = StructureMatrixSpace(rows, columns, ring, bufferFactory)
/** /**
* Automatic buffered matrix, unboxed if it is possible * Automatic buffered matrix, unboxed if it is possible
*/ */
inline fun <reified T : Any, R : Ring<T>> smart(rows: Int, columns: Int, ring: R): MatrixSpace<T, R> = buffered(rows, columns, ring, ::inlineBuffer) inline fun <reified T : Any, R : Ring<T>> smart(rows: Int, columns: Int, ring: R): MatrixSpace<T, R> =
buffered(rows, columns, ring, ::inlineBuffer)
} }
} }
@ -59,7 +67,7 @@ interface MatrixSpace<T : Any, R : Ring<T>> : Space<Matrix<T, R>> {
/** /**
* Specialized 2-d structure * Specialized 2-d structure
*/ */
interface Matrix<T : Any, R : Ring<T>> : NDStructure<T>, SpaceElement<Matrix<T, R>, MatrixSpace<T, R>> { interface Matrix<T : Any, R : Ring<T>> : NDStructure<T>, SpaceElement<Matrix<T, R>, Matrix<T, R>, MatrixSpace<T, R>> {
operator fun get(i: Int, j: Int): T operator fun get(i: Int, j: Int): T
override fun get(index: IntArray): T = get(index[0], index[1]) override fun get(index: IntArray): T = get(index[0], index[1])
@ -82,7 +90,7 @@ interface Matrix<T : Any, R : Ring<T>> : NDStructure<T>, SpaceElement<Matrix<T,
companion object { companion object {
fun real(rows: Int, columns: Int, initializer: (Int, Int) -> Double) = fun real(rows: Int, columns: Int, initializer: (Int, Int) -> Double) =
MatrixSpace.real(rows, columns).produce(rows, columns, initializer) MatrixSpace.real(rows, columns).produce(rows, columns, initializer)
} }
} }
@ -110,10 +118,10 @@ infix fun <T : Any, R : Ring<T>> Matrix<T, R>.dot(vector: Point<T>): Point<T> {
} }
data class StructureMatrixSpace<T : Any, R : Ring<T>>( data class StructureMatrixSpace<T : Any, R : Ring<T>>(
override val rowNum: Int, override val rowNum: Int,
override val colNum: Int, override val colNum: Int,
override val ring: R, override val ring: R,
private val bufferFactory: BufferFactory<T> private val bufferFactory: BufferFactory<T>
) : MatrixSpace<T, R> { ) : MatrixSpace<T, R> {
override val shape: IntArray = intArrayOf(rowNum, colNum) override val shape: IntArray = intArrayOf(rowNum, colNum)
@ -134,15 +142,21 @@ data class StructureMatrixSpace<T : Any, R : Ring<T>>(
override fun point(size: Int, initializer: (Int) -> T): Point<T> = bufferFactory(size, initializer) override fun point(size: Int, initializer: (Int) -> T): Point<T> = bufferFactory(size, initializer)
} }
data class StructureMatrix<T : Any, R : Ring<T>>(override val context: StructureMatrixSpace<T, R>, val structure: NDStructure<T>) : Matrix<T, R> { data class StructureMatrix<T : Any, R : Ring<T>>(
override val context: StructureMatrixSpace<T, R>,
val structure: NDStructure<T>
) : Matrix<T, R> {
init { init {
if (structure.shape.size != 2 || structure.shape[0] != context.rowNum || structure.shape[1] != context.colNum) { if (structure.shape.size != 2 || structure.shape[0] != context.rowNum || structure.shape[1] != context.colNum) {
error("Dimension mismatch for structure, (${context.rowNum}, ${context.colNum}) expected, but ${structure.shape} found") error("Dimension mismatch for structure, (${context.rowNum}, ${context.colNum}) expected, but ${structure.shape} found")
} }
} }
override fun unwrap(): Matrix<T, R> = this
override fun Matrix<T, R>.wrap(): Matrix<T, R> = this
override val shape: IntArray get() = structure.shape override val shape: IntArray get() = structure.shape
override val self: Matrix<T, R> get() = this
override fun get(index: IntArray): T = structure[index] override fun get(index: IntArray): T = structure[index]
@ -153,4 +167,4 @@ data class StructureMatrix<T : Any, R : Ring<T>>(override val context: Structure
//TODO produce transposed matrix via reference without creating new space and structure //TODO produce transposed matrix via reference without creating new space and structure
fun <T : Any, R : Ring<T>> Matrix<T, R>.transpose(): Matrix<T, R> = fun <T : Any, R : Ring<T>> Matrix<T, R>.transpose(): Matrix<T, R> =
context.produce(numCols, numRows) { i, j -> get(j, i) } context.produce(numCols, numRows) { i, j -> get(j, i) }

View File

@ -1,11 +1,12 @@
package scientifik.kmath.linear package scientifik.kmath.linear
import scientifik.kmath.histogram.Point
import scientifik.kmath.operations.DoubleField import scientifik.kmath.operations.DoubleField
import scientifik.kmath.operations.Space import scientifik.kmath.operations.Space
import scientifik.kmath.operations.SpaceElement import scientifik.kmath.operations.SpaceElement
import scientifik.kmath.structures.* import scientifik.kmath.structures.*
typealias Point<T> = Buffer<T>
/** /**
* A linear space for vectors. * A linear space for vectors.
* Could be used on any point-like structure * Could be used on any point-like structure
@ -45,12 +46,17 @@ interface VectorSpace<T : Any, S : Space<T>> : Space<Point<T>> {
/** /**
* A structured vector space with custom buffer * A structured vector space with custom buffer
*/ */
fun <T : Any, S : Space<T>> buffered(size: Int, space: S, bufferFactory: BufferFactory<T> = ::boxingBuffer): VectorSpace<T, S> = BufferVectorSpace(size, space, bufferFactory) fun <T : Any, S : Space<T>> buffered(
size: Int,
space: S,
bufferFactory: BufferFactory<T> = ::boxingBuffer
): VectorSpace<T, S> = BufferVectorSpace(size, space, bufferFactory)
/** /**
* Automatic buffered vector, unboxed if it is possible * Automatic buffered vector, unboxed if it is possible
*/ */
inline fun <reified T : Any, S : Space<T>> smart(size: Int, space: S): VectorSpace<T, S> = buffered(size, space, ::inlineBuffer) inline fun <reified T : Any, S : Space<T>> smart(size: Int, space: S): VectorSpace<T, S> =
buffered(size, space, ::inlineBuffer)
} }
} }
@ -58,38 +64,42 @@ interface VectorSpace<T : Any, S : Space<T>> : Space<Point<T>> {
/** /**
* A point coupled to the linear space * A point coupled to the linear space
*/ */
interface Vector<T : Any, S : Space<T>> : SpaceElement<Vector<T,S>, VectorSpace<T, S>>, Point<T> { interface Vector<T : Any, S : Space<T>> : SpaceElement<Point<T>, Vector<T, S>, VectorSpace<T, S>>, Point<T> {
override val size: Int get() = context.size override val size: Int get() = context.size
override operator fun plus(b: Point<T>): Vector<T, S> = context.add(self, b) override operator fun plus(b: Point<T>): Vector<T, S> = context.add(this, b).wrap()
override operator fun minus(b: Point<T>): Vector<T, S> = context.add(self, context.multiply(b, -1.0)) override operator fun minus(b: Point<T>): Vector<T, S> = context.add(this, context.multiply(b, -1.0)).wrap()
override operator fun times(k: Number): Vector<T, S> = context.multiply(self, k.toDouble()) override operator fun times(k: Number): Vector<T, S> = context.multiply(this, k.toDouble()).wrap()
override operator fun div(k: Number): Vector<T, S> = context.multiply(self, 1.0 / k.toDouble()) override operator fun div(k: Number): Vector<T, S> = context.multiply(this, 1.0 / k.toDouble()).wrap()
companion object { companion object {
/** /**
* Create vector with custom field * Create vector with custom field
*/ */
fun <T : Any, S : Space<T>> generic(size: Int, field: S, initializer: (Int) -> T): Vector<T, S> = fun <T : Any, S : Space<T>> generic(size: Int, field: S, initializer: (Int) -> T): Vector<T, S> =
VectorSpace.buffered(size, field).produceElement(initializer) VectorSpace.buffered(size, field).produceElement(initializer)
fun real(size: Int, initializer: (Int) -> Double): Vector<Double,DoubleField> = VectorSpace.real(size).produceElement(initializer) fun real(size: Int, initializer: (Int) -> Double): Vector<Double, DoubleField> =
fun ofReal(vararg elements: Double): Vector<Double,DoubleField> = VectorSpace.real(elements.size).produceElement { elements[it] } VectorSpace.real(size).produceElement(initializer)
fun ofReal(vararg elements: Double): Vector<Double, DoubleField> =
VectorSpace.real(elements.size).produceElement { elements[it] }
} }
} }
data class BufferVectorSpace<T : Any, S : Space<T>>( data class BufferVectorSpace<T : Any, S : Space<T>>(
override val size: Int, override val size: Int,
override val space: S, override val space: S,
val bufferFactory: BufferFactory<T> val bufferFactory: BufferFactory<T>
) : VectorSpace<T, S> { ) : VectorSpace<T, S> {
override fun produce(initializer: (Int) -> T) = bufferFactory(size, initializer) override fun produce(initializer: (Int) -> T) = bufferFactory(size, initializer)
override fun produceElement(initializer: (Int) -> T): Vector<T, S> = BufferVector(this, produce(initializer)) override fun produceElement(initializer: (Int) -> T): Vector<T, S> = BufferVector(this, produce(initializer))
} }
data class BufferVector<T : Any, S : Space<T>>(override val context: VectorSpace<T, S>, val buffer: Buffer<T>) : Vector<T, S> { data class BufferVector<T : Any, S : Space<T>>(override val context: VectorSpace<T, S>, val buffer: Buffer<T>) :
Vector<T, S> {
init { init {
if (context.size != buffer.size) { if (context.size != buffer.size) {
@ -101,10 +111,13 @@ data class BufferVector<T : Any, S : Space<T>>(override val context: VectorSpace
return buffer[index] return buffer[index]
} }
override fun getSelf(): BufferVector<T, S override fun unwrap(): Point<T> = this
override fun Point<T>.wrap(): Vector<T, S> = BufferVector(context, this)
override fun iterator(): Iterator<T> = (0 until size).map { buffer[it] }.iterator() override fun iterator(): Iterator<T> = (0 until size).map { buffer[it] }.iterator()
override fun toString(): String = this.asSequence().joinToString(prefix = "[", postfix = "]", separator = ", ") { it.toString() } override fun toString(): String =
this.asSequence().joinToString(prefix = "[", postfix = "]", separator = ", ") { it.toString() }
} }

View File

@ -27,33 +27,34 @@ fun <T, R> Sequence<T>.cumulative(initial: R, operation: (T, R) -> R): Sequence<
override fun iterator(): Iterator<R> = this@cumulative.iterator().cumulative(initial, operation) override fun iterator(): Iterator<R> = this@cumulative.iterator().cumulative(initial, operation)
} }
fun <T, R> List<T>.cumulative(initial: R, operation: (T, R) -> R): List<R> = this.iterator().cumulative(initial, operation).asSequence().toList() fun <T, R> List<T>.cumulative(initial: R, operation: (T, R) -> R): List<R> =
this.iterator().cumulative(initial, operation).asSequence().toList()
//Cumulative sum //Cumulative sum
@JvmName("cumulativeSumOfDouble") @JvmName("cumulativeSumOfDouble")
fun Iterable<Double>.cumulativeSum() = this.cumulative(0.0){ element, sum -> sum + element} fun Iterable<Double>.cumulativeSum() = this.cumulative(0.0) { element, sum -> sum + element }
@JvmName("cumulativeSumOfInt") @JvmName("cumulativeSumOfInt")
fun Iterable<Int>.cumulativeSum() = this.cumulative(0){ element, sum -> sum + element} fun Iterable<Int>.cumulativeSum() = this.cumulative(0) { element, sum -> sum + element }
@JvmName("cumulativeSumOfLong") @JvmName("cumulativeSumOfLong")
fun Iterable<Long>.cumulativeSum() = this.cumulative(0L){ element, sum -> sum + element} fun Iterable<Long>.cumulativeSum() = this.cumulative(0L) { element, sum -> sum + element }
@JvmName("cumulativeSumOfDouble") @JvmName("cumulativeSumOfDouble")
fun Sequence<Double>.cumulativeSum() = this.cumulative(0.0){ element, sum -> sum + element} fun Sequence<Double>.cumulativeSum() = this.cumulative(0.0) { element, sum -> sum + element }
@JvmName("cumulativeSumOfInt") @JvmName("cumulativeSumOfInt")
fun Sequence<Int>.cumulativeSum() = this.cumulative(0){ element, sum -> sum + element} fun Sequence<Int>.cumulativeSum() = this.cumulative(0) { element, sum -> sum + element }
@JvmName("cumulativeSumOfLong") @JvmName("cumulativeSumOfLong")
fun Sequence<Long>.cumulativeSum() = this.cumulative(0L){ element, sum -> sum + element} fun Sequence<Long>.cumulativeSum() = this.cumulative(0L) { element, sum -> sum + element }
@JvmName("cumulativeSumOfDouble") @JvmName("cumulativeSumOfDouble")
fun List<Double>.cumulativeSum() = this.cumulative(0.0){ element, sum -> sum + element} fun List<Double>.cumulativeSum() = this.cumulative(0.0) { element, sum -> sum + element }
@JvmName("cumulativeSumOfInt") @JvmName("cumulativeSumOfInt")
fun List<Int>.cumulativeSum() = this.cumulative(0){ element, sum -> sum + element} fun List<Int>.cumulativeSum() = this.cumulative(0) { element, sum -> sum + element }
@JvmName("cumulativeSumOfLong") @JvmName("cumulativeSumOfLong")
fun List<Long>.cumulativeSum() = this.cumulative(0L){ element, sum -> sum + element} fun List<Long>.cumulativeSum() = this.cumulative(0L) { element, sum -> sum + element }

View File

@ -33,13 +33,25 @@ interface Space<T> {
operator fun T.times(k: Number) = multiply(this, k.toDouble()) operator fun T.times(k: Number) = multiply(this, k.toDouble())
operator fun T.div(k: Number) = multiply(this, 1.0 / k.toDouble()) operator fun T.div(k: Number) = multiply(this, 1.0 / k.toDouble())
operator fun Number.times(b: T) = b * this operator fun Number.times(b: T) = b * this
//TODO move to external extensions when they are available
fun Iterable<T>.sum(): T = fold(zero) { left, right -> left + right } fun Iterable<T>.sum(): T = fold(zero) { left, right -> left + right }
fun Sequence<T>.sum(): T = fold(zero) { left, right -> left + right } fun Sequence<T>.sum(): T = fold(zero) { left, right -> left + right }
} }
abstract class AbstractSpace<T> : Space<T> {
//TODO move to external extensions when they are available
final override operator fun T.unaryMinus(): T = multiply(this, -1.0)
final override operator fun T.plus(b: T): T = add(this, b)
final override operator fun T.minus(b: T): T = add(this, -b)
final override operator fun T.times(k: Number) = multiply(this, k.toDouble())
final override operator fun T.div(k: Number) = multiply(this, 1.0 / k.toDouble())
final override operator fun Number.times(b: T) = b * this
final override fun Iterable<T>.sum(): T = fold(zero) { left, right -> left + right }
final override fun Sequence<T>.sum(): T = fold(zero) { left, right -> left + right }
}
/** /**
* The same as {@link Space} but with additional multiplication operation * The same as {@link Space} but with additional multiplication operation
*/ */
@ -56,6 +68,15 @@ interface Ring<T> : Space<T> {
operator fun T.times(b: T): T = multiply(this, b) operator fun T.times(b: T): T = multiply(this, b)
// operator fun T.plus(b: Number) = this.plus(b * one)
// operator fun Number.plus(b: T) = b + this
//
// operator fun T.minus(b: Number) = this.minus(b * one)
// operator fun Number.minus(b: T) = -b + this
}
abstract class AbstractRing<T: Any> : AbstractSpace<T>(), Ring<T> {
final override operator fun T.times(b: T): T = multiply(this, b)
} }
/** /**
@ -66,10 +87,9 @@ interface Field<T> : Ring<T> {
operator fun T.div(b: T): T = divide(this, b) operator fun T.div(b: T): T = divide(this, b)
operator fun Number.div(b: T) = this * divide(one, b) operator fun Number.div(b: T) = this * divide(one, b)
}
operator fun T.plus(b: Number) = this.plus(b * one) abstract class AbstractField<T: Any> : AbstractRing<T>(), Field<T> {
operator fun Number.plus(b: T) = b + this final override operator fun T.div(b: T): T = divide(this, b)
final override operator fun Number.div(b: T) = this * divide(one, b)
operator fun T.minus(b: Number) = this.minus(b * one)
operator fun Number.minus(b: T) = -b + this
} }

View File

@ -14,7 +14,8 @@ object ComplexField : Field<Complex> {
override fun multiply(a: Complex, k: Double): Complex = Complex(a.re * k, a.im * k) override fun multiply(a: Complex, k: Double): Complex = Complex(a.re * k, a.im * k)
override fun multiply(a: Complex, b: Complex): Complex = Complex(a.re * b.re - a.im * b.im, a.re * b.im + a.im * b.re) override fun multiply(a: Complex, b: Complex): Complex =
Complex(a.re * b.re - a.im * b.im, a.re * b.im + a.im * b.re)
override fun divide(a: Complex, b: Complex): Complex { override fun divide(a: Complex, b: Complex): Complex {
val norm = b.square val norm = b.square
@ -35,8 +36,10 @@ object ComplexField : Field<Complex> {
/** /**
* Complex number class * Complex number class
*/ */
data class Complex(val re: Double, val im: Double) : FieldElement<Complex, ComplexField> { data class Complex(val re: Double, val im: Double) : FieldElement<Complex, Complex, ComplexField> {
override val self: Complex get() = this override fun unwrap(): Complex = this
override fun Complex.wrap(): Complex = this
override val context: ComplexField override val context: ComplexField
get() = ComplexField get() = ComplexField

View File

@ -6,11 +6,10 @@ import kotlin.math.pow
* Advanced Number-like field that implements basic operations * Advanced Number-like field that implements basic operations
*/ */
interface ExtendedField<T : Any> : interface ExtendedField<T : Any> :
Field<T>, Field<T>,
TrigonometricOperations<T>, TrigonometricOperations<T>,
PowerOperations<T>, PowerOperations<T>,
ExponentialOperations<T> ExponentialOperations<T>
/** /**
@ -33,7 +32,7 @@ inline class Real(val value: Double) : FieldElement<Double, Real, DoubleField> {
/** /**
* A field for double without boxing. Does not produce appropriate field element * A field for double without boxing. Does not produce appropriate field element
*/ */
object DoubleField : ExtendedField<Double>, Norm<Double, Double> { object DoubleField : AbstractField<Double>(),ExtendedField<Double>, Norm<Double, Double> {
override val zero: Double = 0.0 override val zero: Double = 0.0
override fun add(a: Double, b: Double): Double = a + b override fun add(a: Double, b: Double): Double = a + b
override fun multiply(a: Double, @Suppress("PARAMETER_NAME_CHANGED_ON_OVERRIDE") b: Double): Double = a * b override fun multiply(a: Double, @Suppress("PARAMETER_NAME_CHANGED_ON_OVERRIDE") b: Double): Double = a * b

View File

@ -19,10 +19,10 @@ interface TrigonometricOperations<T> : Field<T> {
fun ctg(arg: T): T = cos(arg) / sin(arg) fun ctg(arg: T): T = cos(arg) / sin(arg)
} }
fun <T : FieldElement<T, out TrigonometricOperations<T>>> sin(arg: T): T = arg.context.sin(arg) fun <T : MathElement<out TrigonometricOperations<T>>> sin(arg: T): T = arg.context.sin(arg)
fun <T : FieldElement<T, out TrigonometricOperations<T>>> cos(arg: T): T = arg.context.cos(arg) fun <T : MathElement<out TrigonometricOperations<T>>> cos(arg: T): T = arg.context.cos(arg)
fun <T : FieldElement<T, out TrigonometricOperations<T>>> tg(arg: T): T = arg.context.tg(arg) fun <T : MathElement<out TrigonometricOperations<T>>> tg(arg: T): T = arg.context.tg(arg)
fun <T : FieldElement<T, out TrigonometricOperations<T>>> ctg(arg: T): T = arg.context.ctg(arg) fun <T : MathElement<out TrigonometricOperations<T>>> ctg(arg: T): T = arg.context.ctg(arg)
/* Power and roots */ /* Power and roots */
@ -31,11 +31,12 @@ fun <T : FieldElement<T, out TrigonometricOperations<T>>> ctg(arg: T): T = arg.c
*/ */
interface PowerOperations<T> { interface PowerOperations<T> {
fun power(arg: T, pow: Double): T fun power(arg: T, pow: Double): T
fun sqrt(arg: T) = power(arg, 0.5)
} }
infix fun <T : MathElement<T, out PowerOperations<T>>> T.pow(power: Double): T = context.power(this, power) infix fun <T : MathElement<out PowerOperations<T>>> T.pow(power: Double): T = context.power(this, power)
fun <T : MathElement<T, out PowerOperations<T>>> sqrt(arg: T): T = arg pow 0.5 fun <T : MathElement<out PowerOperations<T>>> sqrt(arg: T): T = arg pow 0.5
fun <T : MathElement<T, out PowerOperations<T>>> sqr(arg: T): T = arg pow 2.0 fun <T : MathElement<out PowerOperations<T>>> sqr(arg: T): T = arg pow 2.0
/* Exponential */ /* Exponential */
@ -44,11 +45,11 @@ interface ExponentialOperations<T> {
fun ln(arg: T): T fun ln(arg: T): T
} }
fun <T : MathElement<T, out ExponentialOperations<T>>> exp(arg: T): T = arg.context.exp(arg) fun <T : MathElement<out ExponentialOperations<T>>> exp(arg: T): T = arg.context.exp(arg)
fun <T : MathElement<T, out ExponentialOperations<T>>> ln(arg: T): T = arg.context.ln(arg) fun <T : MathElement<out ExponentialOperations<T>>> ln(arg: T): T = arg.context.ln(arg)
interface Norm<in T, out R> { interface Norm<in T: Any, out R> {
fun norm(arg: T): R fun norm(arg: T): R
} }
fun <T : MathElement<T, out Norm<T, R>>, R> norm(arg: T): R = arg.context.norm(arg) fun <T : MathElement<out Norm<T, R>>, R> norm(arg: T): R = arg.context.norm(arg)

View File

@ -3,71 +3,120 @@ package scientifik.kmath.structures
import scientifik.kmath.operations.Field import scientifik.kmath.operations.Field
import scientifik.kmath.operations.FieldElement import scientifik.kmath.operations.FieldElement
open class BufferNDField<T, F : Field<T>>(final override val shape: IntArray, final override val field: F, val bufferFactory: BufferFactory<T>) : NDField<T, F, NDBuffer<T>> { abstract class StridedNDField<T, F : Field<T>>(shape: IntArray, elementField: F) :
AbstractNDField<T, F, NDBuffer<T>>(shape, elementField) {
abstract val bufferFactory: BufferFactory<T>
val strides = DefaultStrides(shape) val strides = DefaultStrides(shape)
}
override fun produce(initializer: F.(IntArray) -> T) =
BufferNDElement(this, bufferFactory(strides.linearSize) { offset -> field.initializer(strides.index(offset)) })
open fun NDBuffer<T>.map(transform: F.(T) -> T) = class BufferNDField<T, F : Field<T>>(
BufferNDElement(this@BufferNDField, bufferFactory(strides.linearSize) { offset -> field.transform(buffer[offset]) }) shape: IntArray,
elementField: F,
override val bufferFactory: BufferFactory<T>
) :
StridedNDField<T, F>(shape, elementField) {
open fun NDBuffer<T>.mapIndexed(transform: F.(index: IntArray, T) -> T) = override fun check(vararg elements: NDBuffer<T>) {
BufferNDElement(this@BufferNDField, bufferFactory(strides.linearSize) { offset -> field.transform(strides.index(offset), buffer[offset]) }) if (!elements.all { it.strides == this.strides }) error("Element strides are not the same as context strides")
}
open fun combine(a: NDBuffer<T>, b: NDBuffer<T>, transform: F.(T, T) -> T) = override val zero by lazy { produce { zero } }
BufferNDElement(this, bufferFactory(strides.linearSize) { offset -> field.transform(a[offset], b[offset]) }) override val one by lazy { produce { one } }
@Suppress("OVERRIDE_BY_INLINE")
override inline fun produce(crossinline initializer: F.(IntArray) -> T): BufferNDElement<T, F> =
BufferNDElement(
this,
bufferFactory(strides.linearSize) { offset -> elementField.initializer(strides.index(offset)) })
@Suppress("OVERRIDE_BY_INLINE")
override inline fun NDBuffer<T>.map(crossinline transform: F.(T) -> T): BufferNDElement<T, F> {
check(this)
return BufferNDElement(
this@BufferNDField,
bufferFactory(strides.linearSize) { offset -> elementField.transform(buffer[offset]) })
}
@Suppress("OVERRIDE_BY_INLINE")
override inline fun NDBuffer<T>.mapIndexed(crossinline transform: F.(index: IntArray, T) -> T): BufferNDElement<T, F> {
check(this)
return BufferNDElement(
this@BufferNDField,
bufferFactory(strides.linearSize) { offset ->
elementField.transform(
strides.index(offset),
buffer[offset]
)
})
}
@Suppress("OVERRIDE_BY_INLINE")
override inline fun combine(
a: NDBuffer<T>,
b: NDBuffer<T>,
crossinline transform: F.(T, T) -> T
): BufferNDElement<T, F> {
check(a, b)
return BufferNDElement(
this,
bufferFactory(strides.linearSize) { offset -> elementField.transform(a.buffer[offset], b.buffer[offset]) })
}
/** /**
* Convert any [NDStructure] to buffered structure using strides from this context. * Convert any [NDStructure] to buffered structure using strides from this context.
* If the structure is already [NDBuffer], conversion is free. If not, it could be expensive because iteration over indexes * If the structure is already [NDBuffer], conversion is free. If not, it could be expensive because iteration over indexes
*
* If the argument is [NDBuffer] with different strides structure, the new element will be produced.
*/ */
fun NDStructure<T>.toBuffer(): NDBuffer<T> = fun NDStructure<T>.toBuffer(): NDBuffer<T> {
this as? NDBuffer<T> ?: produce { index -> get(index) } return if (this is NDBuffer<T> && this.strides == this@BufferNDField.strides) {
this
override val zero: NDBuffer<T> by lazy { produce { field.zero } } } else {
produce { index -> get(index) }
override fun add(a: NDBuffer<T>, b: NDBuffer<T>): NDBuffer<T> = combine(a, b) { aValue, bValue -> add(aValue, bValue) } }
}
override fun multiply(a: NDBuffer<T>, k: Double): NDBuffer<T> = a.map { it * k }
override val one: NDBuffer<T> by lazy { produce { field.one } }
override fun multiply(a: NDBuffer<T>, b: NDBuffer<T>): NDBuffer<T> = combine(a, b) { aValue, bValue -> multiply(aValue, bValue) }
override fun divide(a: NDBuffer<T>, b: NDBuffer<T>): NDBuffer<T> = combine(a, b) { aValue, bValue -> divide(aValue, bValue) }
} }
class BufferNDElement<T, F : Field<T>>(override val context: BufferNDField<T, F>, override val buffer: Buffer<T>) : class BufferNDElement<T, F : Field<T>>(override val context: StridedNDField<T, F>, override val buffer: Buffer<T>) :
NDBuffer<T>, NDBuffer<T>,
FieldElement<NDBuffer<T>, BufferNDElement<T, F>, BufferNDField<T, F>>, FieldElement<NDBuffer<T>, BufferNDElement<T, F>, StridedNDField<T, F>>,
NDElement<T, F> { NDElement<T, F> {
override val elementField: F get() = context.field override val elementField: F
get() = context.elementField
override fun unwrap(): NDBuffer<T> = this override fun unwrap(): NDBuffer<T> =
this
override fun NDBuffer<T>.wrap(): BufferNDElement<T, F> = BufferNDElement(context, this.buffer) override fun NDBuffer<T>.wrap(): BufferNDElement<T, F> =
BufferNDElement(context, this.buffer)
override val strides get() = context.strides override val strides
get() = context.strides
override val shape: IntArray get() = context.shape override val shape: IntArray
get() = context.shape
override fun get(index: IntArray): T = buffer[strides.offset(index)] override fun get(index: IntArray): T =
buffer[strides.offset(index)]
override fun elements(): Sequence<Pair<IntArray, T>> = strides.indices().map { it to get(it) } override fun elements(): Sequence<Pair<IntArray, T>> =
strides.indices().map { it to get(it) }
override fun map(action: F.(T) -> T): BufferNDElement<T, F> = context.run { map(action) } override fun map(action: F.(T) -> T): BufferNDElement<T, F> =
context.run { map(action) }
override fun mapIndexed(transform: F.(index: IntArray, T) -> T): BufferNDElement<T, F> = context.run { mapIndexed(transform) } override fun mapIndexed(transform: F.(index: IntArray, T) -> T): BufferNDElement<T, F> =
context.run { mapIndexed(transform) }
} }
/** /**
* Element by element application of any operation on elements to the whole array. Just like in numpy * Element by element application of any operation on elements to the whole array. Just like in numpy
*/ */
operator fun <T : Any, F : Field<T>> Function1<T, T>.invoke(ndElement: BufferNDElement<T, F>) = operator fun <T : Any, F : Field<T>> Function1<T, T>.invoke(ndElement: BufferNDElement<T, F>) =
ndElement.context.run { ndElement.map { invoke(it) } } ndElement.context.run { ndElement.map { invoke(it) } }
/* plus and minus */ /* plus and minus */
@ -75,13 +124,13 @@ operator fun <T : Any, F : Field<T>> Function1<T, T>.invoke(ndElement: BufferNDE
* Summation operation for [BufferNDElement] and single element * Summation operation for [BufferNDElement] and single element
*/ */
operator fun <T : Any, F : Field<T>> BufferNDElement<T, F>.plus(arg: T) = operator fun <T : Any, F : Field<T>> BufferNDElement<T, F>.plus(arg: T) =
context.run { map { it + arg } } context.run { map { it + arg } }
/** /**
* Subtraction operation between [BufferNDElement] and single element * Subtraction operation between [BufferNDElement] and single element
*/ */
operator fun <T : Any, F : Field<T>> BufferNDElement<T, F>.minus(arg: T) = operator fun <T : Any, F : Field<T>> BufferNDElement<T, F>.minus(arg: T) =
context.run { map { it - arg } } context.run { map { it - arg } }
/* prod and div */ /* prod and div */
@ -89,10 +138,10 @@ operator fun <T : Any, F : Field<T>> BufferNDElement<T, F>.minus(arg: T) =
* Product operation for [BufferNDElement] and single element * Product operation for [BufferNDElement] and single element
*/ */
operator fun <T : Any, F : Field<T>> BufferNDElement<T, F>.times(arg: T) = operator fun <T : Any, F : Field<T>> BufferNDElement<T, F>.times(arg: T) =
context.run { map { it * arg } } context.run { map { it * arg } }
/** /**
* Division operation between [BufferNDElement] and single element * Division operation between [BufferNDElement] and single element
*/ */
operator fun <T : Any, F : Field<T>> BufferNDElement<T, F>.div(arg: T) = operator fun <T : Any, F : Field<T>> BufferNDElement<T, F>.div(arg: T) =
context.run { map { it / arg } } context.run { map { it / arg } }

View File

@ -12,7 +12,8 @@ interface Buffer<T> {
operator fun iterator(): Iterator<T> operator fun iterator(): Iterator<T>
fun contentEquals(other: Buffer<*>): Boolean = asSequence().mapIndexed { index, value -> value == other[index] }.all { it } fun contentEquals(other: Buffer<*>): Boolean =
asSequence().mapIndexed { index, value -> value == other[index] }.all { it }
} }
fun <T> Buffer<T>.asSequence(): Sequence<T> = iterator().asSequence() fun <T> Buffer<T>.asSequence(): Sequence<T> = iterator().asSequence()
@ -151,7 +152,8 @@ inline fun <reified T : Any> inlineBuffer(size: Int, initializer: (Int) -> T): B
/** /**
* Create a boxing mutable buffer of given type * Create a boxing mutable buffer of given type
*/ */
inline fun <T : Any> boxingMutableBuffer(size: Int, initializer: (Int) -> T): MutableBuffer<T> = MutableListBuffer(MutableList(size, initializer)) inline fun <T : Any> boxingMutableBuffer(size: Int, initializer: (Int) -> T): MutableBuffer<T> =
MutableListBuffer(MutableList(size, initializer))
/** /**
* Create most appropriate mutable buffer for given type avoiding boxing wherever possible * Create most appropriate mutable buffer for given type avoiding boxing wherever possible

View File

@ -7,40 +7,41 @@ import scientifik.kmath.operations.TrigonometricOperations
interface ExtendedNDField<T : Any, F : ExtendedField<T>, N : NDStructure<out T>> : interface ExtendedNDField<T : Any, F : ExtendedField<T>, N : NDStructure<out T>> :
NDField<T, F, N>, NDField<T, F, N>,
TrigonometricOperations<N>, TrigonometricOperations<N>,
PowerOperations<N>, PowerOperations<N>,
ExponentialOperations<N> ExponentialOperations<N>
/** /**
* NDField that supports [ExtendedField] operations on its elements * NDField that supports [ExtendedField] operations on its elements
*/ */
class ExtendedNDFieldWrapper<T : Any, F : ExtendedField<T>, N : NDStructure<out T>>(private val ndField: NDField<T, F, N>) : ExtendedNDField<T, F, N>, NDField<T,F,N> by ndField { class ExtendedNDFieldWrapper<T : Any, F : ExtendedField<T>, N : NDStructure<out T>>(private val ndField: NDField<T, F, N>) :
ExtendedNDField<T, F, N>, NDField<T, F, N> by ndField {
override val shape: IntArray get() = ndField.shape override val shape: IntArray get() = ndField.shape
override val field: F get() = ndField.field override val elementField: F get() = ndField.elementField
override fun produce(initializer: F.(IntArray) -> T) = ndField.produce(initializer) override fun produce(initializer: F.(IntArray) -> T) = ndField.produce(initializer)
override fun power(arg: N, pow: Double): N { override fun power(arg: N, pow: Double): N {
return produce { with(field) { power(arg[it], pow) } } return produce { with(elementField) { power(arg[it], pow) } }
} }
override fun exp(arg: N): N { override fun exp(arg: N): N {
return produce { with(field) { exp(arg[it]) } } return produce { with(elementField) { exp(arg[it]) } }
} }
override fun ln(arg: N): N { override fun ln(arg: N): N {
return produce { with(field) { ln(arg[it]) } } return produce { with(elementField) { ln(arg[it]) } }
} }
override fun sin(arg: N): N { override fun sin(arg: N): N {
return produce { with(field) { sin(arg[it]) } } return produce { with(elementField) { sin(arg[it]) } }
} }
override fun cos(arg: N): N { override fun cos(arg: N): N {
return produce { with(field) { cos(arg[it]) } } return produce { with(elementField) { cos(arg[it]) } }
} }
} }

View File

@ -18,30 +18,38 @@ object NDElements {
* Create a optimized NDArray of doubles * Create a optimized NDArray of doubles
*/ */
fun real(shape: IntArray, initializer: DoubleField.(IntArray) -> Double = { 0.0 }) = fun real(shape: IntArray, initializer: DoubleField.(IntArray) -> Double = { 0.0 }) =
NDField.real(shape).produce(initializer) NDField.real(shape).produce(initializer)
fun real1D(dim: Int, initializer: (Int) -> Double = { _ -> 0.0 }) = fun real1D(dim: Int, initializer: (Int) -> Double = { _ -> 0.0 }) =
real(intArrayOf(dim)) { initializer(it[0]) } real(intArrayOf(dim)) { initializer(it[0]) }
fun real2D(dim1: Int, dim2: Int, initializer: (Int, Int) -> Double = { _, _ -> 0.0 }) = fun real2D(dim1: Int, dim2: Int, initializer: (Int, Int) -> Double = { _, _ -> 0.0 }) =
real(intArrayOf(dim1, dim2)) { initializer(it[0], it[1]) } real(intArrayOf(dim1, dim2)) { initializer(it[0], it[1]) }
fun real3D(dim1: Int, dim2: Int, dim3: Int, initializer: (Int, Int, Int) -> Double = { _, _, _ -> 0.0 }) = fun real3D(dim1: Int, dim2: Int, dim3: Int, initializer: (Int, Int, Int) -> Double = { _, _, _ -> 0.0 }) =
real(intArrayOf(dim1, dim2, dim3)) { initializer(it[0], it[1], it[2]) } real(intArrayOf(dim1, dim2, dim3)) { initializer(it[0], it[1], it[2]) }
/** /**
* Simple boxing NDArray * Simple boxing NDArray
*/ */
fun <T : Any, F : Field<T>> generic(shape: IntArray, field: F, initializer: F.(IntArray) -> T): GenericNDElement<T, F> { fun <T : Any, F : Field<T>> generic(
shape: IntArray,
field: F,
initializer: F.(IntArray) -> T
): GenericNDElement<T, F> {
val ndField = GenericNDField(shape, field) val ndField = GenericNDField(shape, field)
val structure = ndStructure(shape) { index -> field.initializer(index) } val structure = ndStructure(shape) { index -> field.initializer(index) }
return GenericNDElement(ndField, structure) return GenericNDElement(ndField, structure)
} }
inline fun <reified T : Any, F : Field<T>> inline(shape: IntArray, field: F, noinline initializer: F.(IntArray) -> T): GenericNDElement<T, F> { inline fun <reified T : Any, F : Field<T>> inline(
shape: IntArray,
field: F,
noinline initializer: F.(IntArray) -> T
): GenericNDElement<T, F> {
val ndField = GenericNDField(shape, field) val ndField = GenericNDField(shape, field)
val structure = ndStructure(shape, ::inlineBuffer) { index -> field.initializer(index) } val structure = ndStructure(shape, ::inlineBuffer) { index -> field.initializer(index) }
return GenericNDElement(ndField, structure) return GenericNDElement(ndField, structure)
@ -52,31 +60,36 @@ object NDElements {
/** /**
* Element by element application of any operation on elements to the whole array. Just like in numpy * Element by element application of any operation on elements to the whole array. Just like in numpy
*/ */
operator fun <T, F : Field<T>> Function1<T, T>.invoke(ndElement: NDElement<T, F>) = ndElement.map { value -> this@invoke(value) } operator fun <T, F : Field<T>> Function1<T, T>.invoke(ndElement: NDElement<T, F>) =
ndElement.map { value -> this@invoke(value) }
/* plus and minus */ /* plus and minus */
/** /**
* Summation operation for [NDElements] and single element * Summation operation for [NDElements] and single element
*/ */
operator fun <T, F : Field<T>> NDElement<T, F>.plus(arg: T): NDElement<T, F> = this.map { value -> elementField.run { arg + value } } operator fun <T, F : Field<T>> NDElement<T, F>.plus(arg: T): NDElement<T, F> =
this.map { value -> elementField.run { arg + value } }
/** /**
* Subtraction operation between [NDElements] and single element * Subtraction operation between [NDElements] and single element
*/ */
operator fun <T, F : Field<T>> NDElement<T, F>.minus(arg: T): NDElement<T, F> = this.map { value -> elementField.run { arg - value } } operator fun <T, F : Field<T>> NDElement<T, F>.minus(arg: T): NDElement<T, F> =
this.map { value -> elementField.run { arg - value } }
/* prod and div */ /* prod and div */
/** /**
* Product operation for [NDElements] and single element * Product operation for [NDElements] and single element
*/ */
operator fun <T, F : Field<T>> NDElement<T, F>.times(arg: T): NDElement<T, F> = this.map { value -> elementField.run { arg * value } } operator fun <T, F : Field<T>> NDElement<T, F>.times(arg: T): NDElement<T, F> =
this.map { value -> elementField.run { arg * value } }
/** /**
* Division operation between [NDElements] and single element * Division operation between [NDElements] and single element
*/ */
operator fun <T, F : Field<T>> NDElement<T, F>.div(arg: T): NDElement<T, F> = this.map { value -> elementField.run { arg / value } } operator fun <T, F : Field<T>> NDElement<T, F>.div(arg: T): NDElement<T, F> =
this.map { value -> elementField.run { arg / value } }
// /** // /**
@ -111,16 +124,19 @@ operator fun <T, F : Field<T>> NDElement<T, F>.div(arg: T): NDElement<T, F> = th
/** /**
* Read-only [NDStructure] coupled to the context. * Read-only [NDStructure] coupled to the context.
*/ */
class GenericNDElement<T, F : Field<T>>(override val context: NDField<T, F, NDStructure<T>>, private val structure: NDStructure<T>) : class GenericNDElement<T, F : Field<T>>(
NDStructure<T> by structure, override val context: NDField<T, F, NDStructure<T>>,
NDElement<T, F>, private val structure: NDStructure<T>
FieldElement<NDStructure<T>, GenericNDElement<T, F>, NDField<T, F, NDStructure<T>>> { ) :
override val elementField: F get() = context.field NDStructure<T> by structure,
NDElement<T, F>,
FieldElement<NDStructure<T>, GenericNDElement<T, F>, NDField<T, F, NDStructure<T>>> {
override val elementField: F get() = context.elementField
override fun unwrap(): NDStructure<T> = structure override fun unwrap(): NDStructure<T> = structure
override fun NDStructure<T>.wrap() = GenericNDElement(context, this) override fun NDStructure<T>.wrap() = GenericNDElement(context, this)
override fun mapIndexed(transform: F.(index: IntArray, T) -> T) = override fun mapIndexed(transform: F.(index: IntArray, T) -> T) =
ndStructure(context.shape) { index: IntArray -> context.field.transform(index, get(index)) }.wrap() ndStructure(context.shape) { index: IntArray -> context.elementField.transform(index, get(index)) }.wrap()
} }

View File

@ -1,5 +1,6 @@
package scientifik.kmath.structures package scientifik.kmath.structures
import scientifik.kmath.operations.AbstractField
import scientifik.kmath.operations.Field import scientifik.kmath.operations.Field
/** /**
@ -10,29 +11,53 @@ class ShapeMismatchException(val expected: IntArray, val actual: IntArray) : Run
/** /**
* Field for n-dimensional arrays. * Field for n-dimensional arrays.
* @param shape - the list of dimensions of the array * @param shape - the list of dimensions of the array
* @param field - operations field defined on individual array element * @param elementField - operations field defined on individual array element
* @param T - the type of the element contained in ND structure * @param T - the type of the element contained in ND structure
* @param F - field over structure elements * @param F - field over structure elements
* @param R - actual nd-element type of this field * @param R - actual nd-element type of this field
*/ */
interface NDField<T, F : Field<T>, N : NDStructure<out T>> : Field<N> { interface NDField<T, F : Field<T>, N : NDStructure<T>> : Field<N> {
val shape: IntArray val shape: IntArray
val field: F val elementField: F
/**
* Check the shape of given NDArray and throw exception if it does not coincide with shape of the field
*/
fun checkShape(vararg elements: N) {
elements.forEach {
if (!shape.contentEquals(it.shape)) {
throw ShapeMismatchException(shape, it.shape)
}
}
}
fun produce(initializer: F.(IntArray) -> T): N fun produce(initializer: F.(IntArray) -> T): N
fun N.map(transform: F.(T) -> T): N
fun N.mapIndexed(transform: F.(index: IntArray, T) -> T): N
fun combine(a: N, b: N, transform: F.(T, T) -> T): N
/**
* Element by element application of any operation on elements to the whole array. Just like in numpy
*/
operator fun Function1<T, T>.invoke(structure: N): N
/**
* Summation operation for [NDElements] and single element
*/
operator fun N.plus(arg: T): N
/**
* Subtraction operation between [NDElements] and single element
*/
operator fun N.minus(arg: T): N
/**
* Product operation for [NDElements] and single element
*/
operator fun N.times(arg: T): N
/**
* Division operation between [NDElements] and single element
*/
operator fun N.div(arg: T): N
operator fun T.plus(arg: N): N
operator fun T.minus(arg: N): N
operator fun T.times(arg: N): N
operator fun T.div(arg: N): N
companion object { companion object {
/** /**
@ -48,48 +73,85 @@ interface NDField<T, F : Field<T>, N : NDStructure<out T>> : Field<N> {
/** /**
* Create a most suitable implementation for nd-field using reified class * Create a most suitable implementation for nd-field using reified class
*/ */
inline fun <reified T : Any, F : Field<T>> inline(shape: IntArray, field: F) = BufferNDField(shape, field, ::inlineBuffer) inline fun <reified T : Any, F : Field<T>> buffered(shape: IntArray, field: F) =
BufferNDField(shape, field, ::inlineBuffer)
} }
} }
class GenericNDField<T : Any, F : Field<T>>(override val shape: IntArray, override val field: F, val bufferFactory: BufferFactory<T> = ::boxingBuffer) : NDField<T, F, NDStructure<T>> { abstract class AbstractNDField<T, F : Field<T>, N : NDStructure<T>>(
override fun produce(initializer: F.(IntArray) -> T): NDStructure<T> = ndStructure(shape, bufferFactory) { field.initializer(it) } override val shape: IntArray,
override val elementField: F
) : AbstractField<N>(), NDField<T, F, N> {
override val zero: N by lazy { produce { zero } }
override val zero: NDStructure<T> by lazy { produce { zero } } override val one: N by lazy { produce { one } }
override val one: NDStructure<T> by lazy { produce { one } } final override operator fun Function1<T, T>.invoke(structure: N) = structure.map { value -> this@invoke(value) }
final override operator fun N.plus(arg: T) = this.map { value -> elementField.run { arg + value } }
final override operator fun N.minus(arg: T) = this.map { value -> elementField.run { arg - value } }
final override operator fun N.times(arg: T) = this.map { value -> elementField.run { arg * value } }
final override operator fun N.div(arg: T) = this.map { value -> elementField.run { arg / value } }
final override operator fun T.plus(arg: N) = arg + this
final override operator fun T.minus(arg: N) = arg - this
final override operator fun T.times(arg: N) = arg * this
final override operator fun T.div(arg: N) = arg / this
/** /**
* Element-by-element addition * Element-by-element addition
*/ */
override fun add(a: NDStructure<T>, b: NDStructure<T>): NDStructure<T> { override fun add(a: N, b: N): N =
checkShape(a, b) combine(a, b) { aValue, bValue -> aValue + bValue }
return produce { field.run { a[it] + b[it] } }
}
/** /**
* Multiply all elements by cinstant * Multiply all elements by cinstant
*/ */
override fun multiply(a: NDStructure<T>, k: Double): NDStructure<T> { override fun multiply(a: N, k: Double): N =
checkShape(a) a.map { it * k }
return produce { field.run { a[it] * k } }
}
/** /**
* Element-by-element multiplication * Element-by-element multiplication
*/ */
override fun multiply(a: NDStructure<T>, b: NDStructure<T>): NDStructure<T> { override fun multiply(a: N, b: N): N =
checkShape(a) combine(a, b) { aValue, bValue -> aValue * bValue }
return produce { field.run { a[it] * b[it] } }
}
/** /**
* Element-by-element division * Element-by-element division
*/ */
override fun divide(a: NDStructure<T>, b: NDStructure<T>): NDStructure<T> { override fun divide(a: N, b: N): N =
checkShape(a) combine(a, b) { aValue, bValue -> aValue / bValue }
return produce { field.run { a[it] / b[it] } }
/**
* Check if given objects are compatible with this context. Throw exception if they are not
*/
open fun check(vararg elements: N) {
elements.forEach {
if (!shape.contentEquals(it.shape)) {
throw ShapeMismatchException(shape, it.shape)
}
}
} }
}
class GenericNDField<T : Any, F : Field<T>>(
shape: IntArray,
elementField: F,
val bufferFactory: BufferFactory<T> = ::boxingBuffer
) :
AbstractNDField<T, F, NDStructure<T>>(shape, elementField) {
override fun produce(initializer: F.(IntArray) -> T): NDStructure<T> =
ndStructure(shape, bufferFactory) { elementField.initializer(it) }
override fun NDStructure<T>.map(transform: F.(T) -> T): NDStructure<T> =
produce { index -> transform(get(index)) }
override fun NDStructure<T>.mapIndexed(transform: F.(index: IntArray, T) -> T): NDStructure<T> =
produce { index -> transform(index, get(index)) }
override fun combine(a: NDStructure<T>, b: NDStructure<T>, transform: F.(T, T) -> T): NDStructure<T> =
produce { index -> transform(a[index], b[index]) }
} }

View File

@ -127,8 +127,8 @@ interface NDBuffer<T> : NDStructure<T> {
* Boxing generic [NDStructure] * Boxing generic [NDStructure]
*/ */
data class BufferNDStructure<T>( data class BufferNDStructure<T>(
override val strides: Strides, override val strides: Strides,
override val buffer: Buffer<T> override val buffer: Buffer<T>
) : NDBuffer<T> { ) : NDBuffer<T> {
init { init {
@ -137,6 +137,12 @@ data class BufferNDStructure<T>(
} }
} }
override fun get(index: IntArray): T = buffer[strides.offset(index)]
override val shape: IntArray get() = strides.shape
override fun elements() = strides.indices().map { it to this[it] }
override fun equals(other: Any?): Boolean { override fun equals(other: Any?): Boolean {
return when { return when {
this === other -> true this === other -> true
@ -156,7 +162,10 @@ data class BufferNDStructure<T>(
/** /**
* Transform structure to a new structure using provided [BufferFactory] and optimizing if argument is [BufferNDStructure] * Transform structure to a new structure using provided [BufferFactory] and optimizing if argument is [BufferNDStructure]
*/ */
inline fun <T, reified R : Any> NDStructure<T>.mapToBuffer(factory: BufferFactory<R> = ::inlineBuffer, crossinline transform: (T) -> R): BufferNDStructure<R> { inline fun <T, reified R : Any> NDStructure<T>.mapToBuffer(
factory: BufferFactory<R> = ::inlineBuffer,
crossinline transform: (T) -> R
): BufferNDStructure<R> {
return if (this is BufferNDStructure<T>) { return if (this is BufferNDStructure<T>) {
BufferNDStructure(this.strides, factory.invoke(strides.linearSize) { transform(buffer[it]) }) BufferNDStructure(this.strides, factory.invoke(strides.linearSize) { transform(buffer[it]) })
} else { } else {
@ -171,26 +180,26 @@ inline fun <T, reified R : Any> NDStructure<T>.mapToBuffer(factory: BufferFactor
* Strides should be reused if possible * Strides should be reused if possible
*/ */
fun <T> ndStructure(strides: Strides, bufferFactory: BufferFactory<T> = ::boxingBuffer, initializer: (IntArray) -> T) = fun <T> ndStructure(strides: Strides, bufferFactory: BufferFactory<T> = ::boxingBuffer, initializer: (IntArray) -> T) =
BufferNDStructure(strides, bufferFactory(strides.linearSize) { i -> initializer(strides.index(i)) }) BufferNDStructure(strides, bufferFactory(strides.linearSize) { i -> initializer(strides.index(i)) })
/** /**
* Inline create NDStructure with non-boxing buffer implementation if it is possible * Inline create NDStructure with non-boxing buffer implementation if it is possible
*/ */
inline fun <reified T : Any> inlineNDStructure(strides: Strides, crossinline initializer: (IntArray) -> T) = inline fun <reified T : Any> inlineNDStructure(strides: Strides, crossinline initializer: (IntArray) -> T) =
BufferNDStructure(strides, inlineBuffer(strides.linearSize) { i -> initializer(strides.index(i)) }) BufferNDStructure(strides, inlineBuffer(strides.linearSize) { i -> initializer(strides.index(i)) })
fun <T> ndStructure(shape: IntArray, bufferFactory: BufferFactory<T> = ::boxingBuffer, initializer: (IntArray) -> T) = fun <T> ndStructure(shape: IntArray, bufferFactory: BufferFactory<T> = ::boxingBuffer, initializer: (IntArray) -> T) =
ndStructure(DefaultStrides(shape), bufferFactory, initializer) ndStructure(DefaultStrides(shape), bufferFactory, initializer)
inline fun <reified T : Any> inlineNdStructure(shape: IntArray, crossinline initializer: (IntArray) -> T) = inline fun <reified T : Any> inlineNdStructure(shape: IntArray, crossinline initializer: (IntArray) -> T) =
inlineNDStructure(DefaultStrides(shape), initializer) inlineNDStructure(DefaultStrides(shape), initializer)
/** /**
* Mutable ND buffer based on linear [inlineBuffer] * Mutable ND buffer based on linear [inlineBuffer]
*/ */
class MutableBufferNDStructure<T>( class MutableBufferNDStructure<T>(
override val strides: Strides, override val strides: Strides,
override val buffer: MutableBuffer<T> override val buffer: MutableBuffer<T>
) : NDBuffer<T>, MutableNDStructure<T> { ) : NDBuffer<T>, MutableNDStructure<T> {
init { init {
@ -205,19 +214,30 @@ class MutableBufferNDStructure<T>(
/** /**
* The same as [inlineNDStructure], but mutable * The same as [inlineNDStructure], but mutable
*/ */
fun <T : Any> mutableNdStructure(strides: Strides, bufferFactory: MutableBufferFactory<T> = ::boxingMutableBuffer, initializer: (IntArray) -> T) = fun <T : Any> mutableNdStructure(
MutableBufferNDStructure(strides, bufferFactory(strides.linearSize) { i -> initializer(strides.index(i)) }) strides: Strides,
bufferFactory: MutableBufferFactory<T> = ::boxingMutableBuffer,
initializer: (IntArray) -> T
) =
MutableBufferNDStructure(strides, bufferFactory(strides.linearSize) { i -> initializer(strides.index(i)) })
inline fun <reified T : Any> inlineMutableNdStructure(strides: Strides, crossinline initializer: (IntArray) -> T) = inline fun <reified T : Any> inlineMutableNdStructure(strides: Strides, crossinline initializer: (IntArray) -> T) =
MutableBufferNDStructure(strides, inlineMutableBuffer(strides.linearSize) { i -> initializer(strides.index(i)) }) MutableBufferNDStructure(strides, inlineMutableBuffer(strides.linearSize) { i -> initializer(strides.index(i)) })
fun <T : Any> mutableNdStructure(shape: IntArray, bufferFactory: MutableBufferFactory<T> = ::boxingMutableBuffer, initializer: (IntArray) -> T) = fun <T : Any> mutableNdStructure(
mutableNdStructure(DefaultStrides(shape), bufferFactory, initializer) shape: IntArray,
bufferFactory: MutableBufferFactory<T> = ::boxingMutableBuffer,
initializer: (IntArray) -> T
) =
mutableNdStructure(DefaultStrides(shape), bufferFactory, initializer)
inline fun <reified T : Any> inlineMutableNdStructure(shape: IntArray, crossinline initializer: (IntArray) -> T) = inline fun <reified T : Any> inlineMutableNdStructure(shape: IntArray, crossinline initializer: (IntArray) -> T) =
inlineMutableNdStructure(DefaultStrides(shape), initializer) inlineMutableNdStructure(DefaultStrides(shape), initializer)
inline fun <reified T : Any> NDStructure<T>.combine(struct: NDStructure<T>, crossinline block: (T, T) -> T): NDStructure<T> { inline fun <reified T : Any> NDStructure<T>.combine(
struct: NDStructure<T>,
crossinline block: (T, T) -> T
): NDStructure<T> {
if (!this.shape.contentEquals(struct.shape)) error("Shape mismatch in structure combination") if (!this.shape.contentEquals(struct.shape)) error("Shape mismatch in structure combination")
return inlineNdStructure(shape) { block(this[it], struct[it]) } return inlineNdStructure(shape) { block(this[it], struct[it]) }
} }

View File

@ -5,47 +5,78 @@ import scientifik.kmath.operations.DoubleField
typealias RealNDElement = BufferNDElement<Double, DoubleField> typealias RealNDElement = BufferNDElement<Double, DoubleField>
class RealNDField(shape: IntArray) : class RealNDField(shape: IntArray) :
BufferNDField<Double, DoubleField>(shape, DoubleField, DoubleBufferFactory), StridedNDField<Double, DoubleField>(shape, DoubleField),
ExtendedNDField<Double, DoubleField, NDBuffer<Double>> { ExtendedNDField<Double, DoubleField, NDBuffer<Double>> {
override val bufferFactory: BufferFactory<Double>
get() = DoubleBufferFactory
/** /**
* Inline map an NDStructure to * Inline map an NDStructure to
*/ */
private inline fun NDBuffer<Double>.mapInline(crossinline operation: DoubleField.(Double) -> Double): RealNDElement { @Suppress("OVERRIDE_BY_INLINE")
val array = DoubleArray(strides.linearSize) { offset -> DoubleField.operation(buffer[offset]) } override inline fun NDBuffer<Double>.map(crossinline transform: DoubleField.(Double) -> Double): RealNDElement {
check(this)
val array = DoubleArray(strides.linearSize) { offset -> DoubleField.transform(buffer[offset]) }
return BufferNDElement(this@RealNDField, DoubleBuffer(array)) return BufferNDElement(this@RealNDField, DoubleBuffer(array))
} }
@Suppress("OVERRIDE_BY_INLINE") @Suppress("OVERRIDE_BY_INLINE")
override inline fun produce(initializer: DoubleField.(IntArray) -> Double): RealNDElement { override inline fun produce(crossinline initializer: DoubleField.(IntArray) -> Double): RealNDElement {
val array = DoubleArray(strides.linearSize) { offset -> field.initializer(strides.index(offset)) } val array = DoubleArray(strides.linearSize) { offset -> elementField.initializer(strides.index(offset)) }
return BufferNDElement(this, DoubleBuffer(array)) return BufferNDElement(this, DoubleBuffer(array))
} }
override fun power(arg: NDBuffer<Double>, pow: Double) = arg.mapInline { power(it, pow) } @Suppress("OVERRIDE_BY_INLINE")
override inline fun NDBuffer<Double>.mapIndexed(crossinline transform: DoubleField.(index: IntArray, Double) -> Double): BufferNDElement<Double, DoubleField> {
check(this)
return BufferNDElement(
this@RealNDField,
bufferFactory(strides.linearSize) { offset ->
elementField.transform(
strides.index(offset),
buffer[offset]
)
})
}
override fun exp(arg: NDBuffer<Double>) = arg.mapInline { exp(it) } @Suppress("OVERRIDE_BY_INLINE")
override inline fun combine(
a: NDBuffer<Double>,
b: NDBuffer<Double>,
crossinline transform: DoubleField.(Double, Double) -> Double
): BufferNDElement<Double, DoubleField> {
check(a, b)
return BufferNDElement(
this,
bufferFactory(strides.linearSize) { offset -> elementField.transform(a.buffer[offset], b.buffer[offset]) })
}
override fun ln(arg: NDBuffer<Double>) = arg.mapInline { ln(it) } override fun power(arg: NDBuffer<Double>, pow: Double) = arg.map { power(it, pow) }
override fun sin(arg: NDBuffer<Double>) = arg.mapInline { sin(it) } override fun exp(arg: NDBuffer<Double>) = arg.map { exp(it) }
override fun cos(arg: NDBuffer<Double>) = arg.mapInline { cos(it) } override fun ln(arg: NDBuffer<Double>) = arg.map { ln(it) }
override fun NDBuffer<Double>.times(k: Number) = mapInline { value -> value * k.toDouble() } override fun sin(arg: NDBuffer<Double>) = arg.map { sin(it) }
override fun NDBuffer<Double>.div(k: Number) = mapInline { value -> value / k.toDouble() } override fun cos(arg: NDBuffer<Double>) = arg.map { cos(it) }
//
override fun Number.times(b: NDBuffer<Double>) = b * this // override fun NDBuffer<Double>.times(k: Number) = mapInline { value -> value * k.toDouble() }
//
override fun Number.div(b: NDBuffer<Double>) = b * (1.0 / this.toDouble()) // override fun NDBuffer<Double>.div(k: Number) = mapInline { value -> value / k.toDouble() }
//
// override fun Number.times(b: NDBuffer<Double>) = b * this
//
// override fun Number.div(b: NDBuffer<Double>) = b * (1.0 / this.toDouble())
} }
/** /**
* Fast element production using function inlining * Fast element production using function inlining
*/ */
inline fun BufferNDField<Double, DoubleField>.produceInline(crossinline initializer: DoubleField.(Int) -> Double): RealNDElement { inline fun StridedNDField<Double, DoubleField>.produceInline(crossinline initializer: DoubleField.(Int) -> Double): RealNDElement {
val array = DoubleArray(strides.linearSize) { offset -> field.initializer(offset) } val array = DoubleArray(strides.linearSize) { offset -> elementField.initializer(offset) }
return BufferNDElement(this, DoubleBuffer(array)) return BufferNDElement(this, DoubleBuffer(array))
} }
@ -53,7 +84,8 @@ inline fun BufferNDField<Double, DoubleField>.produceInline(crossinline initiali
* Element by element application of any operation on elements to the whole array. Just like in numpy * Element by element application of any operation on elements to the whole array. Just like in numpy
*/ */
operator fun Function1<Double, Double>.invoke(ndElement: RealNDElement) = operator fun Function1<Double, Double>.invoke(ndElement: RealNDElement) =
ndElement.context.produceInline { i -> invoke(ndElement.buffer[i]) } ndElement.context.produceInline { i -> invoke(ndElement.buffer[i]) }
/* plus and minus */ /* plus and minus */
@ -61,10 +93,10 @@ operator fun Function1<Double, Double>.invoke(ndElement: RealNDElement) =
* Summation operation for [BufferNDElement] and single element * Summation operation for [BufferNDElement] and single element
*/ */
operator fun RealNDElement.plus(arg: Double) = operator fun RealNDElement.plus(arg: Double) =
context.produceInline { i -> buffer[i] + arg } context.produceInline { i -> buffer[i] + arg }
/** /**
* Subtraction operation between [BufferNDElement] and single element * Subtraction operation between [BufferNDElement] and single element
*/ */
operator fun RealNDElement.minus(arg: Double) = operator fun RealNDElement.minus(arg: Double) =
context.produceInline { i -> buffer[i] - arg } context.produceInline { i -> buffer[i] - arg }

View File

@ -6,10 +6,10 @@ import scientifik.kmath.operations.DoubleField
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
class FieldExpressionContextTest { class ExpressionFieldTest {
@Test @Test
fun testExpression() { fun testExpression() {
val context = FieldExpressionContext(DoubleField) val context = ExpressionField(DoubleField)
val expression = with(context) { val expression = with(context) {
val x = variable("x", 2.0) val x = variable("x", 2.0)
x * x + 2 * x + 1.0 x * x + 2 * x + 1.0
@ -20,10 +20,10 @@ class FieldExpressionContextTest {
@Test @Test
fun testComplex() { fun testComplex() {
val context = FieldExpressionContext(ComplexField) val context = ExpressionField(ComplexField)
val expression = with(context) { val expression = with(context) {
val x = variable("x", Complex(2.0, 0.0)) val x = variable("x", Complex(2.0, 0.0))
x * x + 2 * x + 1.0 x * x + 2 * x + one
} }
assertEquals(expression("x" to Complex(1.0, 0.0)), Complex(4.0, 0.0)) assertEquals(expression("x" to Complex(1.0, 0.0)), Complex(4.0, 0.0))
assertEquals(expression(), Complex(9.0, 0.0)) assertEquals(expression(), Complex(9.0, 0.0))
@ -31,23 +31,23 @@ class FieldExpressionContextTest {
@Test @Test
fun separateContext() { fun separateContext() {
fun <T> FieldExpressionContext<T>.expression(): Expression<T>{ fun <T> ExpressionField<T>.expression(): Expression<T> {
val x = variable("x") val x = variable("x")
return x * x + 2 * x + 1.0 return x * x + 2 * x + one
} }
val expression = FieldExpressionContext(DoubleField).expression() val expression = ExpressionField(DoubleField).expression()
assertEquals(expression("x" to 1.0), 4.0) assertEquals(expression("x" to 1.0), 4.0)
} }
@Test @Test
fun valueExpression() { fun valueExpression() {
val expressionBuilder: FieldExpressionContext<Double>.()->Expression<Double> = { val expressionBuilder: ExpressionField<Double>.() -> Expression<Double> = {
val x = variable("x") val x = variable("x")
x * x + 2 * x + 1.0 x * x + 2 * x + 1.0
} }
val expression = FieldExpressionContext(DoubleField).expressionBuilder() val expression = ExpressionField(DoubleField).expressionBuilder()
assertEquals(expression("x" to 1.0), 4.0) assertEquals(expression("x" to 1.0), 4.0)
} }
} }

View File

@ -11,8 +11,8 @@ class MultivariateHistogramTest {
@Test @Test
fun testSinglePutHistogram() { fun testSinglePutHistogram() {
val histogram = FastHistogram.fromRanges( val histogram = FastHistogram.fromRanges(
(-1.0..1.0), (-1.0..1.0),
(-1.0..1.0) (-1.0..1.0)
) )
histogram.put(0.55, 0.55) histogram.put(0.55, 0.55)
val bin = histogram.find { it.value.toInt() > 0 }!! val bin = histogram.find { it.value.toInt() > 0 }!!
@ -22,21 +22,21 @@ class MultivariateHistogramTest {
} }
@Test @Test
fun testSequentialPut(){ fun testSequentialPut() {
val histogram = FastHistogram.fromRanges( val histogram = FastHistogram.fromRanges(
(-1.0..1.0), (-1.0..1.0),
(-1.0..1.0), (-1.0..1.0),
(-1.0..1.0) (-1.0..1.0)
) )
val random = Random(1234) val random = Random(1234)
fun nextDouble() = random.nextDouble(-1.0,1.0) fun nextDouble() = random.nextDouble(-1.0, 1.0)
val n = 10000 val n = 10000
histogram.fill { histogram.fill {
repeat(n){ repeat(n) {
yield(Vector.ofReal(nextDouble(),nextDouble(),nextDouble())) yield(Vector.ofReal(nextDouble(), nextDouble(), nextDouble()))
} }
} }
assertEquals(n, histogram.sumBy { it.value.toInt() }) assertEquals(n, histogram.sumBy { it.value.toInt() })

View File

@ -21,8 +21,8 @@ class MatrixTest {
} }
@Test @Test
fun testTranspose(){ fun testTranspose() {
val matrix = MatrixSpace.real(3,3).one val matrix = MatrixSpace.real(3, 3).one
val transposed = matrix.transpose() val transposed = matrix.transpose()
assertEquals(matrix.context, transposed.context) assertEquals(matrix.context, transposed.context)
assertEquals((matrix as StructureMatrix).structure, (transposed as StructureMatrix).structure) assertEquals((matrix as StructureMatrix).structure, (transposed as StructureMatrix).structure)

View File

@ -6,9 +6,9 @@ import kotlin.test.assertEquals
class RealLUSolverTest { class RealLUSolverTest {
@Test @Test
fun testInvertOne() { fun testInvertOne() {
val matrix = MatrixSpace.real(2,2).one val matrix = MatrixSpace.real(2, 2).one
val inverted = RealLUSolver.inverse(matrix) val inverted = RealLUSolver.inverse(matrix)
assertEquals(matrix,inverted) assertEquals(matrix, inverted)
} }
// @Test // @Test

View File

@ -6,10 +6,9 @@ import kotlin.test.assertEquals
class RealFieldTest { class RealFieldTest {
@Test @Test
fun testSqrt() { fun testSqrt() {
//fails because KT-27586 val sqrt = with(DoubleField) {
val sqrt = with(RealField) { sqrt(25 * one)
sqrt( 25 * one)
} }
assertEquals(5.0, sqrt.value) assertEquals(5.0, sqrt)
} }
} }

View File

@ -1,6 +1,7 @@
package scientifik.kmath.structures package scientifik.kmath.structures
import scientifik.kmath.operations.Norm import scientifik.kmath.operations.Norm
import scientifik.kmath.structures.NDElements.real2D
import kotlin.math.abs import kotlin.math.abs
import kotlin.math.pow import kotlin.math.pow
import kotlin.test.Test import kotlin.test.Test

View File

@ -1,16 +1,33 @@
package scientifik.kmath.histogram package scientifik.kmath.histogram
actual class LongCounter{ actual class LongCounter {
private var sum: Long = 0 private var sum: Long = 0
actual fun decrement() {sum--} actual fun decrement() {
actual fun increment() {sum++} sum--
actual fun reset() {sum = 0} }
actual fun increment() {
sum++
}
actual fun reset() {
sum = 0
}
actual fun sum(): Long = sum actual fun sum(): Long = sum
actual fun add(l: Long) {sum+=l} actual fun add(l: Long) {
sum += l
}
} }
actual class DoubleCounter{
actual class DoubleCounter {
private var sum: Double = 0.0 private var sum: Double = 0.0
actual fun reset() {sum = 0.0} actual fun reset() {
sum = 0.0
}
actual fun sum(): Double = sum actual fun sum(): Double = sum
actual fun add(d: Double) {sum+=d} actual fun add(d: Double) {
sum += d
}
} }

View File

@ -18,7 +18,7 @@ class UnivariateBin(val position: Double, val size: Double, val counter: LongCou
override fun contains(vector: Buffer<out Double>): Boolean = contains(vector[0]) override fun contains(vector: Buffer<out Double>): Boolean = contains(vector[0])
internal operator fun inc() = this.also { counter.increment()} internal operator fun inc() = this.also { counter.increment() }
override val dimension: Int get() = 1 override val dimension: Int get() = 1
} }
@ -26,7 +26,8 @@ class UnivariateBin(val position: Double, val size: Double, val counter: LongCou
/** /**
* Univariate histogram with log(n) bin search speed * Univariate histogram with log(n) bin search speed
*/ */
class UnivariateHistogram private constructor(private val factory: (Double) -> UnivariateBin) : MutableHistogram<Double,UnivariateBin> { class UnivariateHistogram private constructor(private val factory: (Double) -> UnivariateBin) :
MutableHistogram<Double, UnivariateBin> {
private val bins: TreeMap<Double, UnivariateBin> = TreeMap() private val bins: TreeMap<Double, UnivariateBin> = TreeMap()

View File

@ -31,7 +31,7 @@ interface FixedSizeBufferSpec<T : Any> : BufferSpec<T> {
*/ */
fun ByteBuffer.readObject(index: Int): T { fun ByteBuffer.readObject(index: Int): T {
val dup = duplicate() val dup = duplicate()
dup.position(index*unitSize) dup.position(index * unitSize)
return dup.readObject() return dup.readObject()
} }
@ -49,7 +49,7 @@ interface FixedSizeBufferSpec<T : Any> : BufferSpec<T> {
*/ */
fun ByteBuffer.writeObject(index: Int, obj: T) { fun ByteBuffer.writeObject(index: Int, obj: T) {
val dup = duplicate() val dup = duplicate()
dup.position(index*unitSize) dup.position(index * unitSize)
dup.writeObject(obj) dup.writeObject(obj)
} }
} }

View File

@ -2,7 +2,8 @@ package scientifik.kmath.structures
import java.nio.ByteBuffer import java.nio.ByteBuffer
class ObjectBuffer<T : Any>(private val buffer: ByteBuffer, private val spec: FixedSizeBufferSpec<T>) : MutableBuffer<T> { class ObjectBuffer<T : Any>(private val buffer: ByteBuffer, private val spec: FixedSizeBufferSpec<T>) :
MutableBuffer<T> {
override val size: Int override val size: Int
get() = buffer.limit() / spec.unitSize get() = buffer.limit() / spec.unitSize
@ -23,6 +24,6 @@ class ObjectBuffer<T : Any>(private val buffer: ByteBuffer, private val spec: Fi
companion object { companion object {
fun <T : Any> create(spec: FixedSizeBufferSpec<T>, size: Int) = fun <T : Any> create(spec: FixedSizeBufferSpec<T>, size: Int) =
ObjectBuffer<T>(ByteBuffer.allocate(size * spec.unitSize), spec) ObjectBuffer<T>(ByteBuffer.allocate(size * spec.unitSize), spec)
} }
} }

View File

@ -6,6 +6,9 @@ import kotlinx.coroutines.Dispatchers
import kotlin.coroutines.CoroutineContext import kotlin.coroutines.CoroutineContext
import kotlin.coroutines.EmptyCoroutineContext import kotlin.coroutines.EmptyCoroutineContext
expect fun <R> runBlocking(context: CoroutineContext = EmptyCoroutineContext, function: suspend CoroutineScope.()->R): R expect fun <R> runBlocking(
context: CoroutineContext = EmptyCoroutineContext,
function: suspend CoroutineScope.() -> R
): R
val Dispatchers.Math: CoroutineDispatcher get() = Dispatchers.Default val Dispatchers.Math: CoroutineDispatcher get() = Dispatchers.Default

View File

@ -3,7 +3,8 @@ package scientifik.kmath.structures
import kotlinx.coroutines.* import kotlinx.coroutines.*
import scientifik.kmath.operations.Field import scientifik.kmath.operations.Field
class LazyNDField<T, F : Field<T>>(shape: IntArray, field: F, val scope: CoroutineScope = GlobalScope) : BufferNDField<T, F>(shape,field, ::boxingBuffer) { class LazyNDField<T, F : Field<T>>(shape: IntArray, field: F, val scope: CoroutineScope = GlobalScope) :
BufferNDField<T, F>(shape, field, ::boxingBuffer) {
override fun add(a: NDStructure<T>, b: NDStructure<T>): NDElements<T, F> { override fun add(a: NDStructure<T>, b: NDStructure<T>): NDElements<T, F> {
return LazyNDStructure(this) { index -> return LazyNDStructure(this) { index ->
@ -34,13 +35,23 @@ class LazyNDField<T, F : Field<T>>(shape: IntArray, field: F, val scope: Corouti
} }
} }
class LazyNDStructure<T, F : Field<T>>(override val context: LazyNDField<T, F>, val function: suspend F.(IntArray) -> T) : NDElements<T, F>, NDStructure<T> { class LazyNDStructure<T, F : Field<T>>(
override val context: LazyNDField<T, F>,
val function: suspend F.(IntArray) -> T
) : NDElements<T, F>, NDStructure<T> {
override val self: NDElements<T, F> get() = this override val self: NDElements<T, F> get() = this
override val shape: IntArray get() = context.shape override val shape: IntArray get() = context.shape
private val cache = HashMap<IntArray, Deferred<T>>() private val cache = HashMap<IntArray, Deferred<T>>()
fun deferred(index: IntArray) = cache.getOrPut(index) { context.scope.async(context = Dispatchers.Math) { function.invoke(context.field, index) } } fun deferred(index: IntArray) = cache.getOrPut(index) {
context.scope.async(context = Dispatchers.Math) {
function.invoke(
context.elementField,
index
)
}
}
suspend fun await(index: IntArray): T = deferred(index).await() suspend fun await(index: IntArray): T = deferred(index).await()
@ -54,9 +65,11 @@ class LazyNDStructure<T, F : Field<T>>(override val context: LazyNDField<T, F>,
} }
} }
fun <T> NDStructure<T>.deferred(index: IntArray) = if (this is LazyNDStructure<T, *>) this.deferred(index) else CompletableDeferred(get(index)) fun <T> NDStructure<T>.deferred(index: IntArray) =
if (this is LazyNDStructure<T, *>) this.deferred(index) else CompletableDeferred(get(index))
suspend fun <T> NDStructure<T>.await(index: IntArray) = if (this is LazyNDStructure<T, *>) this.await(index) else get(index) suspend fun <T> NDStructure<T>.await(index: IntArray) =
if (this is LazyNDStructure<T, *>) this.await(index) else get(index)
fun <T, F : Field<T>> NDElements<T, F>.lazy(scope: CoroutineScope = GlobalScope): LazyNDStructure<T, F> { fun <T, F : Field<T>> NDElements<T, F>.lazy(scope: CoroutineScope = GlobalScope): LazyNDStructure<T, F> {
return if (this is LazyNDStructure<T, F>) { return if (this is LazyNDStructure<T, F>) {
@ -67,10 +80,12 @@ fun <T, F : Field<T>> NDElements<T, F>.lazy(scope: CoroutineScope = GlobalScope)
} }
} }
inline fun <T, F : Field<T>> LazyNDStructure<T, F>.mapIndexed(crossinline action: suspend F.(IntArray, T) -> T) = LazyNDStructure(context) { index -> inline fun <T, F : Field<T>> LazyNDStructure<T, F>.mapIndexed(crossinline action: suspend F.(IntArray, T) -> T) =
action.invoke(this, index, await(index)) LazyNDStructure(context) { index ->
} action.invoke(this, index, await(index))
}
inline fun <T, F : Field<T>> LazyNDStructure<T, F>.map(crossinline action: suspend F.(T) -> T) = LazyNDStructure(context) { index -> inline fun <T, F : Field<T>> LazyNDStructure<T, F>.map(crossinline action: suspend F.(T) -> T) =
action.invoke(this, await(index)) LazyNDStructure(context) { index ->
} action.invoke(this, await(index))
}

View File

@ -14,7 +14,7 @@ class LazyNDFieldTest {
counter++ counter++
it * it it * it
} }
assertEquals(4, result[0,0,0]) assertEquals(4, result[0, 0, 0])
assertEquals(1, counter) assertEquals(1, counter)
} }
} }

View File

@ -3,4 +3,5 @@ package scientifik.kmath.structures
import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.CoroutineScope
import kotlin.coroutines.CoroutineContext import kotlin.coroutines.CoroutineContext
actual fun <R> runBlocking(context: CoroutineContext, function: suspend CoroutineScope.() -> R): R = kotlinx.coroutines.runBlocking(context, function) actual fun <R> runBlocking(context: CoroutineContext, function: suspend CoroutineScope.() -> R): R =
kotlinx.coroutines.runBlocking(context, function)

View File

@ -9,8 +9,8 @@ pluginManagement {
rootProject.name = "kmath" rootProject.name = "kmath"
include( include(
":kmath-core", ":kmath-core",
":kmath-io", ":kmath-io",
":kmath-coroutines", ":kmath-coroutines",
":benchmarks" ":benchmarks"
) )