Abstract linear algebra and real value array implementation.

This commit is contained in:
Alexander Nozik 2018-07-18 12:51:21 +03:00
parent af446d2d2c
commit eb4c9a4b94
5 changed files with 116 additions and 55 deletions

View File

@ -1,5 +1,5 @@
buildscript {
ext.kotlin_version = '1.2.41'
ext.kotlin_version = '1.2.51'
repositories {
mavenCentral()

View File

@ -7,8 +7,10 @@ import scientifik.kmath.operations.SpaceElement
/**
* The space for linear elements. Supports scalar product alongside with standard linear operations.
* @param T type of individual element of the vector or matrix
* @param V the type of vector space element
*/
abstract class LinearSpace<T, V: LinearStructure<out T>>(val rows: Int, val columns: Int, val field: Field<T>) : Space<V> {
abstract class LinearSpace<T : Any, V : LinearStructure<out T>>(val rows: Int, val columns: Int, val field: Field<T>) : Space<V> {
/**
* Produce the element of this space
@ -18,7 +20,7 @@ abstract class LinearSpace<T, V: LinearStructure<out T>>(val rows: Int, val colu
/**
* Produce new linear space with given dimensions
*/
abstract fun produceSpace(rows: Int, columns: Int): LinearSpace<T,V>
abstract fun produceSpace(rows: Int, columns: Int): LinearSpace<T, V>
override val zero: V by lazy {
produce { _, _ -> field.zero }
@ -37,9 +39,9 @@ abstract class LinearSpace<T, V: LinearStructure<out T>>(val rows: Int, val colu
* Dot product
*/
fun multiply(a: V, b: V): V {
if (a.columns != b.rows) {
if (a.rows != b.columns) {
//TODO replace by specific exception
throw RuntimeException("Dimension mismatch in vector dot product")
error("Dimension mismatch in vector dot product")
}
return produceSpace(a.rows, b.columns).produce { i, j ->
(0..a.columns).asSequence().map { k -> field.multiply(a[i, k], b[k, j]) }.reduce { first, second -> field.add(first, second) }
@ -52,9 +54,10 @@ abstract class LinearSpace<T, V: LinearStructure<out T>>(val rows: Int, val colu
/**
* A matrix-like structure that is not dependent on specific space implementation
*/
interface LinearStructure<T> {
interface LinearStructure<T : Any> {
val rows: Int
val columns: Int
operator fun get(i: Int, j: Int): T
fun transpose(): LinearStructure<T> {
@ -66,52 +69,76 @@ interface LinearStructure<T> {
}
}
class RealArraySpace(rows: Int, columns: Int) : LinearSpace<Double, RealArray>(rows, columns, DoubleField) {
override fun produce(initializer: (Int, Int) -> Double): RealArray {
return RealArray(this, initializer)
}
interface Vector<T : Any> : LinearStructure<T> {
override val columns: Int
get() = 1
override fun produceSpace(rows: Int, columns: Int): LinearSpace<Double, RealArray> {
return RealArraySpace(rows, columns)
}
}
class RealArray(override val context: RealArraySpace, initializer: (Int, Int) -> Double): LinearStructure<Double>, SpaceElement<RealArray> {
val array: Array<Array<Double>> = Array(context.rows) { i -> Array(context.columns) { j -> initializer(i, j) } }
override val rows: Int = context.rows
override val columns: Int = context.columns
override fun get(i: Int, j: Int): Double {
return array[i][j]
}
override val self: RealArray = this
operator fun get(i: Int) = get(i, 0)
}
///**
// * An element of linear algebra with fixed dimension. The linear space allows linear operations on objects of the same dimensions.
// * Scalar product operations are performed outside space.
// *
// * @param E the type of linear object element type.
// */
//interface LinearObject<E> : SpaceElement<LinearObject<E>> {
// override val context: LinearSpace<>
//
// val rows: Int
// val columns: Int
// operator fun get(i: Int, j: Int): E
//
// /**
// * Get a transposed object with switched dimensions
// */
// fun transpose(): LinearObject<E>
//
// /**
// * Perform scalar multiplication (dot) operation, checking dimensions. The argument object and result both could be outside initial space.
// */
// operator fun times(other: LinearObject<E>): LinearObject<E>
//}
/**
* DoubleArray-based implementation of vector space
*/
class ArraySpace<T : Any>(rows: Int, columns: Int, field: Field<T>) : LinearSpace<T, LinearStructure<out T>>(rows, columns, field) {
override fun produce(initializer: (Int, Int) -> T): LinearStructure<T> = ArrayMatrix<T>(this, initializer)
override fun produceSpace(rows: Int, columns: Int): LinearSpace<T, LinearStructure<out T>> {
return ArraySpace(rows, columns, field)
}
}
/**
* Member of [ArraySpace] which wraps 2-D array
*/
class ArrayMatrix<T : Any>(override val context: ArraySpace<T>, initializer: (Int, Int) -> T) : LinearStructure<T>, SpaceElement<LinearStructure<out T>, ArraySpace<T>> {
val list: List<List<T>> = (0 until rows).map { i -> (0 until columns).map { j -> initializer(i, j) } }
override val rows: Int get() = context.rows
override val columns: Int get() = context.columns
override fun get(i: Int, j: Int): T {
return list[i][j]
}
override val self: ArrayMatrix<T> get() = this
}
class ArrayVector<T : Any>(override val context: ArraySpace<T>, initializer: (Int) -> T) : Vector<T>, SpaceElement<LinearStructure<out T>, ArraySpace<T>> {
init {
if (context.columns != 1) {
error("Vector must have single column")
}
}
val list: List<T> = (0 until context.rows).map(initializer)
override val rows: Int get() = context.rows
override val columns: Int = 1
override fun get(i: Int, j: Int): T {
return list[i]
}
override val self: ArrayVector<T> get() = this
}
fun <T : Any> vector(size: Int, field: Field<T>, initializer: (Int) -> T) = ArrayVector(ArraySpace(size, 1, field), initializer)
//TODO replace by primitive array version
fun realVector(size: Int, initializer: (Int) -> Double) = vector(size, DoubleField, initializer)
fun <T : Any> Array<T>.asVector(field: Field<T>) = vector(size, field) { this[it] }
//TODO add inferred field from field element
fun DoubleArray.asVector() = realVector(this.size) { this[it] }
fun <T : Any> matrix(rows: Int, columns: Int, field: Field<T>, initializer: (Int, Int) -> T) = ArrayMatrix<T>(ArraySpace(rows, columns, field), initializer)
fun realMatrix(rows: Int, columns: Int, initializer: (Int, Int) -> Double) = matrix(rows, columns, DoubleField, initializer)

View File

@ -3,6 +3,9 @@ package scientifik.kmath.structures
import scientifik.kmath.operations.Field
import scientifik.kmath.operations.FieldElement
/**
* An exception is thrown when the expected ans actual shape of NDArray differs
*/
class ShapeMismatchException(val expected: List<Int>, val actual: List<Int>) : RuntimeException()
/**
@ -11,9 +14,11 @@ class ShapeMismatchException(val expected: List<Int>, val actual: List<Int>) : R
* @param field - operations field defined on individual array element
* @param T the type of the element contained in NDArray
*/
abstract class NDField<T>(val shape: List<Int>, val field: Field<T>) : Field<NDArray<T>> {
abstract class NDField<T>(val shape: List<Int>, private val field: Field<T>) : Field<NDArray<T>> {
/**
* Create new instance of NDArray using field shape and given initializer
* The producer takes list of indices as argument and returns contained value
*/
abstract fun produce(initializer: (List<Int>) -> T): NDArray<T>
@ -21,6 +26,9 @@ abstract class NDField<T>(val shape: List<Int>, val field: Field<T>) : Field<NDA
produce { this.field.zero }
}
/**
* Check the shape of given NDArray and throw exception if it does not coincide with shape of the field
*/
private fun checkShape(vararg arrays: NDArray<T>) {
arrays.forEach {
if (shape != it.shape) {
@ -66,13 +74,13 @@ abstract class NDField<T>(val shape: List<Int>, val field: Field<T>) : Field<NDA
}
interface NDArray<T> : FieldElement<NDArray<T>>, Iterable<Pair<List<Int>, T>> {
interface NDArray<T> : FieldElement<NDArray<T>, NDField<T>>, Iterable<Pair<List<Int>, T>> {
/**
* The list of dimensions of this NDArray
*/
val shape: List<Int>
get() = (context as NDField<T>).shape
get() = context.shape
/**
* The number of dimentsions for this array

View File

@ -69,7 +69,7 @@ private class RealNDField(shape: List<Int>) : NDField<Double>(shape, DoubleField
//TODO generate fixed hash code for quick comparison?
override val self: NDArray<Double> = this
override val self: NDArray<Double> get() = this
}
}

View File

@ -0,0 +1,26 @@
package scientifik.kmath.structures
import org.junit.Assert.assertEquals
import org.junit.Test
class ArrayMatrixTest {
@Test
fun testSum() {
val vector1 = realVector(5) { it.toDouble() }
val vector2 = realVector(5) { 5 - it.toDouble() }
val sum = vector1 + vector2
assertEquals(5.0, sum[2, 0], 0.1)
}
@Test
fun testDot() {
val vector1 = realVector(5) { it.toDouble() }
val vector2 = realVector(5) { 5 - it.toDouble() }
val product = with(vector1.context) {
vector1 dot (vector2.transpose())
}
assertEquals(10.0, product[1, 0], 0.1)
}
}