Equlity half-fix for NDStructure

This commit is contained in:
Alexander Nozik 2020-05-06 10:50:08 +03:00
parent 898f082a0c
commit 646207e140
19 changed files with 172 additions and 138 deletions

View File

@ -1,8 +1,8 @@
plugins { plugins {
id("scientifik.publish") version "0.4.2" apply false id("scientifik.publish") apply false
} }
val kmathVersion by extra("0.1.4-dev-4") val kmathVersion by extra("0.1.4-dev-5")
val bintrayRepo by extra("scientifik") val bintrayRepo by extra("scientifik")
val githubProject by extra("kmath") val githubProject by extra("kmath")

View File

@ -57,6 +57,6 @@ benchmark {
tasks.withType<KotlinCompile> { tasks.withType<KotlinCompile> {
kotlinOptions { kotlinOptions {
jvmTarget = Scientifik.JVM_VERSION jvmTarget = Scientifik.JVM_TARGET.toString()
} }
} }

View File

@ -5,6 +5,7 @@ import org.apache.commons.math3.linear.RealMatrix
import org.apache.commons.math3.linear.RealVector import org.apache.commons.math3.linear.RealVector
import scientifik.kmath.linear.* import scientifik.kmath.linear.*
import scientifik.kmath.structures.Matrix import scientifik.kmath.structures.Matrix
import scientifik.kmath.structures.NDStructure
class CMMatrix(val origin: RealMatrix, features: Set<MatrixFeature>? = null) : class CMMatrix(val origin: RealMatrix, features: Set<MatrixFeature>? = null) :
FeaturedMatrix<Double> { FeaturedMatrix<Double> {
@ -19,6 +20,16 @@ class CMMatrix(val origin: RealMatrix, features: Set<MatrixFeature>? = null) :
CMMatrix(origin, this.features + features) CMMatrix(origin, this.features + features)
override fun get(i: Int, j: Int): Double = origin.getEntry(i, j) override fun get(i: Int, j: Int): Double = origin.getEntry(i, j)
override fun equals(other: Any?): Boolean {
return NDStructure.equals(this, other as? NDStructure<*> ?: return false)
}
override fun hashCode(): Int {
var result = origin.hashCode()
result = 31 * result + features.hashCode()
return result
}
} }
fun Matrix<Double>.toCM(): CMMatrix = if (this is CMMatrix) { fun Matrix<Double>.toCM(): CMMatrix = if (this is CMMatrix) {

View File

@ -2,6 +2,7 @@ package scientifik.kmath.operations
import scientifik.kmath.operations.BigInt.Companion.BASE import scientifik.kmath.operations.BigInt.Companion.BASE
import scientifik.kmath.operations.BigInt.Companion.BASE_SIZE import scientifik.kmath.operations.BigInt.Companion.BASE_SIZE
import scientifik.kmath.structures.*
import kotlin.math.log2 import kotlin.math.log2
import kotlin.math.max import kotlin.math.max
import kotlin.math.min import kotlin.math.min
@ -482,3 +483,18 @@ fun String.parseBigInteger(): BigInt? {
} }
return res * sign return res * sign
} }
inline fun Buffer.Companion.bigInt(size: Int, initializer: (Int) -> BigInt): Buffer<BigInt> =
boxing(size, initializer)
inline fun MutableBuffer.Companion.bigInt(size: Int, initializer: (Int) -> BigInt): MutableBuffer<BigInt> =
boxing(size, initializer)
fun NDAlgebra.Companion.bigInt(vararg shape: Int): BoxingNDRing<BigInt, BigIntField> =
BoxingNDRing(shape, BigIntField, Buffer.Companion::bigInt)
fun NDElement.Companion.bigInt(
vararg shape: Int,
initializer: BigIntField.(IntArray) -> BigInt
): BufferedNDRingElement<BigInt, BigIntField> =
NDAlgebra.bigInt(*shape).produce(initializer)

View File

@ -3,10 +3,10 @@ package scientifik.kmath.structures
import scientifik.kmath.operations.* import scientifik.kmath.operations.*
/** /**
* Base interface for an element with context, containing strides * Base class for an element with context, containing strides
*/ */
interface BufferedNDElement<T, C> : NDBuffer<T>, NDElement<T, C, NDBuffer<T>> { abstract class BufferedNDElement<T, C> : NDBuffer<T>(), NDElement<T, C, NDBuffer<T>> {
override val context: BufferedNDAlgebra<T, C> abstract override val context: BufferedNDAlgebra<T, C>
override val strides get() = context.strides override val strides get() = context.strides
@ -16,7 +16,7 @@ interface BufferedNDElement<T, C> : NDBuffer<T>, NDElement<T, C, NDBuffer<T>> {
class BufferedNDSpaceElement<T, S : Space<T>>( class BufferedNDSpaceElement<T, S : Space<T>>(
override val context: BufferedNDSpace<T, S>, override val context: BufferedNDSpace<T, S>,
override val buffer: Buffer<T> override val buffer: Buffer<T>
) : BufferedNDElement<T, S>, SpaceElement<NDBuffer<T>, BufferedNDSpaceElement<T, S>, BufferedNDSpace<T, S>> { ) : BufferedNDElement<T, S>(), SpaceElement<NDBuffer<T>, BufferedNDSpaceElement<T, S>, BufferedNDSpace<T, S>> {
override fun unwrap(): NDBuffer<T> = this override fun unwrap(): NDBuffer<T> = this
@ -29,7 +29,7 @@ class BufferedNDSpaceElement<T, S : Space<T>>(
class BufferedNDRingElement<T, R : Ring<T>>( class BufferedNDRingElement<T, R : Ring<T>>(
override val context: BufferedNDRing<T, R>, override val context: BufferedNDRing<T, R>,
override val buffer: Buffer<T> override val buffer: Buffer<T>
) : BufferedNDElement<T, R>, RingElement<NDBuffer<T>, BufferedNDRingElement<T, R>, BufferedNDRing<T, R>> { ) : BufferedNDElement<T, R>(), RingElement<NDBuffer<T>, BufferedNDRingElement<T, R>, BufferedNDRing<T, R>> {
override fun unwrap(): NDBuffer<T> = this override fun unwrap(): NDBuffer<T> = this
@ -42,7 +42,7 @@ class BufferedNDRingElement<T, R : Ring<T>>(
class BufferedNDFieldElement<T, F : Field<T>>( class BufferedNDFieldElement<T, F : Field<T>>(
override val context: BufferedNDField<T, F>, override val context: BufferedNDField<T, F>,
override val buffer: Buffer<T> override val buffer: Buffer<T>
) : BufferedNDElement<T, F>, FieldElement<NDBuffer<T>, BufferedNDFieldElement<T, F>, BufferedNDField<T, F>> { ) : BufferedNDElement<T, F>(), FieldElement<NDBuffer<T>, BufferedNDFieldElement<T, F>, BufferedNDField<T, F>> {
override fun unwrap(): NDBuffer<T> = this override fun unwrap(): NDBuffer<T> = this

View File

@ -71,7 +71,7 @@ interface Buffer<T> {
fun <T> Buffer<T>.asSequence(): Sequence<T> = Sequence(::iterator) fun <T> Buffer<T>.asSequence(): Sequence<T> = Sequence(::iterator)
fun <T> Buffer<T>.asIterable(): Iterable<T> = asSequence().asIterable() fun <T> Buffer<T>.asIterable(): Iterable<T> = Iterable(::iterator)
val Buffer<*>.indices: IntRange get() = IntRange(0, size - 1) val Buffer<*>.indices: IntRange get() = IntRange(0, size - 1)

View File

@ -14,15 +14,21 @@ interface NDStructure<T> {
fun elements(): Sequence<Pair<IntArray, T>> fun elements(): Sequence<Pair<IntArray, T>>
override fun equals(other: Any?): Boolean
override fun hashCode(): Int
companion object { companion object {
fun equals(st1: NDStructure<*>, st2: NDStructure<*>): Boolean { fun equals(st1: NDStructure<*>, st2: NDStructure<*>): Boolean {
return when { if(st1===st2) return true
st1 === st2 -> true
st1 is BufferNDStructure<*> && st2 is BufferNDStructure<*> && st1.strides == st2.strides -> st1.buffer.contentEquals( // fast comparison of buffers if possible
st2.buffer if(st1 is NDBuffer && st2 is NDBuffer && st1.strides == st2.strides){
) return st1.buffer.contentEquals(st2.buffer)
else -> st1.elements().all { (index, value) -> value == st2[index] }
} }
//element by element comparison if it could not be avoided
return st1.elements().all { (index, value) -> value == st2[index] }
} }
/** /**
@ -177,15 +183,25 @@ class DefaultStrides private constructor(override val shape: IntArray) : Strides
} }
} }
interface NDBuffer<T> : NDStructure<T> { abstract class NDBuffer<T> : NDStructure<T> {
val buffer: Buffer<T> abstract val buffer: Buffer<T>
val strides: Strides abstract val strides: Strides
override fun get(index: IntArray): T = buffer[strides.offset(index)] override fun get(index: IntArray): T = buffer[strides.offset(index)]
override val shape: IntArray get() = strides.shape override val shape: IntArray get() = strides.shape
override fun elements() = strides.indices().map { it to this[it] } override fun elements(): Sequence<Pair<IntArray, T>> = strides.indices().map { it to this[it] }
override fun equals(other: Any?): Boolean {
return NDStructure.equals(this, other as? NDStructure<*> ?: return false)
}
override fun hashCode(): Int {
var result = strides.hashCode()
result = 31 * result + buffer.hashCode()
return result
}
} }
/** /**
@ -194,34 +210,12 @@ interface NDBuffer<T> : NDStructure<T> {
class BufferNDStructure<T>( class BufferNDStructure<T>(
override val strides: Strides, override val strides: Strides,
override val buffer: Buffer<T> override val buffer: Buffer<T>
) : NDBuffer<T> { ) : NDBuffer<T>(){
init { init {
if (strides.linearSize != buffer.size) { if (strides.linearSize != buffer.size) {
error("Expected buffer side of ${strides.linearSize}, but found ${buffer.size}") error("Expected buffer side of ${strides.linearSize}, but found ${buffer.size}")
} }
} }
override fun get(index: IntArray): T = buffer[strides.offset(index)]
override val shape: IntArray get() = strides.shape
override fun elements() = strides.indices().map { it to this[it] }
override fun equals(other: Any?): Boolean {
return when {
this === other -> true
other is BufferNDStructure<*> && this.strides == other.strides -> this.buffer.contentEquals(other.buffer)
other is NDStructure<*> -> elements().all { (index, value) -> value == other[index] }
else -> false
}
}
override fun hashCode(): Int {
var result = strides.hashCode()
result = 31 * result + buffer.hashCode()
return result
}
} }
/** /**
@ -245,7 +239,7 @@ inline fun <T, reified R : Any> NDStructure<T>.mapToBuffer(
class MutableBufferNDStructure<T>( class MutableBufferNDStructure<T>(
override val strides: Strides, override val strides: Strides,
override val buffer: MutableBuffer<T> override val buffer: MutableBuffer<T>
) : NDBuffer<T>, MutableNDStructure<T> { ) : NDBuffer<T>(), MutableNDStructure<T> {
init { init {
if (strides.linearSize != buffer.size) { if (strides.linearSize != buffer.size) {

View File

@ -14,7 +14,6 @@ interface Structure2D<T> : NDStructure<T> {
return get(index[0], index[1]) return get(index[0], index[1])
} }
val rows: Buffer<Buffer<T>> val rows: Buffer<Buffer<T>>
get() = VirtualBuffer(rowNum) { i -> get() = VirtualBuffer(rowNum) { i ->
VirtualBuffer(colNum) { j -> get(i, j) } VirtualBuffer(colNum) { j -> get(i, j) }
@ -58,22 +57,4 @@ fun <T> NDStructure<T>.as2D(): Structure2D<T> = if (shape.size == 2) {
error("Can't create 2d-structure from ${shape.size}d-structure") error("Can't create 2d-structure from ${shape.size}d-structure")
} }
/**
* Represent this 2D structure as 1D if it has exactly one column. Throw error otherwise.
*/
fun <T> Structure2D<T>.as1D() = if (colNum == 1) {
object : Structure1D<T> {
override fun get(index: Int): T = get(index, 0)
override val shape: IntArray get() = intArrayOf(rowNum)
override fun elements(): Sequence<Pair<IntArray, T>> = elements()
override val size: Int get() = rowNum
}
} else {
error("Can't convert matrix with more than one column to vector")
}
typealias Matrix<T> = Structure2D<T> typealias Matrix<T> = Structure2D<T>

View File

@ -5,7 +5,7 @@ import java.math.BigDecimal
import java.math.BigInteger import java.math.BigInteger
import java.math.MathContext import java.math.MathContext
object BigIntegerRing : Ring<BigInteger> { object JBigIntegerField : Field<BigInteger> {
override val zero: BigInteger = BigInteger.ZERO override val zero: BigInteger = BigInteger.ZERO
override val one: BigInteger = BigInteger.ONE override val one: BigInteger = BigInteger.ONE
@ -14,9 +14,11 @@ object BigIntegerRing : Ring<BigInteger> {
override fun multiply(a: BigInteger, k: Number): BigInteger = a.multiply(k.toInt().toBigInteger()) override fun multiply(a: BigInteger, k: Number): BigInteger = a.multiply(k.toInt().toBigInteger())
override fun multiply(a: BigInteger, b: BigInteger): BigInteger = a.multiply(b) override fun multiply(a: BigInteger, b: BigInteger): BigInteger = a.multiply(b)
override fun divide(a: BigInteger, b: BigInteger): BigInteger = a.div(b)
} }
class BigDecimalField(val mathContext: MathContext = MathContext.DECIMAL64) : Field<BigDecimal> { class JBigDecimalField(val mathContext: MathContext = MathContext.DECIMAL64) : Field<BigDecimal> {
override val zero: BigDecimal = BigDecimal.ZERO override val zero: BigDecimal = BigDecimal.ZERO
override val one: BigDecimal = BigDecimal.ONE override val one: BigDecimal = BigDecimal.ONE
@ -28,18 +30,3 @@ class BigDecimalField(val mathContext: MathContext = MathContext.DECIMAL64) : Fi
override fun multiply(a: BigDecimal, b: BigDecimal): BigDecimal = a.multiply(b, mathContext) override fun multiply(a: BigDecimal, b: BigDecimal): BigDecimal = a.multiply(b, mathContext)
override fun divide(a: BigDecimal, b: BigDecimal): BigDecimal = a.divide(b, mathContext) override fun divide(a: BigDecimal, b: BigDecimal): BigDecimal = a.divide(b, mathContext)
} }
inline fun Buffer.Companion.bigInt(size: Int, initializer: (Int) -> BigInteger): Buffer<BigInteger> =
boxing(size, initializer)
inline fun MutableBuffer.Companion.bigInt(size: Int, initializer: (Int) -> BigInteger): MutableBuffer<BigInteger> =
boxing(size, initializer)
fun NDAlgebra.Companion.bigInt(vararg shape: Int): BoxingNDRing<BigInteger, BigIntegerRing> =
BoxingNDRing(shape, BigIntegerRing, Buffer.Companion::bigInt)
fun NDElement.Companion.bigInt(
vararg shape: Int,
initializer: BigIntegerRing.(IntArray) -> BigInteger
): BufferedNDRingElement<BigInteger, BigIntegerRing> =
NDAlgebra.bigInt(*shape).produce(initializer)

View File

@ -68,6 +68,8 @@ class MarkovChain<out R : Any>(private val seed: suspend () -> R, private val ge
private var value: R? = null private var value: R? = null
fun value() = value
override suspend fun next(): R { override suspend fun next(): R {
mutex.withLock { mutex.withLock {
val newValue = gen(value ?: seed()) val newValue = gen(value ?: seed())
@ -97,6 +99,8 @@ class StatefulChain<S, out R>(
private var value: R? = null private var value: R? = null
fun value() = value
override suspend fun next(): R { override suspend fun next(): R {
mutex.withLock { mutex.withLock {
val newValue = state.gen(value ?: state.seed()) val newValue = state.gen(value ?: state.seed())

View File

@ -30,6 +30,20 @@ class LazyNDStructure<T>(
} }
return res.asSequence() return res.asSequence()
} }
override fun equals(other: Any?): Boolean {
return NDStructure.equals(this, other as? NDStructure<*> ?: return false)
}
override fun hashCode(): Int {
var result = scope.hashCode()
result = 31 * result + shape.contentHashCode()
result = 31 * result + function.hashCode()
result = 31 * result + cache.hashCode()
return result
}
} }
fun <T> NDStructure<T>.deferred(index: IntArray) = fun <T> NDStructure<T>.deferred(index: IntArray) =

View File

@ -1,48 +0,0 @@
package scientifik.kmath.linear
import scientifik.kmath.operations.Field
import scientifik.kmath.operations.Norm
import scientifik.kmath.operations.RealField
import scientifik.kmath.operations.SpaceElement
import scientifik.kmath.structures.Buffer
import scientifik.kmath.structures.DoubleBuffer
import scientifik.kmath.structures.asBuffer
import scientifik.kmath.structures.asSequence
fun DoubleArray.asVector() = RealVector(this.asBuffer())
fun List<Double>.asVector() = RealVector(this.asBuffer())
object VectorL2Norm : Norm<Point<out Number>, Double> {
override fun norm(arg: Point<out Number>): Double =
kotlin.math.sqrt(arg.asSequence().sumByDouble { it.toDouble() })
}
inline class RealVector(val point: Point<Double>) :
SpaceElement<Point<Double>, RealVector, VectorSpace<Double, RealField>>, Point<Double> {
override val context: VectorSpace<Double, RealField> get() = space(point.size)
override fun unwrap(): Point<Double> = point
override fun Point<Double>.wrap(): RealVector = RealVector(this)
override val size: Int get() = point.size
override fun get(index: Int): Double = point[index]
override fun iterator(): Iterator<Double> = point.iterator()
companion object {
private val spaceCache = HashMap<Int, BufferVectorSpace<Double, RealField>>()
inline operator fun invoke(dim:Int, initalizer: (Int)-> Double) = RealVector(DoubleBuffer(dim, initalizer))
operator fun invoke(vararg values: Double) = values.asVector()
fun space(dim: Int) =
spaceCache.getOrPut(dim) {
BufferVectorSpace(dim, RealField) { size, init -> Buffer.real(size, init) }
}
}
}

View File

@ -0,0 +1,59 @@
package scientifik.kmath.real
import scientifik.kmath.linear.BufferVectorSpace
import scientifik.kmath.linear.Point
import scientifik.kmath.linear.VectorSpace
import scientifik.kmath.operations.Norm
import scientifik.kmath.operations.RealField
import scientifik.kmath.operations.SpaceElement
import scientifik.kmath.structures.Buffer
import scientifik.kmath.structures.DoubleBuffer
import scientifik.kmath.structures.asBuffer
import scientifik.kmath.structures.asIterable
import kotlin.math.sqrt
fun DoubleArray.asVector() = RealVector(this.asBuffer())
fun List<Double>.asVector() = RealVector(this.asBuffer())
object VectorL2Norm : Norm<Point<out Number>, Double> {
override fun norm(arg: Point<out Number>): Double = sqrt(arg.asIterable().sumByDouble { it.toDouble() })
}
inline class RealVector(private val point: Point<Double>) :
SpaceElement<Point<Double>, RealVector, VectorSpace<Double, RealField>>, Point<Double> {
override val context: VectorSpace<Double, RealField>
get() = space(
point.size
)
override fun unwrap(): Point<Double> = point
override fun Point<Double>.wrap(): RealVector =
RealVector(this)
override val size: Int get() = point.size
override fun get(index: Int): Double = point[index]
override fun iterator(): Iterator<Double> = point.iterator()
companion object {
private val spaceCache = HashMap<Int, BufferVectorSpace<Double, RealField>>()
inline operator fun invoke(dim: Int, initializer: (Int) -> Double) =
RealVector(DoubleBuffer(dim, initializer))
operator fun invoke(vararg values: Double): RealVector = values.asVector()
fun space(dim: Int): BufferVectorSpace<Double, RealField> =
spaceCache.getOrPut(dim) {
BufferVectorSpace(
dim,
RealField
) { size, init -> Buffer.real(size, init) }
}
}
}

View File

@ -1,5 +1,6 @@
package scientifik.kmath.linear package scientifik.kmath.linear
import scientifik.kmath.real.RealVector
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals

View File

@ -1,7 +1,7 @@
package scientifik.kmath.histogram package scientifik.kmath.histogram
import scientifik.kmath.linear.Point import scientifik.kmath.linear.Point
import scientifik.kmath.linear.asVector import scientifik.kmath.real.asVector
import scientifik.kmath.operations.SpaceOperations import scientifik.kmath.operations.SpaceOperations
import scientifik.kmath.structures.* import scientifik.kmath.structures.*
import kotlin.math.floor import kotlin.math.floor

View File

@ -3,7 +3,7 @@ package scietifik.kmath.histogram
import scientifik.kmath.histogram.RealHistogram import scientifik.kmath.histogram.RealHistogram
import scientifik.kmath.histogram.fill import scientifik.kmath.histogram.fill
import scientifik.kmath.histogram.put import scientifik.kmath.histogram.put
import scientifik.kmath.linear.RealVector import scientifik.kmath.real.RealVector
import kotlin.random.Random import kotlin.random.Random
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals

View File

@ -1,7 +1,7 @@
package scientifik.kmath.histogram package scientifik.kmath.histogram
import scientifik.kmath.linear.RealVector import scientifik.kmath.real.RealVector
import scientifik.kmath.linear.asVector import scientifik.kmath.real.asVector
import scientifik.kmath.structures.Buffer import scientifik.kmath.structures.Buffer
import java.util.* import java.util.*
import kotlin.math.floor import kotlin.math.floor

View File

@ -4,6 +4,7 @@ import koma.extensions.fill
import koma.matrix.MatrixFactory import koma.matrix.MatrixFactory
import scientifik.kmath.operations.Space import scientifik.kmath.operations.Space
import scientifik.kmath.structures.Matrix import scientifik.kmath.structures.Matrix
import scientifik.kmath.structures.NDStructure
class KomaMatrixContext<T : Any>( class KomaMatrixContext<T : Any>(
private val factory: MatrixFactory<koma.matrix.Matrix<T>>, private val factory: MatrixFactory<koma.matrix.Matrix<T>>,
@ -85,6 +86,18 @@ class KomaMatrix<T : Any>(val origin: koma.matrix.Matrix<T>, features: Set<Matri
KomaMatrix(this.origin, this.features + features) KomaMatrix(this.origin, this.features + features)
override fun get(i: Int, j: Int): T = origin.getGeneric(i, j) override fun get(i: Int, j: Int): T = origin.getGeneric(i, j)
override fun equals(other: Any?): Boolean {
return NDStructure.equals(this, other as? NDStructure<*> ?: return false)
}
override fun hashCode(): Int {
var result = origin.hashCode()
result = 31 * result + features.hashCode()
return result
}
} }
class KomaVector<T : Any> internal constructor(val origin: koma.matrix.Matrix<T>) : Point<T> { class KomaVector<T : Any> internal constructor(val origin: koma.matrix.Matrix<T>) : Point<T> {

View File

@ -1,10 +1,12 @@
pluginManagement { pluginManagement {
val toolsVersion = "0.5.0"
plugins { plugins {
id("scientifik.mpp") version "0.4.1" id("scientifik.mpp") version toolsVersion
id("scientifik.jvm") version "0.4.1" id("scientifik.jvm") version toolsVersion
id("scientifik.atomic") version "0.4.1" id("scientifik.atomic") version toolsVersion
id("scientifik.publish") version "0.4.1" id("scientifik.publish") version toolsVersion
} }
repositories { repositories {
@ -20,7 +22,7 @@ pluginManagement {
resolutionStrategy { resolutionStrategy {
eachPlugin { eachPlugin {
when (requested.id.id) { when (requested.id.id) {
"scientifik.mpp", "scientifik.jvm", "scientifik.publish" -> useModule("scientifik:gradle-tools:${requested.version}") "scientifik.mpp", "scientifik.jvm", "scientifik.publish" -> useModule("scientifik:gradle-tools:$toolsVersion")
} }
} }
} }