forked from kscience/kmath
Abstract linear algebra and real value array implementation.
This commit is contained in:
parent
af446d2d2c
commit
eb4c9a4b94
@ -1,5 +1,5 @@
|
||||
buildscript {
|
||||
ext.kotlin_version = '1.2.41'
|
||||
ext.kotlin_version = '1.2.51'
|
||||
|
||||
repositories {
|
||||
mavenCentral()
|
||||
|
@ -7,8 +7,10 @@ import scientifik.kmath.operations.SpaceElement
|
||||
|
||||
/**
|
||||
* The space for linear elements. Supports scalar product alongside with standard linear operations.
|
||||
* @param T type of individual element of the vector or matrix
|
||||
* @param V the type of vector space element
|
||||
*/
|
||||
abstract class LinearSpace<T, V: LinearStructure<out T>>(val rows: Int, val columns: Int, val field: Field<T>) : Space<V> {
|
||||
abstract class LinearSpace<T : Any, V : LinearStructure<out T>>(val rows: Int, val columns: Int, val field: Field<T>) : Space<V> {
|
||||
|
||||
/**
|
||||
* Produce the element of this space
|
||||
@ -18,7 +20,7 @@ abstract class LinearSpace<T, V: LinearStructure<out T>>(val rows: Int, val colu
|
||||
/**
|
||||
* Produce new linear space with given dimensions
|
||||
*/
|
||||
abstract fun produceSpace(rows: Int, columns: Int): LinearSpace<T,V>
|
||||
abstract fun produceSpace(rows: Int, columns: Int): LinearSpace<T, V>
|
||||
|
||||
override val zero: V by lazy {
|
||||
produce { _, _ -> field.zero }
|
||||
@ -37,9 +39,9 @@ abstract class LinearSpace<T, V: LinearStructure<out T>>(val rows: Int, val colu
|
||||
* Dot product
|
||||
*/
|
||||
fun multiply(a: V, b: V): V {
|
||||
if (a.columns != b.rows) {
|
||||
if (a.rows != b.columns) {
|
||||
//TODO replace by specific exception
|
||||
throw RuntimeException("Dimension mismatch in vector dot product")
|
||||
error("Dimension mismatch in vector dot product")
|
||||
}
|
||||
return produceSpace(a.rows, b.columns).produce { i, j ->
|
||||
(0..a.columns).asSequence().map { k -> field.multiply(a[i, k], b[k, j]) }.reduce { first, second -> field.add(first, second) }
|
||||
@ -52,9 +54,10 @@ abstract class LinearSpace<T, V: LinearStructure<out T>>(val rows: Int, val colu
|
||||
/**
|
||||
* A matrix-like structure that is not dependent on specific space implementation
|
||||
*/
|
||||
interface LinearStructure<T> {
|
||||
interface LinearStructure<T : Any> {
|
||||
val rows: Int
|
||||
val columns: Int
|
||||
|
||||
operator fun get(i: Int, j: Int): T
|
||||
|
||||
fun transpose(): LinearStructure<T> {
|
||||
@ -66,52 +69,76 @@ interface LinearStructure<T> {
|
||||
}
|
||||
}
|
||||
|
||||
class RealArraySpace(rows: Int, columns: Int) : LinearSpace<Double, RealArray>(rows, columns, DoubleField) {
|
||||
override fun produce(initializer: (Int, Int) -> Double): RealArray {
|
||||
return RealArray(this, initializer)
|
||||
}
|
||||
interface Vector<T : Any> : LinearStructure<T> {
|
||||
override val columns: Int
|
||||
get() = 1
|
||||
|
||||
override fun produceSpace(rows: Int, columns: Int): LinearSpace<Double, RealArray> {
|
||||
return RealArraySpace(rows, columns)
|
||||
}
|
||||
}
|
||||
|
||||
class RealArray(override val context: RealArraySpace, initializer: (Int, Int) -> Double): LinearStructure<Double>, SpaceElement<RealArray> {
|
||||
|
||||
val array: Array<Array<Double>> = Array(context.rows) { i -> Array(context.columns) { j -> initializer(i, j) } }
|
||||
|
||||
override val rows: Int = context.rows
|
||||
|
||||
override val columns: Int = context.columns
|
||||
|
||||
override fun get(i: Int, j: Int): Double {
|
||||
return array[i][j]
|
||||
}
|
||||
|
||||
override val self: RealArray = this
|
||||
operator fun get(i: Int) = get(i, 0)
|
||||
}
|
||||
|
||||
|
||||
///**
|
||||
// * An element of linear algebra with fixed dimension. The linear space allows linear operations on objects of the same dimensions.
|
||||
// * Scalar product operations are performed outside space.
|
||||
// *
|
||||
// * @param E the type of linear object element type.
|
||||
// */
|
||||
//interface LinearObject<E> : SpaceElement<LinearObject<E>> {
|
||||
// override val context: LinearSpace<>
|
||||
//
|
||||
// val rows: Int
|
||||
// val columns: Int
|
||||
// operator fun get(i: Int, j: Int): E
|
||||
//
|
||||
// /**
|
||||
// * Get a transposed object with switched dimensions
|
||||
// */
|
||||
// fun transpose(): LinearObject<E>
|
||||
//
|
||||
// /**
|
||||
// * Perform scalar multiplication (dot) operation, checking dimensions. The argument object and result both could be outside initial space.
|
||||
// */
|
||||
// operator fun times(other: LinearObject<E>): LinearObject<E>
|
||||
//}
|
||||
/**
|
||||
* DoubleArray-based implementation of vector space
|
||||
*/
|
||||
class ArraySpace<T : Any>(rows: Int, columns: Int, field: Field<T>) : LinearSpace<T, LinearStructure<out T>>(rows, columns, field) {
|
||||
|
||||
override fun produce(initializer: (Int, Int) -> T): LinearStructure<T> = ArrayMatrix<T>(this, initializer)
|
||||
|
||||
|
||||
override fun produceSpace(rows: Int, columns: Int): LinearSpace<T, LinearStructure<out T>> {
|
||||
return ArraySpace(rows, columns, field)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Member of [ArraySpace] which wraps 2-D array
|
||||
*/
|
||||
class ArrayMatrix<T : Any>(override val context: ArraySpace<T>, initializer: (Int, Int) -> T) : LinearStructure<T>, SpaceElement<LinearStructure<out T>, ArraySpace<T>> {
|
||||
|
||||
val list: List<List<T>> = (0 until rows).map { i -> (0 until columns).map { j -> initializer(i, j) } }
|
||||
|
||||
override val rows: Int get() = context.rows
|
||||
|
||||
override val columns: Int get() = context.columns
|
||||
|
||||
override fun get(i: Int, j: Int): T {
|
||||
return list[i][j]
|
||||
}
|
||||
|
||||
override val self: ArrayMatrix<T> get() = this
|
||||
}
|
||||
|
||||
|
||||
class ArrayVector<T : Any>(override val context: ArraySpace<T>, initializer: (Int) -> T) : Vector<T>, SpaceElement<LinearStructure<out T>, ArraySpace<T>> {
|
||||
|
||||
init {
|
||||
if (context.columns != 1) {
|
||||
error("Vector must have single column")
|
||||
}
|
||||
}
|
||||
|
||||
val list: List<T> = (0 until context.rows).map(initializer)
|
||||
|
||||
|
||||
override val rows: Int get() = context.rows
|
||||
|
||||
override val columns: Int = 1
|
||||
|
||||
override fun get(i: Int, j: Int): T {
|
||||
return list[i]
|
||||
}
|
||||
|
||||
override val self: ArrayVector<T> get() = this
|
||||
|
||||
}
|
||||
|
||||
fun <T : Any> vector(size: Int, field: Field<T>, initializer: (Int) -> T) = ArrayVector(ArraySpace(size, 1, field), initializer)
|
||||
//TODO replace by primitive array version
|
||||
fun realVector(size: Int, initializer: (Int) -> Double) = vector(size, DoubleField, initializer)
|
||||
|
||||
fun <T : Any> Array<T>.asVector(field: Field<T>) = vector(size, field) { this[it] }
|
||||
//TODO add inferred field from field element
|
||||
fun DoubleArray.asVector() = realVector(this.size) { this[it] }
|
||||
|
||||
fun <T : Any> matrix(rows: Int, columns: Int, field: Field<T>, initializer: (Int, Int) -> T) = ArrayMatrix<T>(ArraySpace(rows, columns, field), initializer)
|
||||
fun realMatrix(rows: Int, columns: Int, initializer: (Int, Int) -> Double) = matrix(rows, columns, DoubleField, initializer)
|
@ -3,6 +3,9 @@ package scientifik.kmath.structures
|
||||
import scientifik.kmath.operations.Field
|
||||
import scientifik.kmath.operations.FieldElement
|
||||
|
||||
/**
|
||||
* An exception is thrown when the expected ans actual shape of NDArray differs
|
||||
*/
|
||||
class ShapeMismatchException(val expected: List<Int>, val actual: List<Int>) : RuntimeException()
|
||||
|
||||
/**
|
||||
@ -11,9 +14,11 @@ class ShapeMismatchException(val expected: List<Int>, val actual: List<Int>) : R
|
||||
* @param field - operations field defined on individual array element
|
||||
* @param T the type of the element contained in NDArray
|
||||
*/
|
||||
abstract class NDField<T>(val shape: List<Int>, val field: Field<T>) : Field<NDArray<T>> {
|
||||
abstract class NDField<T>(val shape: List<Int>, private val field: Field<T>) : Field<NDArray<T>> {
|
||||
|
||||
/**
|
||||
* Create new instance of NDArray using field shape and given initializer
|
||||
* The producer takes list of indices as argument and returns contained value
|
||||
*/
|
||||
abstract fun produce(initializer: (List<Int>) -> T): NDArray<T>
|
||||
|
||||
@ -21,6 +26,9 @@ abstract class NDField<T>(val shape: List<Int>, val field: Field<T>) : Field<NDA
|
||||
produce { this.field.zero }
|
||||
}
|
||||
|
||||
/**
|
||||
* Check the shape of given NDArray and throw exception if it does not coincide with shape of the field
|
||||
*/
|
||||
private fun checkShape(vararg arrays: NDArray<T>) {
|
||||
arrays.forEach {
|
||||
if (shape != it.shape) {
|
||||
@ -66,13 +74,13 @@ abstract class NDField<T>(val shape: List<Int>, val field: Field<T>) : Field<NDA
|
||||
}
|
||||
|
||||
|
||||
interface NDArray<T> : FieldElement<NDArray<T>>, Iterable<Pair<List<Int>, T>> {
|
||||
interface NDArray<T> : FieldElement<NDArray<T>, NDField<T>>, Iterable<Pair<List<Int>, T>> {
|
||||
|
||||
/**
|
||||
* The list of dimensions of this NDArray
|
||||
*/
|
||||
val shape: List<Int>
|
||||
get() = (context as NDField<T>).shape
|
||||
get() = context.shape
|
||||
|
||||
/**
|
||||
* The number of dimentsions for this array
|
||||
|
@ -69,7 +69,7 @@ private class RealNDField(shape: List<Int>) : NDField<Double>(shape, DoubleField
|
||||
//TODO generate fixed hash code for quick comparison?
|
||||
|
||||
|
||||
override val self: NDArray<Double> = this
|
||||
override val self: NDArray<Double> get() = this
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -0,0 +1,26 @@
|
||||
package scientifik.kmath.structures
|
||||
|
||||
import org.junit.Assert.assertEquals
|
||||
import org.junit.Test
|
||||
|
||||
class ArrayMatrixTest {
|
||||
|
||||
@Test
|
||||
fun testSum() {
|
||||
val vector1 = realVector(5) { it.toDouble() }
|
||||
val vector2 = realVector(5) { 5 - it.toDouble() }
|
||||
val sum = vector1 + vector2
|
||||
assertEquals(5.0, sum[2, 0], 0.1)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testDot() {
|
||||
val vector1 = realVector(5) { it.toDouble() }
|
||||
val vector2 = realVector(5) { 5 - it.toDouble() }
|
||||
val product = with(vector1.context) {
|
||||
vector1 dot (vector2.transpose())
|
||||
}
|
||||
|
||||
assertEquals(10.0, product[1, 0], 0.1)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user