[WIP] Refactor NDStructures

This commit is contained in:
Alexander Nozik 2021-01-23 19:19:13 +03:00
parent 061398b009
commit 332c04b573
20 changed files with 120 additions and 83 deletions

View File

@ -1,17 +1,17 @@
package kscience.kmath.operations
import kscience.kmath.nd.NDField
import kscience.kmath.structures.NDElement
import kscience.kmath.nd.NDAlgebra
import kscience.kmath.nd.complex
fun main() {
// 2d element
val element = NDElement.complex(2, 2) { (i,j) ->
val element = NDAlgebra.complex(2, 2).produce { (i,j) ->
Complex(i.toDouble() - j.toDouble(), i.toDouble() + j.toDouble())
}
println(element)
// 1d element operation
val result = with(NDField.complex(8)) {
val result = with(NDAlgebra.complex(8)) {
val a = produce { (it) -> i * it - it.toDouble() }
val b = 3
val c = Complex(1.0, 1.0)

View File

@ -1,7 +1,10 @@
package kscience.kmath.structures
import kotlinx.coroutines.GlobalScope
import kscience.kmath.nd.*
import kscience.kmath.nd.NDAlgebra
import kscience.kmath.nd.NDStructure
import kscience.kmath.nd.field
import kscience.kmath.nd.real
import kscience.kmath.nd4j.Nd4jArrayField
import kscience.kmath.operations.RealField
import kscience.kmath.operations.invoke
@ -39,9 +42,11 @@ fun main() {
}
measureAndPrint("Element addition") {
var res: NDStructure<Double> = boxingField.one
boxingField {
var res: NDStructure<Double> = one
repeat(n) { res += 1.0 }
}
}
measureAndPrint("Specialized addition") {
specializedField {
@ -52,7 +57,7 @@ fun main() {
measureAndPrint("Nd4j specialized addition") {
nd4jField {
var res = one
var res:NDStructure<Double> = one
repeat(n) { res += 1.0 }
}
}

View File

@ -1,10 +1,7 @@
package kscience.kmath.commons.linear
import kscience.kmath.linear.DiagonalFeature
import kscience.kmath.linear.MatrixContext
import kscience.kmath.linear.Point
import kscience.kmath.linear.*
import kscience.kmath.misc.UnstableKMathAPI
import kscience.kmath.nd.Matrix
import org.apache.commons.math3.linear.*
import kotlin.reflect.KClass
import kotlin.reflect.cast

View File

@ -1,7 +1,7 @@
package kscience.kmath.commons.linear
import kscience.kmath.linear.Matrix
import kscience.kmath.linear.Point
import kscience.kmath.nd.Matrix
import org.apache.commons.math3.linear.*
public enum class CMDecomposition {

View File

@ -86,7 +86,7 @@ public fun <T, A : Space<T>> NDAlgebra.Companion.space(
vararg shape: Int,
): BufferedNDSpace<T, A> = BufferedNDSpace(shape, space, bufferFactory)
public inline fun <T, A : Space<T>, R> A.nd(
public inline fun <T, A : Space<T>, R> A.ndSpace(
noinline bufferFactory: BufferFactory<T>,
vararg shape: Int,
action: BufferedNDSpace<T, A>.() -> R,
@ -102,7 +102,7 @@ public fun <T, A : Ring<T>> NDAlgebra.Companion.ring(
vararg shape: Int,
): BufferedNDRing<T, A> = BufferedNDRing(shape, ring, bufferFactory)
public inline fun <T, A : Ring<T>, R> A.nd(
public inline fun <T, A : Ring<T>, R> A.ndRing(
noinline bufferFactory: BufferFactory<T>,
vararg shape: Int,
action: BufferedNDRing<T, A>.() -> R,
@ -118,7 +118,7 @@ public fun <T, A : Field<T>> NDAlgebra.Companion.field(
vararg shape: Int,
): BufferedNDField<T, A> = BufferedNDField(shape, field, bufferFactory)
public inline fun <T, A : Field<T>, R> A.nd(
public inline fun <T, A : Field<T>, R> A.ndField(
noinline bufferFactory: BufferFactory<T>,
vararg shape: Int,
action: BufferedNDField<T, A>.() -> R,

View File

@ -94,9 +94,12 @@ public open class RealNDField(
/**
* Fast element production using function inlining
*/
public inline fun BufferedNDField<Double, RealField>.produceInline(crossinline initializer: RealField.(Int) -> Double): NDBuffer<Double> {
public inline fun BufferedNDField<Double, RealField>.produceInline(crossinline initializer: RealField.(IntArray) -> Double): NDBuffer<Double> {
contract { callsInPlace(initializer, InvocationKind.EXACTLY_ONCE) }
val array = DoubleArray(strides.linearSize) { offset -> RealField.initializer(offset) }
val array = DoubleArray(strides.linearSize) { offset ->
val index = strides.index(offset)
RealField.initializer(index)
}
return NDBuffer(strides, RealBuffer(array))
}

View File

@ -52,7 +52,15 @@ public interface Structure2D<T> : NDStructure<T> {
for (j in 0 until colNum) yield(intArrayOf(i, j) to get(i, j))
}
public companion object
public companion object {
public inline fun real(
rows: Int,
columns: Int,
crossinline init: (i: Int, j: Int) -> Double,
): NDBuffer<Double> = NDAlgebra.real(rows, columns).produceInline { (i, j) ->
init(i, j)
}
}
}
/**

View File

@ -1,10 +1,12 @@
package kscience.kmath.operations
import kscience.kmath.misc.UnstableKMathAPI
import kscience.kmath.nd.BufferedNDRing
import kscience.kmath.nd.NDAlgebra
import kscience.kmath.operations.BigInt.Companion.BASE
import kscience.kmath.operations.BigInt.Companion.BASE_SIZE
import kscience.kmath.structures.*
import kscience.kmath.structures.Buffer
import kscience.kmath.structures.MutableBuffer
import kotlin.math.log2
import kotlin.math.max
import kotlin.math.min
@ -463,10 +465,5 @@ public inline fun Buffer.Companion.bigInt(size: Int, initializer: (Int) -> BigIn
public inline fun MutableBuffer.Companion.bigInt(size: Int, initializer: (Int) -> BigInt): MutableBuffer<BigInt> =
boxing(size, initializer)
public fun NDAlgebra.Companion.bigInt(vararg shape: Int): BoxingNDRing<BigInt, BigIntField> =
BoxingNDRing(shape, BigIntField, Buffer.Companion::bigInt)
public fun NDElement.Companion.bigInt(
vararg shape: Int,
initializer: BigIntField.(IntArray) -> BigInt
): BufferedNDRingElement<BigInt, BigIntField> = NDAlgebra.bigInt(*shape).produce(initializer)
public fun NDAlgebra.Companion.bigInt(vararg shape: Int): BufferedNDRing<BigInt, BigIntField> =
BufferedNDRing(shape, BigIntField, Buffer.Companion::bigInt)

View File

@ -43,7 +43,7 @@ public interface Buffer<T> {
* Checks content equality with another buffer.
*/
public fun contentEquals(other: Buffer<*>): Boolean =
kscience.kmath.nd.mapIndexed { index, value -> value == other[index] }.all { it }
asSequence().mapIndexed { index, value -> value == other[index] }.all { it }
public companion object {
/**

View File

@ -60,7 +60,7 @@ public class FlaggedRealBuffer(public val values: DoubleArray, public val flags:
override operator fun get(index: Int): Double? = if (isValid(index)) values[index] else null
override operator fun iterator(): Iterator<Double?> = kscience.kmath.nd.map {
override operator fun iterator(): Iterator<Double?> = values.indices.asSequence().map {
if (isValid(it)) values[it] else null
}.iterator()
}

View File

@ -15,7 +15,7 @@ public open class MemoryBuffer<T : Any>(protected val memory: Memory, protected
private val reader: MemoryReader = memory.reader()
override operator fun get(index: Int): T = reader.read(spec, spec.objectSize * index)
override operator fun iterator(): Iterator<T> = kscience.kmath.nd.map { get(it) }.iterator()
override operator fun iterator(): Iterator<T> = (0 until size).asSequence().map { get(it) }.iterator()
public companion object {
public fun <T : Any> create(spec: MemorySpec<T>, size: Int): MemoryBuffer<T> =

View File

@ -1,6 +1,8 @@
package kscience.kmath.structures
import kscience.kmath.nd.NDField
import kscience.kmath.nd.NDAlgebra
import kscience.kmath.nd.get
import kscience.kmath.nd.real
import kscience.kmath.operations.internal.FieldVerifier
import kotlin.test.Test
import kotlin.test.assertEquals
@ -8,12 +10,12 @@ import kotlin.test.assertEquals
internal class NDFieldTest {
@Test
fun verify() {
NDField.real(12, 32).run { FieldVerifier(this, one + 3, one - 23, one * 12, 6.66) }
NDAlgebra.real(12, 32).run { FieldVerifier(this, one + 3, one - 23, one * 12, 6.66) }
}
@Test
fun testStrides() {
val ndArray = NDElement.real(intArrayOf(10, 10)) { (it[0] + it[1]).toDouble() }
val ndArray = NDAlgebra.real(10, 10).produce { (it[0] + it[1]).toDouble() }
assertEquals(ndArray[5, 5], 10.0)
}
}

View File

@ -1,9 +1,8 @@
package kscience.kmath.structures
import kscience.kmath.nd.NDField
import kscience.kmath.nd.NDStructure
import kscience.kmath.nd.*
import kscience.kmath.operations.Norm
import kscience.kmath.structures.NDElement.Companion.real2D
import kscience.kmath.operations.invoke
import kotlin.math.abs
import kotlin.math.pow
import kotlin.test.Test
@ -11,25 +10,30 @@ import kotlin.test.assertEquals
@Suppress("UNUSED_VARIABLE")
class NumberNDFieldTest {
val array1: RealNDElement = real2D(3, 3) { i, j -> (i + j).toDouble() }
val array2: RealNDElement = real2D(3, 3) { i, j -> (i - j).toDouble() }
val algebra = NDAlgebra.real(3,3)
val array1 = algebra.produceInline { (i, j) -> (i + j).toDouble() }
val array2 = algebra.produceInline { (i, j) -> (i - j).toDouble() }
@Test
fun testSum() {
algebra {
val sum = array1 + array2
assertEquals(4.0, sum[2, 2])
}
}
@Test
fun testProduct() {
algebra {
val product = array1 * array2
assertEquals(0.0, product[2, 2])
}
}
@Test
fun testGeneration() {
val array = real2D(3, 3) { i, j -> (i * 10 + j).toDouble() }
val array = Structure2D.real(3, 3) { i, j -> (i * 10 + j).toDouble() }
for (i in 0..2) {
for (j in 0..2) {
@ -41,17 +45,21 @@ class NumberNDFieldTest {
@Test
fun testExternalFunction() {
algebra {
val function: (Double) -> Double = { x -> x.pow(2) + 2 * x + 1 }
val result = function(array1) + 1.0
assertEquals(10.0, result[1, 1])
}
}
@Test
fun testLibraryFunction() {
algebra {
val abs: (Double) -> Double = ::abs
val result = abs(array2)
assertEquals(2.0, result[0, 2])
}
}
@Test
fun combineTest() {
@ -65,6 +73,8 @@ class NumberNDFieldTest {
@Test
fun testInternalContext() {
(NDField.real(*array1.shape)) { with(L2Norm) { 1 + norm(array1) + exp(array2) } }
algebra {
(NDAlgebra.real(*array1.shape)) { with(L2Norm) { 1 + norm(array1) + exp(array2) } }
}
}
}

View File

@ -21,7 +21,7 @@ public class LazyNDStructure<T>(
public override fun elements(): Sequence<Pair<IntArray, T>> {
val strides = DefaultStrides(shape)
val res = runBlocking { kscience.kmath.nd.map { index -> index to await(index) } }
val res = runBlocking { strides.indices().toList().map { index -> index to await(index) } }
return res.asSequence()
}

View File

@ -1,10 +1,7 @@
package kscience.kmath.ejml
import kscience.kmath.linear.InverseMatrixFeature
import kscience.kmath.linear.MatrixContext
import kscience.kmath.linear.Point
import kscience.kmath.linear.*
import kscience.kmath.misc.UnstableKMathAPI
import kscience.kmath.nd.Matrix
import kscience.kmath.nd.getFeature
import org.ejml.simple.SimpleMatrix

View File

@ -1,11 +1,11 @@
package kscience.kmath.real
import kscience.kmath.linear.MatrixContext
import kscience.kmath.linear.real
import kscience.kmath.linear.*
import kscience.kmath.misc.UnstableKMathAPI
import kscience.kmath.nd.Matrix
import kscience.kmath.structures.Buffer
import kscience.kmath.structures.RealBuffer
import kscience.kmath.structures.asIterable
import kotlin.math.pow
/*
* Functions for convenient "numpy-like" operations with Double matrices.

View File

@ -1,7 +1,9 @@
package kaceince.kmath.real
import kscience.kmath.nd.Matrix
import kscience.kmath.linear.Matrix
import kscience.kmath.linear.build
import kscience.kmath.real.*
import kscience.kmath.structures.contentEquals
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertTrue

View File

@ -17,7 +17,7 @@ public data class BinDef<T : Comparable<T>>(
require(vector.size == center.size) { "Dimension mismatch for input vector. Expected ${center.size}, but found ${vector.size}" }
val upper = space { center + sizes / 2.0 }
val lower = space { center - sizes / 2.0 }
return kscience.kmath.nd.mapIndexed { i, value -> value in lower[i]..upper[i] }.all { it }
return vector.asSequence().mapIndexed { i, value -> value in lower[i]..upper[i] }.all { it }
}
}
@ -70,7 +70,7 @@ public class RealHistogram(
public fun getValue(point: Buffer<out Double>): Long = getValue(getIndex(point))
private fun getDef(index: IntArray): BinDef<Double> {
val center = kscience.kmath.nd.mapIndexed { axis, i ->
val center = index.mapIndexed { axis, i ->
when (i) {
0 -> Double.NEGATIVE_INFINITY
strides.shape[axis] - 1 -> Double.POSITIVE_INFINITY
@ -100,7 +100,7 @@ public class RealHistogram(
}
public override operator fun iterator(): Iterator<MultivariateBin<Double>> =
kscience.kmath.nd.map { (index, value) -> MultivariateBin(getDef(index), value.sum()) }
weights.elements().map { (index, value) -> MultivariateBin(getDef(index), value.sum()) }
.iterator()
/**

View File

@ -7,6 +7,13 @@ import kscience.kmath.structures.*
import org.nd4j.linalg.api.ndarray.INDArray
import org.nd4j.linalg.factory.Nd4j
internal fun NDAlgebra<*, *>.checkShape(array: INDArray): INDArray {
val arrayShape = array.shape().toIntArray()
if (!shape.contentEquals(arrayShape)) throw ShapeMismatchException(shape, arrayShape)
return array
}
/**
* Represents [NDAlgebra] over [Nd4jArrayAlgebra].
*
@ -18,6 +25,7 @@ public interface Nd4jArrayAlgebra<T, C> : NDAlgebra<T, C> {
* Wraps [INDArray] to [N].
*/
public fun INDArray.wrap(): Nd4jArrayStructure<T>
public val NDStructure<T>.ndArray: INDArray
get() = when {
!shape.contentEquals(this@Nd4jArrayAlgebra.shape) -> throw ShapeMismatchException(
@ -213,7 +221,7 @@ public class RealNd4jArrayField(public override val shape: IntArray) : Nd4jArray
public override val elementContext: RealField
get() = RealField
public override fun INDArray.wrap(): Nd4jArrayStructure<Double> = checkShape(asRealStructure())
public override fun INDArray.wrap(): Nd4jArrayStructure<Double> = checkShape(this).asRealStructure()
public override operator fun NDStructure<Double>.div(arg: Double): Nd4jArrayStructure<Double> {
return ndArray.div(arg).wrap()
@ -247,7 +255,7 @@ public class FloatNd4jArrayField(public override val shape: IntArray) : Nd4jArra
public override val elementContext: FloatField
get() = FloatField
public override fun INDArray.wrap(): Nd4jArrayStructure<Float> = checkShape(asFloatStructure())
public override fun INDArray.wrap(): Nd4jArrayStructure<Float> = checkShape(this).asFloatStructure()
public override operator fun NDStructure<Float>.div(arg: Float): Nd4jArrayStructure<Float> {
return ndArray.div(arg).wrap()
@ -281,7 +289,7 @@ public class IntNd4jArrayRing(public override val shape: IntArray) : Nd4jArrayRi
public override val elementContext: IntRing
get() = IntRing
public override fun INDArray.wrap(): Nd4jArrayStructure<Int> = check(asIntStructure())
public override fun INDArray.wrap(): Nd4jArrayStructure<Int> = checkShape(this).asIntStructure()
public override operator fun NDStructure<Int>.plus(arg: Int): Nd4jArrayStructure<Int> {
return ndArray.add(arg).wrap()
@ -307,7 +315,7 @@ public class LongNd4jArrayRing(public override val shape: IntArray) : Nd4jArrayR
public override val elementContext: LongRing
get() = LongRing
public override fun INDArray.wrap(): Nd4jArrayStructure<Long> = check(asLongStructure())
public override fun INDArray.wrap(): Nd4jArrayStructure<Long> = checkShape(this).asLongStructure()
public override operator fun NDStructure<Long>.plus(arg: Long): Nd4jArrayStructure<Long> {
return ndArray.add(arg).wrap()

View File

@ -1,9 +1,6 @@
package kscience.kmath.viktor
import kscience.kmath.nd.DefaultStrides
import kscience.kmath.nd.MutableNDStructure
import kscience.kmath.nd.NDField
import kscience.kmath.nd.Strides
import kscience.kmath.nd.*
import kscience.kmath.operations.RealField
import org.jetbrains.bio.viktor.F64Array
@ -24,14 +21,25 @@ public inline class ViktorNDStructure(public val f64Buffer: F64Array) : MutableN
public fun F64Array.asStructure(): ViktorNDStructure = ViktorNDStructure(this)
@Suppress("OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
public class ViktorNDField(public override val shape: IntArray) : NDField<Double, RealField, ViktorNDStructure> {
public class ViktorNDField(public override val shape: IntArray) : NDField<Double, RealField> {
public val NDStructure<Double>.f64Buffer: F64Array
get() = when {
!shape.contentEquals(this@ViktorNDField.shape) -> throw ShapeMismatchException(
this@ViktorNDField.shape,
shape
)
this is ViktorNDStructure && this.f64Buffer.shape.contentEquals(this@ViktorNDField.shape) -> this.f64Buffer
else -> produce { this@f64Buffer[it] }.f64Buffer
}
public override val zero: ViktorNDStructure
get() = F64Array.full(init = 0.0, shape = shape).asStructure()
public override val one: ViktorNDStructure
get() = F64Array.full(init = 1.0, shape = shape).asStructure()
public val strides: Strides = DefaultStrides(shape)
private val strides: Strides = DefaultStrides(shape)
public override val elementContext: RealField get() = RealField
@ -42,7 +50,7 @@ public class ViktorNDField(public override val shape: IntArray) : NDField<Double
}
}.asStructure()
public override fun map(arg: ViktorNDStructure, transform: RealField.(Double) -> Double): ViktorNDStructure =
public override fun map(arg: NDStructure<Double>, transform: RealField.(Double) -> Double): ViktorNDStructure =
F64Array(*shape).apply {
this@ViktorNDField.strides.indices().forEach { index ->
set(value = RealField.transform(arg[index]), indices = index)
@ -50,7 +58,7 @@ public class ViktorNDField(public override val shape: IntArray) : NDField<Double
}.asStructure()
public override fun mapIndexed(
arg: ViktorNDStructure,
arg: NDStructure<Double>,
transform: RealField.(index: IntArray, Double) -> Double
): ViktorNDStructure = F64Array(*shape).apply {
this@ViktorNDField.strides.indices().forEach { index ->
@ -59,8 +67,8 @@ public class ViktorNDField(public override val shape: IntArray) : NDField<Double
}.asStructure()
public override fun combine(
a: ViktorNDStructure,
b: ViktorNDStructure,
a: NDStructure<Double>,
b: NDStructure<Double>,
transform: RealField.(Double, Double) -> Double
): ViktorNDStructure = F64Array(*shape).apply {
this@ViktorNDField.strides.indices().forEach { index ->
@ -68,21 +76,21 @@ public class ViktorNDField(public override val shape: IntArray) : NDField<Double
}
}.asStructure()
public override fun add(a: ViktorNDStructure, b: ViktorNDStructure): ViktorNDStructure =
public override fun add(a: NDStructure<Double>, b: NDStructure<Double>): ViktorNDStructure =
(a.f64Buffer + b.f64Buffer).asStructure()
public override fun multiply(a: ViktorNDStructure, k: Number): ViktorNDStructure =
public override fun multiply(a: NDStructure<Double>, k: Number): ViktorNDStructure =
(a.f64Buffer * k.toDouble()).asStructure()
public override inline fun ViktorNDStructure.plus(b: ViktorNDStructure): ViktorNDStructure =
public override inline fun NDStructure<Double>.plus(b: NDStructure<Double>): ViktorNDStructure =
(f64Buffer + b.f64Buffer).asStructure()
public override inline fun ViktorNDStructure.minus(b: ViktorNDStructure): ViktorNDStructure =
public override inline fun NDStructure<Double>.minus(b: NDStructure<Double>): ViktorNDStructure =
(f64Buffer - b.f64Buffer).asStructure()
public override inline fun ViktorNDStructure.times(k: Number): ViktorNDStructure =
public override inline fun NDStructure<Double>.times(k: Number): ViktorNDStructure =
(f64Buffer * k.toDouble()).asStructure()
public override inline fun ViktorNDStructure.plus(arg: Double): ViktorNDStructure =
public override inline fun NDStructure<Double>.plus(arg: Double): ViktorNDStructure =
(f64Buffer.plus(arg)).asStructure()
}