Merge pull request #11 from altavir/dev
Dev
This commit is contained in:
commit
258d689430
33
README.md
33
README.md
@ -1,2 +1,31 @@
|
|||||||
# kmath
|
# KMath
|
||||||
Kotlin mathematics extensions library
|
Kotlin MATHematics library is intended as a kotlin based analog of numpy python library. Contrary to `numpy`
|
||||||
|
and `scipy` it is modular and has a lightweight core.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
* **Algebra**
|
||||||
|
* Mathematical operation entities like rings, spaces and fields with (**TODO** add example to wiki)
|
||||||
|
* Basic linear algebra operations (summs products, etc) backed by `Space` API.
|
||||||
|
* [In progress] advanced linear algebra operations like matrix inversions.
|
||||||
|
* **Array-like structures** Full support of numpy-like ndarray including mixed ariphmetic operations and function operations
|
||||||
|
on arrays and numbers just like it works in python (with benefit of static type checking).
|
||||||
|
|
||||||
|
## Multi-platform support
|
||||||
|
KMath is developed as a multi-platform library, which means that most of interfaces are declared in common module.
|
||||||
|
Implementation is also done in common module wherever it is possible. In some cases features are delegated to
|
||||||
|
platform even if they could be done in common module because of platform performance optimization.
|
||||||
|
|
||||||
|
## Performance
|
||||||
|
The calculation performance is one of major goals of KMath in the future, but in some cases it is not possible to achieve
|
||||||
|
both performance and flexibility. We expect to firstly focus on creating convenient universal API and then work on
|
||||||
|
increasing performance for specific cases. We expect the worst KMath performance still be better than natural python,
|
||||||
|
but worse than optimized native/scipy (mostly due to boxing operations on primitive numbers). The best performance
|
||||||
|
of optimized parts should be better than scipy.
|
||||||
|
|
||||||
|
## Releases
|
||||||
|
The project is currently in pre-release stage. Work builds could be obtained with
|
||||||
|
[![](https://jitpack.io/v/altavir/kmath.svg)](https://jitpack.io/#altavir/kmath).
|
||||||
|
|
||||||
|
## Contributing
|
||||||
|
The project requires a lot of additional work. Please fill free to contribute in any way and propose new features.
|
@ -1,5 +1,5 @@
|
|||||||
buildscript {
|
buildscript {
|
||||||
ext.kotlin_version = '1.2.41'
|
ext.kotlin_version = '1.2.60'
|
||||||
|
|
||||||
repositories {
|
repositories {
|
||||||
mavenCentral()
|
mavenCentral()
|
||||||
|
@ -1,5 +1,16 @@
|
|||||||
package scientifik.kmath.operations
|
package scientifik.kmath.operations
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The generic mathematics elements which is able to store its context
|
||||||
|
*/
|
||||||
|
interface MathElement<T, S>{
|
||||||
|
/**
|
||||||
|
* The context this element belongs to
|
||||||
|
*/
|
||||||
|
val context: S
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A general interface representing linear context of some kind.
|
* A general interface representing linear context of some kind.
|
||||||
* The context defines sum operation for its elements and multiplication by real value.
|
* The context defines sum operation for its elements and multiplication by real value.
|
||||||
@ -37,23 +48,20 @@ interface Space<T> {
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* The element of linear context
|
* The element of linear context
|
||||||
* @param S self type of the element. Needed for static type checking
|
* @param T self type of the element. Needed for static type checking
|
||||||
|
* @param S the type of space
|
||||||
*/
|
*/
|
||||||
interface SpaceElement<S : SpaceElement<S>> {
|
interface SpaceElement<T, S : Space<T>>: MathElement<T,S> {
|
||||||
/**
|
|
||||||
* The context this element belongs to
|
|
||||||
*/
|
|
||||||
val context: Space<S>
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Self value. Needed for static type checking. Needed to avoid type erasure on JVM.
|
* Self value. Needed for static type checking. Needed to avoid type erasure on JVM.
|
||||||
*/
|
*/
|
||||||
val self: S
|
val self: T
|
||||||
|
|
||||||
operator fun plus(b: S): S = context.add(self, b)
|
operator fun plus(b: T): T = context.add(self, b)
|
||||||
operator fun minus(b: S): S = context.add(self, context.multiply(b, -1.0))
|
operator fun minus(b: T): T = context.add(self, context.multiply(b, -1.0))
|
||||||
operator fun times(k: Number): S = context.multiply(self, k.toDouble())
|
operator fun times(k: Number): T = context.multiply(self, k.toDouble())
|
||||||
operator fun div(k: Number): S = context.multiply(self, 1.0 / k.toDouble())
|
operator fun div(k: Number): T = context.multiply(self, 1.0 / k.toDouble())
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -77,10 +85,10 @@ interface Ring<T> : Space<T> {
|
|||||||
/**
|
/**
|
||||||
* Ring element
|
* Ring element
|
||||||
*/
|
*/
|
||||||
interface RingElement<S : RingElement<S>> : SpaceElement<S> {
|
interface RingElement<T, S : Ring<T>> : SpaceElement<T, S> {
|
||||||
override val context: Ring<S>
|
override val context: S
|
||||||
|
|
||||||
operator fun times(b: S): S = context.multiply(self, b)
|
operator fun times(b: T): T = context.multiply(self, b)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -96,8 +104,8 @@ interface Field<T> : Ring<T> {
|
|||||||
/**
|
/**
|
||||||
* Field element
|
* Field element
|
||||||
*/
|
*/
|
||||||
interface FieldElement<S : FieldElement<S>> : RingElement<S> {
|
interface FieldElement<T, S : Field<T>> : RingElement<T, S> {
|
||||||
override val context: Field<S>
|
override val context: S
|
||||||
|
|
||||||
operator fun div(b: S): S = context.divide(self, b)
|
operator fun div(b: T): T = context.divide(self, b)
|
||||||
}
|
}
|
@ -1,23 +1,39 @@
|
|||||||
package scientifik.kmath.operations
|
package scientifik.kmath.operations
|
||||||
|
|
||||||
|
import kotlin.math.pow
|
||||||
import kotlin.math.sqrt
|
import kotlin.math.sqrt
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Field for real values
|
* Field for real values
|
||||||
*/
|
*/
|
||||||
object RealField : Field<Real> {
|
object RealField : Field<Real>, TrigonometricOperations<Real>, PowerOperations<Real>, ExponentialOperations<Real> {
|
||||||
override val zero: Real = Real(0.0)
|
override val zero: Real = Real(0.0)
|
||||||
override fun add(a: Real, b: Real): Real = Real(a.value + b.value)
|
override fun add(a: Real, b: Real): Real = Real(a.value + b.value)
|
||||||
override val one: Real = Real(1.0)
|
override val one: Real = Real(1.0)
|
||||||
override fun multiply(a: Real, b: Real): Real = Real(a.value * b.value)
|
override fun multiply(a: Real, b: Real): Real = Real(a.value * b.value)
|
||||||
override fun multiply(a: Real, k: Double): Real = Real(a.value * k)
|
override fun multiply(a: Real, k: Double): Real = Real(a.value * k)
|
||||||
override fun divide(a: Real, b: Real): Real = Real(a.value / b.value)
|
override fun divide(a: Real, b: Real): Real = Real(a.value / b.value)
|
||||||
|
|
||||||
|
override fun sin(arg: Real): Real = Real(kotlin.math.sin(arg.value))
|
||||||
|
override fun cos(arg: Real): Real = Real(kotlin.math.cos(arg.value))
|
||||||
|
|
||||||
|
override fun power(arg: Real, pow: Double): Real = Real(arg.value.pow(pow))
|
||||||
|
|
||||||
|
override fun exp(arg: Real): Real = Real(kotlin.math.exp(arg.value))
|
||||||
|
|
||||||
|
override fun ln(arg: Real): Real = Real(kotlin.math.ln(arg.value))
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Real field element wrapping double
|
* Real field element wrapping double.
|
||||||
|
*
|
||||||
|
* TODO could be replaced by inline class in kotlin 1.3 if it would allow to avoid boxing
|
||||||
*/
|
*/
|
||||||
class Real(val value: Double) : FieldElement<Real>, Number() {
|
class Real(val value: Double) : Number(), FieldElement<Real, RealField> {
|
||||||
|
/*
|
||||||
|
* The class uses composition instead of inheritance since Double is final
|
||||||
|
*/
|
||||||
|
|
||||||
override fun toByte(): Byte = value.toByte()
|
override fun toByte(): Byte = value.toByte()
|
||||||
override fun toChar(): Char = value.toChar()
|
override fun toChar(): Char = value.toChar()
|
||||||
override fun toDouble(): Double = value
|
override fun toDouble(): Double = value
|
||||||
@ -29,8 +45,10 @@ class Real(val value: Double) : FieldElement<Real>, Number() {
|
|||||||
//values are dynamically calculated to save memory
|
//values are dynamically calculated to save memory
|
||||||
override val self
|
override val self
|
||||||
get() = this
|
get() = this
|
||||||
|
|
||||||
override val context
|
override val context
|
||||||
get() = RealField
|
get() = RealField
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -54,10 +72,9 @@ object ComplexField : Field<Complex> {
|
|||||||
/**
|
/**
|
||||||
* Complex number class
|
* Complex number class
|
||||||
*/
|
*/
|
||||||
data class Complex(val re: Double, val im: Double) : FieldElement<Complex> {
|
data class Complex(val re: Double, val im: Double) : FieldElement<Complex, ComplexField> {
|
||||||
override val self: Complex
|
override val self: Complex get() = this
|
||||||
get() = this
|
override val context: ComplexField
|
||||||
override val context: Field<Complex>
|
|
||||||
get() = ComplexField
|
get() = ComplexField
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -72,15 +89,15 @@ data class Complex(val re: Double, val im: Double) : FieldElement<Complex> {
|
|||||||
val module: Double
|
val module: Double
|
||||||
get() = sqrt(square)
|
get() = sqrt(square)
|
||||||
|
|
||||||
|
|
||||||
//TODO is it convenient?
|
|
||||||
operator fun not() = conjugate
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A field for double without boxing. Does not produce appropriate field element
|
||||||
|
*/
|
||||||
object DoubleField : Field<Double> {
|
object DoubleField : Field<Double> {
|
||||||
override val zero: Double = 0.0
|
override val zero: Double = 0.0
|
||||||
override fun add(a: Double, b: Double): Double = a + b
|
override fun add(a: Double, b: Double): Double = a + b
|
||||||
override fun multiply(a: Double, b: Double): Double = a * b
|
override fun multiply(a: Double, @Suppress("PARAMETER_NAME_CHANGED_ON_OVERRIDE") b: Double): Double = a * b
|
||||||
override val one: Double = 1.0
|
override val one: Double = 1.0
|
||||||
override fun divide(a: Double, b: Double): Double = a / b
|
override fun divide(a: Double, b: Double): Double = a / b
|
||||||
}
|
}
|
@ -0,0 +1,48 @@
|
|||||||
|
package scientifik.kmath.operations
|
||||||
|
|
||||||
|
|
||||||
|
/* Trigonometric operations */
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A container for trigonometric operations for specific type. Trigonometric operations are limited to fields.
|
||||||
|
*
|
||||||
|
* The operations are not exposed to class directly to avoid method bloat but instead are declared in the field.
|
||||||
|
* It also allows to override behavior for optional operations
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
interface TrigonometricOperations<T>: Field<T> {
|
||||||
|
fun sin(arg: T): T
|
||||||
|
fun cos(arg: T): T
|
||||||
|
|
||||||
|
fun tg(arg: T): T = sin(arg) / cos(arg)
|
||||||
|
|
||||||
|
fun ctg(arg: T): T = cos(arg) / sin(arg)
|
||||||
|
}
|
||||||
|
|
||||||
|
fun <T : FieldElement<T, out TrigonometricOperations<T>>> sin(arg: T): T = arg.context.sin(arg)
|
||||||
|
fun <T : FieldElement<T, out TrigonometricOperations<T>>> cos(arg: T): T = arg.context.cos(arg)
|
||||||
|
fun <T : FieldElement<T, out TrigonometricOperations<T>>> tg(arg: T): T = arg.context.tg(arg)
|
||||||
|
fun <T : FieldElement<T, out TrigonometricOperations<T>>> ctg(arg: T): T = arg.context.ctg(arg)
|
||||||
|
|
||||||
|
/* Power and roots */
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A context extension to include power operations like square roots, etc
|
||||||
|
*/
|
||||||
|
interface PowerOperations<T> {
|
||||||
|
fun power(arg: T, pow: Double): T
|
||||||
|
}
|
||||||
|
|
||||||
|
infix fun <T : MathElement<T, out PowerOperations<T>>> T.pow(power: Double): T = context.power(this, power)
|
||||||
|
fun <T : MathElement<T, out PowerOperations<T>>> sqrt(arg: T): T = arg pow 0.5
|
||||||
|
fun <T : MathElement<T, out PowerOperations<T>>> sqr(arg: T): T = arg pow 2.0
|
||||||
|
|
||||||
|
/* Exponential */
|
||||||
|
|
||||||
|
interface ExponentialOperations<T>{
|
||||||
|
fun exp(arg: T): T
|
||||||
|
fun ln(arg: T): T
|
||||||
|
}
|
||||||
|
|
||||||
|
fun <T: MathElement<T, out ExponentialOperations<T>>> exp(arg:T): T = arg.context.exp(arg)
|
||||||
|
fun <T: MathElement<T, out ExponentialOperations<T>>> ln(arg:T): T = arg.context.ln(arg)
|
@ -0,0 +1,144 @@
|
|||||||
|
package scientifik.kmath.structures
|
||||||
|
|
||||||
|
import scientifik.kmath.operations.DoubleField
|
||||||
|
import scientifik.kmath.operations.Field
|
||||||
|
import scientifik.kmath.operations.Space
|
||||||
|
import scientifik.kmath.operations.SpaceElement
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The space for linear elements. Supports scalar product alongside with standard linear operations.
|
||||||
|
* @param T type of individual element of the vector or matrix
|
||||||
|
* @param V the type of vector space element
|
||||||
|
*/
|
||||||
|
abstract class LinearSpace<T : Any, V : LinearStructure<out T>>(val rows: Int, val columns: Int, val field: Field<T>) : Space<V> {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Produce the element of this space
|
||||||
|
*/
|
||||||
|
abstract fun produce(initializer: (Int, Int) -> T): V
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Produce new linear space with given dimensions
|
||||||
|
*/
|
||||||
|
abstract fun produceSpace(rows: Int, columns: Int): LinearSpace<T, V>
|
||||||
|
|
||||||
|
override val zero: V by lazy {
|
||||||
|
produce { _, _ -> field.zero }
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun add(a: V, b: V): V {
|
||||||
|
return produce { i, j -> with(field) { a[i, j] + b[i, j] } }
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun multiply(a: V, k: Double): V {
|
||||||
|
//TODO it is possible to implement scalable linear elements which normed values and adjustable scale to save memory and processing poser
|
||||||
|
return produce { i, j -> with(field) { a[i, j] * k } }
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Dot product
|
||||||
|
*/
|
||||||
|
fun multiply(a: V, b: V): V {
|
||||||
|
if (a.rows != b.columns) {
|
||||||
|
//TODO replace by specific exception
|
||||||
|
error("Dimension mismatch in vector dot product")
|
||||||
|
}
|
||||||
|
return produceSpace(a.rows, b.columns).produce { i, j ->
|
||||||
|
(0..a.columns).asSequence().map { k -> field.multiply(a[i, k], b[k, j]) }.reduce { first, second -> field.add(first, second) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
infix fun V.dot(b: V): V = multiply(this, b)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A matrix-like structure that is not dependent on specific space implementation
|
||||||
|
*/
|
||||||
|
interface LinearStructure<T : Any> {
|
||||||
|
val rows: Int
|
||||||
|
val columns: Int
|
||||||
|
|
||||||
|
operator fun get(i: Int, j: Int): T
|
||||||
|
|
||||||
|
fun transpose(): LinearStructure<T> {
|
||||||
|
return object : LinearStructure<T> {
|
||||||
|
override val rows: Int = this@LinearStructure.columns
|
||||||
|
override val columns: Int = this@LinearStructure.rows
|
||||||
|
override fun get(i: Int, j: Int): T = this@LinearStructure.get(j, i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
interface Vector<T : Any> : LinearStructure<T> {
|
||||||
|
override val columns: Int
|
||||||
|
get() = 1
|
||||||
|
|
||||||
|
operator fun get(i: Int) = get(i, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* DoubleArray-based implementation of vector space
|
||||||
|
*/
|
||||||
|
class ArraySpace<T : Any>(rows: Int, columns: Int, field: Field<T>) : LinearSpace<T, LinearStructure<out T>>(rows, columns, field) {
|
||||||
|
|
||||||
|
override fun produce(initializer: (Int, Int) -> T): LinearStructure<T> = ArrayMatrix<T>(this, initializer)
|
||||||
|
|
||||||
|
|
||||||
|
override fun produceSpace(rows: Int, columns: Int): LinearSpace<T, LinearStructure<out T>> {
|
||||||
|
return ArraySpace(rows, columns, field)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Member of [ArraySpace] which wraps 2-D array
|
||||||
|
*/
|
||||||
|
class ArrayMatrix<T : Any>(override val context: ArraySpace<T>, initializer: (Int, Int) -> T) : LinearStructure<T>, SpaceElement<LinearStructure<out T>, ArraySpace<T>> {
|
||||||
|
|
||||||
|
val list: List<List<T>> = (0 until rows).map { i -> (0 until columns).map { j -> initializer(i, j) } }
|
||||||
|
|
||||||
|
override val rows: Int get() = context.rows
|
||||||
|
|
||||||
|
override val columns: Int get() = context.columns
|
||||||
|
|
||||||
|
override fun get(i: Int, j: Int): T {
|
||||||
|
return list[i][j]
|
||||||
|
}
|
||||||
|
|
||||||
|
override val self: ArrayMatrix<T> get() = this
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ArrayVector<T : Any>(override val context: ArraySpace<T>, initializer: (Int) -> T) : Vector<T>, SpaceElement<LinearStructure<out T>, ArraySpace<T>> {
|
||||||
|
|
||||||
|
init {
|
||||||
|
if (context.columns != 1) {
|
||||||
|
error("Vector must have single column")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
val list: List<T> = (0 until context.rows).map(initializer)
|
||||||
|
|
||||||
|
|
||||||
|
override val rows: Int get() = context.rows
|
||||||
|
|
||||||
|
override val columns: Int = 1
|
||||||
|
|
||||||
|
override fun get(i: Int, j: Int): T {
|
||||||
|
return list[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
override val self: ArrayVector<T> get() = this
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
fun <T : Any> vector(size: Int, field: Field<T>, initializer: (Int) -> T) = ArrayVector(ArraySpace(size, 1, field), initializer)
|
||||||
|
//TODO replace by primitive array version
|
||||||
|
fun realVector(size: Int, initializer: (Int) -> Double) = vector(size, DoubleField, initializer)
|
||||||
|
|
||||||
|
fun <T : Any> Array<T>.asVector(field: Field<T>) = vector(size, field) { this[it] }
|
||||||
|
//TODO add inferred field from field element
|
||||||
|
fun DoubleArray.asVector() = realVector(this.size) { this[it] }
|
||||||
|
|
||||||
|
fun <T : Any> matrix(rows: Int, columns: Int, field: Field<T>, initializer: (Int, Int) -> T) = ArrayMatrix<T>(ArraySpace(rows, columns, field), initializer)
|
||||||
|
fun realMatrix(rows: Int, columns: Int, initializer: (Int, Int) -> Double) = matrix(rows, columns, DoubleField, initializer)
|
@ -3,22 +3,32 @@ package scientifik.kmath.structures
|
|||||||
import scientifik.kmath.operations.Field
|
import scientifik.kmath.operations.Field
|
||||||
import scientifik.kmath.operations.FieldElement
|
import scientifik.kmath.operations.FieldElement
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An exception is thrown when the expected ans actual shape of NDArray differs
|
||||||
|
*/
|
||||||
class ShapeMismatchException(val expected: List<Int>, val actual: List<Int>) : RuntimeException()
|
class ShapeMismatchException(val expected: List<Int>, val actual: List<Int>) : RuntimeException()
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Field for n-dimensional arrays.
|
* Field for n-dimensional arrays.
|
||||||
* @param shape - the list of dimensions of the array
|
* @param shape - the list of dimensions of the array
|
||||||
* @param field - operations field defined on individual array element
|
* @param field - operations field defined on individual array element
|
||||||
|
* @param T the type of the element contained in NDArray
|
||||||
*/
|
*/
|
||||||
abstract class NDField<T>(val shape: List<Int>, val field: Field<T>) : Field<NDArray<T>> {
|
abstract class NDField<T>(val shape: List<Int>, val field: Field<T>) : Field<NDArray<T>> {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create new instance of NDArray using field shape and given initializer
|
* Create new instance of NDArray using field shape and given initializer
|
||||||
|
* The producer takes list of indices as argument and returns contained value
|
||||||
*/
|
*/
|
||||||
abstract fun produce(initializer: (List<Int>) -> T): NDArray<T>
|
abstract fun produce(initializer: (List<Int>) -> T): NDArray<T>
|
||||||
|
|
||||||
override val zero: NDArray<T>
|
override val zero: NDArray<T> by lazy {
|
||||||
get() = produce { this.field.zero }
|
produce { this.field.zero }
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check the shape of given NDArray and throw exception if it does not coincide with shape of the field
|
||||||
|
*/
|
||||||
private fun checkShape(vararg arrays: NDArray<T>) {
|
private fun checkShape(vararg arrays: NDArray<T>) {
|
||||||
arrays.forEach {
|
arrays.forEach {
|
||||||
if (shape != it.shape) {
|
if (shape != it.shape) {
|
||||||
@ -40,7 +50,7 @@ abstract class NDField<T>(val shape: List<Int>, val field: Field<T>) : Field<NDA
|
|||||||
*/
|
*/
|
||||||
override fun multiply(a: NDArray<T>, k: Double): NDArray<T> {
|
override fun multiply(a: NDArray<T>, k: Double): NDArray<T> {
|
||||||
checkShape(a)
|
checkShape(a)
|
||||||
return produce { with(field) {a[it] * k} }
|
return produce { with(field) { a[it] * k } }
|
||||||
}
|
}
|
||||||
|
|
||||||
override val one: NDArray<T>
|
override val one: NDArray<T>
|
||||||
@ -51,7 +61,7 @@ abstract class NDField<T>(val shape: List<Int>, val field: Field<T>) : Field<NDA
|
|||||||
*/
|
*/
|
||||||
override fun multiply(a: NDArray<T>, b: NDArray<T>): NDArray<T> {
|
override fun multiply(a: NDArray<T>, b: NDArray<T>): NDArray<T> {
|
||||||
checkShape(a)
|
checkShape(a)
|
||||||
return produce { with(field) {a[it] * b[it]} }
|
return produce { with(field) { a[it] * b[it] } }
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -59,18 +69,18 @@ abstract class NDField<T>(val shape: List<Int>, val field: Field<T>) : Field<NDA
|
|||||||
*/
|
*/
|
||||||
override fun divide(a: NDArray<T>, b: NDArray<T>): NDArray<T> {
|
override fun divide(a: NDArray<T>, b: NDArray<T>): NDArray<T> {
|
||||||
checkShape(a)
|
checkShape(a)
|
||||||
return produce { with(field) {a[it] / b[it]} }
|
return produce { with(field) { a[it] / b[it] } }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
interface NDArray<T> : FieldElement<NDArray<T>>, Iterable<Pair<List<Int>, T>> {
|
interface NDArray<T> : FieldElement<NDArray<T>, NDField<T>> {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The list of dimensions of this NDArray
|
* The list of dimensions of this NDArray
|
||||||
*/
|
*/
|
||||||
val shape: List<Int>
|
val shape: List<Int>
|
||||||
get() = (context as NDField<T>).shape
|
get() = context.shape
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The number of dimentsions for this array
|
* The number of dimentsions for this array
|
||||||
@ -87,14 +97,14 @@ interface NDArray<T> : FieldElement<NDArray<T>>, Iterable<Pair<List<Int>, T>> {
|
|||||||
return get(*index.toIntArray())
|
return get(*index.toIntArray())
|
||||||
}
|
}
|
||||||
|
|
||||||
override operator fun iterator(): Iterator<Pair<List<Int>, T>> {
|
operator fun iterator(): Iterator<Pair<List<Int>, T>> {
|
||||||
return iterateIndexes(shape).map { Pair(it, this[it]) }.iterator()
|
return iterateIndexes(shape).map { Pair(it, this[it]) }.iterator()
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Generate new NDArray, using given transformation for each element
|
* Generate new NDArray, using given transformation for each element
|
||||||
*/
|
*/
|
||||||
fun transform(action: (List<Int>, T) -> T): NDArray<T> = (context as NDField<T>).produce { action(it, this[it]) }
|
fun transform(action: (List<Int>, T) -> T): NDArray<T> = context.produce { action(it, this[it]) }
|
||||||
|
|
||||||
companion object {
|
companion object {
|
||||||
/**
|
/**
|
||||||
@ -115,6 +125,79 @@ interface NDArray<T> : FieldElement<NDArray<T>>, Iterable<Pair<List<Int>, T>> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Element by element application of any operation on elements to the whole array. Just like in numpy
|
||||||
|
*/
|
||||||
|
operator fun <T> Function1<T, T>.invoke(ndArray: NDArray<T>): NDArray<T> = ndArray.transform { _, value -> this(value) }
|
||||||
|
|
||||||
|
/* plus and minus */
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Summation operation for [NDArray] and single element
|
||||||
|
*/
|
||||||
|
operator fun <T> NDArray<T>.plus(arg: T): NDArray<T> = transform { _, value ->
|
||||||
|
with(context.field){
|
||||||
|
arg + value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Reverse sum operation
|
||||||
|
*/
|
||||||
|
operator fun <T> T.plus(arg: NDArray<T>): NDArray<T> = arg + this
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Subtraction operation between [NDArray] and single element
|
||||||
|
*/
|
||||||
|
operator fun <T> NDArray<T>.minus(arg: T): NDArray<T> = transform { _, value ->
|
||||||
|
with(context.field){
|
||||||
|
arg - value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Reverse minus operation
|
||||||
|
*/
|
||||||
|
operator fun <T> T.minus(arg: NDArray<T>): NDArray<T> = arg.transform { _, value ->
|
||||||
|
with(arg.context.field){
|
||||||
|
this@minus - value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* prod and div */
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Product operation for [NDArray] and single element
|
||||||
|
*/
|
||||||
|
operator fun <T> NDArray<T>.times(arg: T): NDArray<T> = transform { _, value ->
|
||||||
|
with(context.field){
|
||||||
|
arg * value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Reverse product operation
|
||||||
|
*/
|
||||||
|
operator fun <T> T.times(arg: NDArray<T>): NDArray<T> = arg * this
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Division operation between [NDArray] and single element
|
||||||
|
*/
|
||||||
|
operator fun <T> NDArray<T>.div(arg: T): NDArray<T> = transform { _, value ->
|
||||||
|
with(context.field){
|
||||||
|
arg / value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Reverse division operation
|
||||||
|
*/
|
||||||
|
operator fun <T> T.div(arg: NDArray<T>): NDArray<T> = arg.transform { _, value ->
|
||||||
|
with(arg.context.field){
|
||||||
|
this@div/ value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a platform-specific NDArray of doubles
|
* Create a platform-specific NDArray of doubles
|
||||||
*/
|
*/
|
@ -0,0 +1,14 @@
|
|||||||
|
package scientifik.kmath.operations
|
||||||
|
|
||||||
|
import kotlin.test.Test
|
||||||
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
|
class RealFieldTest {
|
||||||
|
@Test
|
||||||
|
fun testSqrt() {
|
||||||
|
val sqrt = with(RealField) {
|
||||||
|
sqrt(25 * one)
|
||||||
|
}
|
||||||
|
assertEquals(5.0, sqrt.toDouble())
|
||||||
|
}
|
||||||
|
}
|
@ -5,7 +5,7 @@ repositories {
|
|||||||
}
|
}
|
||||||
|
|
||||||
dependencies {
|
dependencies {
|
||||||
expectedBy project(":common")
|
expectedBy project(":kmath-common")
|
||||||
compile "org.jetbrains.kotlin:kotlin-stdlib-jdk8:$kotlin_version"
|
compile "org.jetbrains.kotlin:kotlin-stdlib-jdk8:$kotlin_version"
|
||||||
testCompile "junit:junit:4.12"
|
testCompile "junit:junit:4.12"
|
||||||
testCompile "org.jetbrains.kotlin:kotlin-test-junit:$kotlin_version"
|
testCompile "org.jetbrains.kotlin:kotlin-test-junit:$kotlin_version"
|
@ -35,6 +35,7 @@ private class RealNDField(shape: List<Int>) : NDField<Double>(shape, DoubleField
|
|||||||
override fun produce(initializer: (List<Int>) -> Double): NDArray<Double> {
|
override fun produce(initializer: (List<Int>) -> Double): NDArray<Double> {
|
||||||
//TODO use sparse arrays for large capacities
|
//TODO use sparse arrays for large capacities
|
||||||
val buffer = DoubleBuffer.allocate(capacity)
|
val buffer = DoubleBuffer.allocate(capacity)
|
||||||
|
//FIXME there could be performance degradation due to iteration procedure. Replace by straight iteration
|
||||||
NDArray.iterateIndexes(shape).forEach {
|
NDArray.iterateIndexes(shape).forEach {
|
||||||
buffer.put(offset(it), initializer(it))
|
buffer.put(offset(it), initializer(it))
|
||||||
}
|
}
|
||||||
@ -68,7 +69,7 @@ private class RealNDField(shape: List<Int>) : NDField<Double>(shape, DoubleField
|
|||||||
//TODO generate fixed hash code for quick comparison?
|
//TODO generate fixed hash code for quick comparison?
|
||||||
|
|
||||||
|
|
||||||
override val self: NDArray<Double> = this
|
override val self: NDArray<Double> get() = this
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -0,0 +1,26 @@
|
|||||||
|
package scientifik.kmath.structures
|
||||||
|
|
||||||
|
import org.junit.Assert.assertEquals
|
||||||
|
import org.junit.Test
|
||||||
|
|
||||||
|
class ArrayMatrixTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testSum() {
|
||||||
|
val vector1 = realVector(5) { it.toDouble() }
|
||||||
|
val vector2 = realVector(5) { 5 - it.toDouble() }
|
||||||
|
val sum = vector1 + vector2
|
||||||
|
assertEquals(5.0, sum[2, 0], 0.1)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testDot() {
|
||||||
|
val vector1 = realVector(5) { it.toDouble() }
|
||||||
|
val vector2 = realVector(5) { 5 - it.toDouble() }
|
||||||
|
val product = with(vector1.context) {
|
||||||
|
vector1 dot (vector2.transpose())
|
||||||
|
}
|
||||||
|
|
||||||
|
assertEquals(10.0, product[1, 0], 0.1)
|
||||||
|
}
|
||||||
|
}
|
@ -1,6 +1,7 @@
|
|||||||
package scientifik.kmath.structures
|
package scientifik.kmath.structures
|
||||||
|
|
||||||
import org.junit.Assert.assertEquals
|
import org.junit.Assert.assertEquals
|
||||||
|
import kotlin.math.pow
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
|
|
||||||
class RealNDFieldTest {
|
class RealNDFieldTest {
|
||||||
@ -14,8 +15,8 @@ class RealNDFieldTest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testProduct(){
|
fun testProduct() {
|
||||||
val product = array1*array2
|
val product = array1 * array2
|
||||||
assertEquals(0.0, product[2, 2], 0.1)
|
assertEquals(0.0, product[2, 2], 0.1)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -24,11 +25,18 @@ class RealNDFieldTest {
|
|||||||
|
|
||||||
val array = real2DArray(3, 3) { i, j -> (i * 10 + j).toDouble() }
|
val array = real2DArray(3, 3) { i, j -> (i * 10 + j).toDouble() }
|
||||||
|
|
||||||
for(i in 0..2){
|
for (i in 0..2) {
|
||||||
for(j in 0..2){
|
for (j in 0..2) {
|
||||||
val expected= (i * 10 + j).toDouble()
|
val expected = (i * 10 + j).toDouble()
|
||||||
assertEquals("Error at index [$i, $j]", expected, array[i,j], 0.1)
|
assertEquals("Error at index [$i, $j]", expected, array[i, j], 0.1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testExternalFunction() {
|
||||||
|
val function: (Double) -> Double = { x -> x.pow(2) + 2 * x + 1 }
|
||||||
|
val result = function(array1) + 1.0
|
||||||
|
assertEquals(10.0, result[1,1],0.01)
|
||||||
|
}
|
||||||
}
|
}
|
@ -1,4 +1,4 @@
|
|||||||
rootProject.name = 'kmath'
|
rootProject.name = 'kmath'
|
||||||
include 'common'
|
include 'kmath-common'
|
||||||
include 'jvm'
|
include 'kmath-jvm'
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user