Dev #11
@ -1,5 +1,5 @@
|
|||||||
buildscript {
|
buildscript {
|
||||||
ext.kotlin_version = '1.2.41'
|
ext.kotlin_version = '1.2.51'
|
||||||
|
|
||||||
repositories {
|
repositories {
|
||||||
mavenCentral()
|
mavenCentral()
|
||||||
|
@ -7,8 +7,10 @@ import scientifik.kmath.operations.SpaceElement
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* The space for linear elements. Supports scalar product alongside with standard linear operations.
|
* The space for linear elements. Supports scalar product alongside with standard linear operations.
|
||||||
|
* @param T type of individual element of the vector or matrix
|
||||||
|
* @param V the type of vector space element
|
||||||
*/
|
*/
|
||||||
abstract class LinearSpace<T, V: LinearStructure<out T>>(val rows: Int, val columns: Int, val field: Field<T>) : Space<V> {
|
abstract class LinearSpace<T : Any, V : LinearStructure<out T>>(val rows: Int, val columns: Int, val field: Field<T>) : Space<V> {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Produce the element of this space
|
* Produce the element of this space
|
||||||
@ -18,7 +20,7 @@ abstract class LinearSpace<T, V: LinearStructure<out T>>(val rows: Int, val colu
|
|||||||
/**
|
/**
|
||||||
* Produce new linear space with given dimensions
|
* Produce new linear space with given dimensions
|
||||||
*/
|
*/
|
||||||
abstract fun produceSpace(rows: Int, columns: Int): LinearSpace<T,V>
|
abstract fun produceSpace(rows: Int, columns: Int): LinearSpace<T, V>
|
||||||
|
|
||||||
override val zero: V by lazy {
|
override val zero: V by lazy {
|
||||||
produce { _, _ -> field.zero }
|
produce { _, _ -> field.zero }
|
||||||
@ -37,9 +39,9 @@ abstract class LinearSpace<T, V: LinearStructure<out T>>(val rows: Int, val colu
|
|||||||
* Dot product
|
* Dot product
|
||||||
*/
|
*/
|
||||||
fun multiply(a: V, b: V): V {
|
fun multiply(a: V, b: V): V {
|
||||||
if (a.columns != b.rows) {
|
if (a.rows != b.columns) {
|
||||||
//TODO replace by specific exception
|
//TODO replace by specific exception
|
||||||
throw RuntimeException("Dimension mismatch in vector dot product")
|
error("Dimension mismatch in vector dot product")
|
||||||
}
|
}
|
||||||
return produceSpace(a.rows, b.columns).produce { i, j ->
|
return produceSpace(a.rows, b.columns).produce { i, j ->
|
||||||
(0..a.columns).asSequence().map { k -> field.multiply(a[i, k], b[k, j]) }.reduce { first, second -> field.add(first, second) }
|
(0..a.columns).asSequence().map { k -> field.multiply(a[i, k], b[k, j]) }.reduce { first, second -> field.add(first, second) }
|
||||||
@ -52,9 +54,10 @@ abstract class LinearSpace<T, V: LinearStructure<out T>>(val rows: Int, val colu
|
|||||||
/**
|
/**
|
||||||
* A matrix-like structure that is not dependent on specific space implementation
|
* A matrix-like structure that is not dependent on specific space implementation
|
||||||
*/
|
*/
|
||||||
interface LinearStructure<T> {
|
interface LinearStructure<T : Any> {
|
||||||
val rows: Int
|
val rows: Int
|
||||||
val columns: Int
|
val columns: Int
|
||||||
|
|
||||||
operator fun get(i: Int, j: Int): T
|
operator fun get(i: Int, j: Int): T
|
||||||
|
|
||||||
fun transpose(): LinearStructure<T> {
|
fun transpose(): LinearStructure<T> {
|
||||||
@ -66,52 +69,76 @@ interface LinearStructure<T> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
class RealArraySpace(rows: Int, columns: Int) : LinearSpace<Double, RealArray>(rows, columns, DoubleField) {
|
interface Vector<T : Any> : LinearStructure<T> {
|
||||||
override fun produce(initializer: (Int, Int) -> Double): RealArray {
|
override val columns: Int
|
||||||
return RealArray(this, initializer)
|
get() = 1
|
||||||
}
|
|
||||||
|
|
||||||
override fun produceSpace(rows: Int, columns: Int): LinearSpace<Double, RealArray> {
|
operator fun get(i: Int) = get(i, 0)
|
||||||
return RealArraySpace(rows, columns)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
class RealArray(override val context: RealArraySpace, initializer: (Int, Int) -> Double): LinearStructure<Double>, SpaceElement<RealArray> {
|
|
||||||
|
|
||||||
val array: Array<Array<Double>> = Array(context.rows) { i -> Array(context.columns) { j -> initializer(i, j) } }
|
|
||||||
|
|
||||||
override val rows: Int = context.rows
|
|
||||||
|
|
||||||
override val columns: Int = context.columns
|
|
||||||
|
|
||||||
override fun get(i: Int, j: Int): Double {
|
|
||||||
return array[i][j]
|
|
||||||
}
|
|
||||||
|
|
||||||
override val self: RealArray = this
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
///**
|
/**
|
||||||
// * An element of linear algebra with fixed dimension. The linear space allows linear operations on objects of the same dimensions.
|
* DoubleArray-based implementation of vector space
|
||||||
// * Scalar product operations are performed outside space.
|
*/
|
||||||
// *
|
class ArraySpace<T : Any>(rows: Int, columns: Int, field: Field<T>) : LinearSpace<T, LinearStructure<out T>>(rows, columns, field) {
|
||||||
// * @param E the type of linear object element type.
|
|
||||||
// */
|
override fun produce(initializer: (Int, Int) -> T): LinearStructure<T> = ArrayMatrix<T>(this, initializer)
|
||||||
//interface LinearObject<E> : SpaceElement<LinearObject<E>> {
|
|
||||||
// override val context: LinearSpace<>
|
|
||||||
//
|
override fun produceSpace(rows: Int, columns: Int): LinearSpace<T, LinearStructure<out T>> {
|
||||||
// val rows: Int
|
return ArraySpace(rows, columns, field)
|
||||||
// val columns: Int
|
}
|
||||||
// operator fun get(i: Int, j: Int): E
|
}
|
||||||
//
|
|
||||||
// /**
|
/**
|
||||||
// * Get a transposed object with switched dimensions
|
* Member of [ArraySpace] which wraps 2-D array
|
||||||
// */
|
*/
|
||||||
// fun transpose(): LinearObject<E>
|
class ArrayMatrix<T : Any>(override val context: ArraySpace<T>, initializer: (Int, Int) -> T) : LinearStructure<T>, SpaceElement<LinearStructure<out T>, ArraySpace<T>> {
|
||||||
//
|
|
||||||
// /**
|
val list: List<List<T>> = (0 until rows).map { i -> (0 until columns).map { j -> initializer(i, j) } }
|
||||||
// * Perform scalar multiplication (dot) operation, checking dimensions. The argument object and result both could be outside initial space.
|
|
||||||
// */
|
override val rows: Int get() = context.rows
|
||||||
// operator fun times(other: LinearObject<E>): LinearObject<E>
|
|
||||||
//}
|
override val columns: Int get() = context.columns
|
||||||
|
|
||||||
|
override fun get(i: Int, j: Int): T {
|
||||||
|
return list[i][j]
|
||||||
|
}
|
||||||
|
|
||||||
|
override val self: ArrayMatrix<T> get() = this
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ArrayVector<T : Any>(override val context: ArraySpace<T>, initializer: (Int) -> T) : Vector<T>, SpaceElement<LinearStructure<out T>, ArraySpace<T>> {
|
||||||
|
|
||||||
|
init {
|
||||||
|
if (context.columns != 1) {
|
||||||
|
error("Vector must have single column")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
val list: List<T> = (0 until context.rows).map(initializer)
|
||||||
|
|
||||||
|
|
||||||
|
override val rows: Int get() = context.rows
|
||||||
|
|
||||||
|
override val columns: Int = 1
|
||||||
|
|
||||||
|
override fun get(i: Int, j: Int): T {
|
||||||
|
return list[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
override val self: ArrayVector<T> get() = this
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
fun <T : Any> vector(size: Int, field: Field<T>, initializer: (Int) -> T) = ArrayVector(ArraySpace(size, 1, field), initializer)
|
||||||
|
//TODO replace by primitive array version
|
||||||
|
fun realVector(size: Int, initializer: (Int) -> Double) = vector(size, DoubleField, initializer)
|
||||||
|
|
||||||
|
fun <T : Any> Array<T>.asVector(field: Field<T>) = vector(size, field) { this[it] }
|
||||||
|
//TODO add inferred field from field element
|
||||||
|
fun DoubleArray.asVector() = realVector(this.size) { this[it] }
|
||||||
|
|
||||||
|
fun <T : Any> matrix(rows: Int, columns: Int, field: Field<T>, initializer: (Int, Int) -> T) = ArrayMatrix<T>(ArraySpace(rows, columns, field), initializer)
|
||||||
|
fun realMatrix(rows: Int, columns: Int, initializer: (Int, Int) -> Double) = matrix(rows, columns, DoubleField, initializer)
|
@ -3,6 +3,9 @@ package scientifik.kmath.structures
|
|||||||
import scientifik.kmath.operations.Field
|
import scientifik.kmath.operations.Field
|
||||||
import scientifik.kmath.operations.FieldElement
|
import scientifik.kmath.operations.FieldElement
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An exception is thrown when the expected ans actual shape of NDArray differs
|
||||||
|
*/
|
||||||
class ShapeMismatchException(val expected: List<Int>, val actual: List<Int>) : RuntimeException()
|
class ShapeMismatchException(val expected: List<Int>, val actual: List<Int>) : RuntimeException()
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -11,9 +14,11 @@ class ShapeMismatchException(val expected: List<Int>, val actual: List<Int>) : R
|
|||||||
* @param field - operations field defined on individual array element
|
* @param field - operations field defined on individual array element
|
||||||
* @param T the type of the element contained in NDArray
|
* @param T the type of the element contained in NDArray
|
||||||
*/
|
*/
|
||||||
abstract class NDField<T>(val shape: List<Int>, val field: Field<T>) : Field<NDArray<T>> {
|
abstract class NDField<T>(val shape: List<Int>, private val field: Field<T>) : Field<NDArray<T>> {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create new instance of NDArray using field shape and given initializer
|
* Create new instance of NDArray using field shape and given initializer
|
||||||
|
* The producer takes list of indices as argument and returns contained value
|
||||||
*/
|
*/
|
||||||
abstract fun produce(initializer: (List<Int>) -> T): NDArray<T>
|
abstract fun produce(initializer: (List<Int>) -> T): NDArray<T>
|
||||||
|
|
||||||
@ -21,6 +26,9 @@ abstract class NDField<T>(val shape: List<Int>, val field: Field<T>) : Field<NDA
|
|||||||
produce { this.field.zero }
|
produce { this.field.zero }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check the shape of given NDArray and throw exception if it does not coincide with shape of the field
|
||||||
|
*/
|
||||||
private fun checkShape(vararg arrays: NDArray<T>) {
|
private fun checkShape(vararg arrays: NDArray<T>) {
|
||||||
arrays.forEach {
|
arrays.forEach {
|
||||||
if (shape != it.shape) {
|
if (shape != it.shape) {
|
||||||
@ -66,13 +74,13 @@ abstract class NDField<T>(val shape: List<Int>, val field: Field<T>) : Field<NDA
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
interface NDArray<T> : FieldElement<NDArray<T>>, Iterable<Pair<List<Int>, T>> {
|
interface NDArray<T> : FieldElement<NDArray<T>, NDField<T>>, Iterable<Pair<List<Int>, T>> {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The list of dimensions of this NDArray
|
* The list of dimensions of this NDArray
|
||||||
*/
|
*/
|
||||||
val shape: List<Int>
|
val shape: List<Int>
|
||||||
get() = (context as NDField<T>).shape
|
get() = context.shape
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The number of dimentsions for this array
|
* The number of dimentsions for this array
|
||||||
|
@ -69,7 +69,7 @@ private class RealNDField(shape: List<Int>) : NDField<Double>(shape, DoubleField
|
|||||||
//TODO generate fixed hash code for quick comparison?
|
//TODO generate fixed hash code for quick comparison?
|
||||||
|
|
||||||
|
|
||||||
override val self: NDArray<Double> = this
|
override val self: NDArray<Double> get() = this
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -0,0 +1,26 @@
|
|||||||
|
package scientifik.kmath.structures
|
||||||
|
|
||||||
|
import org.junit.Assert.assertEquals
|
||||||
|
import org.junit.Test
|
||||||
|
|
||||||
|
class ArrayMatrixTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testSum() {
|
||||||
|
val vector1 = realVector(5) { it.toDouble() }
|
||||||
|
val vector2 = realVector(5) { 5 - it.toDouble() }
|
||||||
|
val sum = vector1 + vector2
|
||||||
|
assertEquals(5.0, sum[2, 0], 0.1)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testDot() {
|
||||||
|
val vector1 = realVector(5) { it.toDouble() }
|
||||||
|
val vector2 = realVector(5) { 5 - it.toDouble() }
|
||||||
|
val product = with(vector1.context) {
|
||||||
|
vector1 dot (vector2.transpose())
|
||||||
|
}
|
||||||
|
|
||||||
|
assertEquals(10.0, product[1, 0], 0.1)
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user