forked from kscience/kmath
Equlity half-fix for NDStructure
This commit is contained in:
parent
898f082a0c
commit
646207e140
@ -1,8 +1,8 @@
|
|||||||
plugins {
|
plugins {
|
||||||
id("scientifik.publish") version "0.4.2" apply false
|
id("scientifik.publish") apply false
|
||||||
}
|
}
|
||||||
|
|
||||||
val kmathVersion by extra("0.1.4-dev-4")
|
val kmathVersion by extra("0.1.4-dev-5")
|
||||||
|
|
||||||
val bintrayRepo by extra("scientifik")
|
val bintrayRepo by extra("scientifik")
|
||||||
val githubProject by extra("kmath")
|
val githubProject by extra("kmath")
|
||||||
|
@ -57,6 +57,6 @@ benchmark {
|
|||||||
|
|
||||||
tasks.withType<KotlinCompile> {
|
tasks.withType<KotlinCompile> {
|
||||||
kotlinOptions {
|
kotlinOptions {
|
||||||
jvmTarget = Scientifik.JVM_VERSION
|
jvmTarget = Scientifik.JVM_TARGET.toString()
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -5,6 +5,7 @@ import org.apache.commons.math3.linear.RealMatrix
|
|||||||
import org.apache.commons.math3.linear.RealVector
|
import org.apache.commons.math3.linear.RealVector
|
||||||
import scientifik.kmath.linear.*
|
import scientifik.kmath.linear.*
|
||||||
import scientifik.kmath.structures.Matrix
|
import scientifik.kmath.structures.Matrix
|
||||||
|
import scientifik.kmath.structures.NDStructure
|
||||||
|
|
||||||
class CMMatrix(val origin: RealMatrix, features: Set<MatrixFeature>? = null) :
|
class CMMatrix(val origin: RealMatrix, features: Set<MatrixFeature>? = null) :
|
||||||
FeaturedMatrix<Double> {
|
FeaturedMatrix<Double> {
|
||||||
@ -19,6 +20,16 @@ class CMMatrix(val origin: RealMatrix, features: Set<MatrixFeature>? = null) :
|
|||||||
CMMatrix(origin, this.features + features)
|
CMMatrix(origin, this.features + features)
|
||||||
|
|
||||||
override fun get(i: Int, j: Int): Double = origin.getEntry(i, j)
|
override fun get(i: Int, j: Int): Double = origin.getEntry(i, j)
|
||||||
|
|
||||||
|
override fun equals(other: Any?): Boolean {
|
||||||
|
return NDStructure.equals(this, other as? NDStructure<*> ?: return false)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun hashCode(): Int {
|
||||||
|
var result = origin.hashCode()
|
||||||
|
result = 31 * result + features.hashCode()
|
||||||
|
return result
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fun Matrix<Double>.toCM(): CMMatrix = if (this is CMMatrix) {
|
fun Matrix<Double>.toCM(): CMMatrix = if (this is CMMatrix) {
|
||||||
|
@ -2,6 +2,7 @@ package scientifik.kmath.operations
|
|||||||
|
|
||||||
import scientifik.kmath.operations.BigInt.Companion.BASE
|
import scientifik.kmath.operations.BigInt.Companion.BASE
|
||||||
import scientifik.kmath.operations.BigInt.Companion.BASE_SIZE
|
import scientifik.kmath.operations.BigInt.Companion.BASE_SIZE
|
||||||
|
import scientifik.kmath.structures.*
|
||||||
import kotlin.math.log2
|
import kotlin.math.log2
|
||||||
import kotlin.math.max
|
import kotlin.math.max
|
||||||
import kotlin.math.min
|
import kotlin.math.min
|
||||||
@ -482,3 +483,18 @@ fun String.parseBigInteger(): BigInt? {
|
|||||||
}
|
}
|
||||||
return res * sign
|
return res * sign
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline fun Buffer.Companion.bigInt(size: Int, initializer: (Int) -> BigInt): Buffer<BigInt> =
|
||||||
|
boxing(size, initializer)
|
||||||
|
|
||||||
|
inline fun MutableBuffer.Companion.bigInt(size: Int, initializer: (Int) -> BigInt): MutableBuffer<BigInt> =
|
||||||
|
boxing(size, initializer)
|
||||||
|
|
||||||
|
fun NDAlgebra.Companion.bigInt(vararg shape: Int): BoxingNDRing<BigInt, BigIntField> =
|
||||||
|
BoxingNDRing(shape, BigIntField, Buffer.Companion::bigInt)
|
||||||
|
|
||||||
|
fun NDElement.Companion.bigInt(
|
||||||
|
vararg shape: Int,
|
||||||
|
initializer: BigIntField.(IntArray) -> BigInt
|
||||||
|
): BufferedNDRingElement<BigInt, BigIntField> =
|
||||||
|
NDAlgebra.bigInt(*shape).produce(initializer)
|
@ -3,10 +3,10 @@ package scientifik.kmath.structures
|
|||||||
import scientifik.kmath.operations.*
|
import scientifik.kmath.operations.*
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Base interface for an element with context, containing strides
|
* Base class for an element with context, containing strides
|
||||||
*/
|
*/
|
||||||
interface BufferedNDElement<T, C> : NDBuffer<T>, NDElement<T, C, NDBuffer<T>> {
|
abstract class BufferedNDElement<T, C> : NDBuffer<T>(), NDElement<T, C, NDBuffer<T>> {
|
||||||
override val context: BufferedNDAlgebra<T, C>
|
abstract override val context: BufferedNDAlgebra<T, C>
|
||||||
|
|
||||||
override val strides get() = context.strides
|
override val strides get() = context.strides
|
||||||
|
|
||||||
@ -16,7 +16,7 @@ interface BufferedNDElement<T, C> : NDBuffer<T>, NDElement<T, C, NDBuffer<T>> {
|
|||||||
class BufferedNDSpaceElement<T, S : Space<T>>(
|
class BufferedNDSpaceElement<T, S : Space<T>>(
|
||||||
override val context: BufferedNDSpace<T, S>,
|
override val context: BufferedNDSpace<T, S>,
|
||||||
override val buffer: Buffer<T>
|
override val buffer: Buffer<T>
|
||||||
) : BufferedNDElement<T, S>, SpaceElement<NDBuffer<T>, BufferedNDSpaceElement<T, S>, BufferedNDSpace<T, S>> {
|
) : BufferedNDElement<T, S>(), SpaceElement<NDBuffer<T>, BufferedNDSpaceElement<T, S>, BufferedNDSpace<T, S>> {
|
||||||
|
|
||||||
override fun unwrap(): NDBuffer<T> = this
|
override fun unwrap(): NDBuffer<T> = this
|
||||||
|
|
||||||
@ -29,7 +29,7 @@ class BufferedNDSpaceElement<T, S : Space<T>>(
|
|||||||
class BufferedNDRingElement<T, R : Ring<T>>(
|
class BufferedNDRingElement<T, R : Ring<T>>(
|
||||||
override val context: BufferedNDRing<T, R>,
|
override val context: BufferedNDRing<T, R>,
|
||||||
override val buffer: Buffer<T>
|
override val buffer: Buffer<T>
|
||||||
) : BufferedNDElement<T, R>, RingElement<NDBuffer<T>, BufferedNDRingElement<T, R>, BufferedNDRing<T, R>> {
|
) : BufferedNDElement<T, R>(), RingElement<NDBuffer<T>, BufferedNDRingElement<T, R>, BufferedNDRing<T, R>> {
|
||||||
|
|
||||||
override fun unwrap(): NDBuffer<T> = this
|
override fun unwrap(): NDBuffer<T> = this
|
||||||
|
|
||||||
@ -42,7 +42,7 @@ class BufferedNDRingElement<T, R : Ring<T>>(
|
|||||||
class BufferedNDFieldElement<T, F : Field<T>>(
|
class BufferedNDFieldElement<T, F : Field<T>>(
|
||||||
override val context: BufferedNDField<T, F>,
|
override val context: BufferedNDField<T, F>,
|
||||||
override val buffer: Buffer<T>
|
override val buffer: Buffer<T>
|
||||||
) : BufferedNDElement<T, F>, FieldElement<NDBuffer<T>, BufferedNDFieldElement<T, F>, BufferedNDField<T, F>> {
|
) : BufferedNDElement<T, F>(), FieldElement<NDBuffer<T>, BufferedNDFieldElement<T, F>, BufferedNDField<T, F>> {
|
||||||
|
|
||||||
override fun unwrap(): NDBuffer<T> = this
|
override fun unwrap(): NDBuffer<T> = this
|
||||||
|
|
||||||
|
@ -71,7 +71,7 @@ interface Buffer<T> {
|
|||||||
|
|
||||||
fun <T> Buffer<T>.asSequence(): Sequence<T> = Sequence(::iterator)
|
fun <T> Buffer<T>.asSequence(): Sequence<T> = Sequence(::iterator)
|
||||||
|
|
||||||
fun <T> Buffer<T>.asIterable(): Iterable<T> = asSequence().asIterable()
|
fun <T> Buffer<T>.asIterable(): Iterable<T> = Iterable(::iterator)
|
||||||
|
|
||||||
val Buffer<*>.indices: IntRange get() = IntRange(0, size - 1)
|
val Buffer<*>.indices: IntRange get() = IntRange(0, size - 1)
|
||||||
|
|
||||||
|
@ -14,15 +14,21 @@ interface NDStructure<T> {
|
|||||||
|
|
||||||
fun elements(): Sequence<Pair<IntArray, T>>
|
fun elements(): Sequence<Pair<IntArray, T>>
|
||||||
|
|
||||||
|
override fun equals(other: Any?): Boolean
|
||||||
|
|
||||||
|
override fun hashCode(): Int
|
||||||
|
|
||||||
companion object {
|
companion object {
|
||||||
fun equals(st1: NDStructure<*>, st2: NDStructure<*>): Boolean {
|
fun equals(st1: NDStructure<*>, st2: NDStructure<*>): Boolean {
|
||||||
return when {
|
if(st1===st2) return true
|
||||||
st1 === st2 -> true
|
|
||||||
st1 is BufferNDStructure<*> && st2 is BufferNDStructure<*> && st1.strides == st2.strides -> st1.buffer.contentEquals(
|
// fast comparison of buffers if possible
|
||||||
st2.buffer
|
if(st1 is NDBuffer && st2 is NDBuffer && st1.strides == st2.strides){
|
||||||
)
|
return st1.buffer.contentEquals(st2.buffer)
|
||||||
else -> st1.elements().all { (index, value) -> value == st2[index] }
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//element by element comparison if it could not be avoided
|
||||||
|
return st1.elements().all { (index, value) -> value == st2[index] }
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -177,15 +183,25 @@ class DefaultStrides private constructor(override val shape: IntArray) : Strides
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
interface NDBuffer<T> : NDStructure<T> {
|
abstract class NDBuffer<T> : NDStructure<T> {
|
||||||
val buffer: Buffer<T>
|
abstract val buffer: Buffer<T>
|
||||||
val strides: Strides
|
abstract val strides: Strides
|
||||||
|
|
||||||
override fun get(index: IntArray): T = buffer[strides.offset(index)]
|
override fun get(index: IntArray): T = buffer[strides.offset(index)]
|
||||||
|
|
||||||
override val shape: IntArray get() = strides.shape
|
override val shape: IntArray get() = strides.shape
|
||||||
|
|
||||||
override fun elements() = strides.indices().map { it to this[it] }
|
override fun elements(): Sequence<Pair<IntArray, T>> = strides.indices().map { it to this[it] }
|
||||||
|
|
||||||
|
override fun equals(other: Any?): Boolean {
|
||||||
|
return NDStructure.equals(this, other as? NDStructure<*> ?: return false)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun hashCode(): Int {
|
||||||
|
var result = strides.hashCode()
|
||||||
|
result = 31 * result + buffer.hashCode()
|
||||||
|
return result
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -194,34 +210,12 @@ interface NDBuffer<T> : NDStructure<T> {
|
|||||||
class BufferNDStructure<T>(
|
class BufferNDStructure<T>(
|
||||||
override val strides: Strides,
|
override val strides: Strides,
|
||||||
override val buffer: Buffer<T>
|
override val buffer: Buffer<T>
|
||||||
) : NDBuffer<T> {
|
) : NDBuffer<T>(){
|
||||||
|
|
||||||
init {
|
init {
|
||||||
if (strides.linearSize != buffer.size) {
|
if (strides.linearSize != buffer.size) {
|
||||||
error("Expected buffer side of ${strides.linearSize}, but found ${buffer.size}")
|
error("Expected buffer side of ${strides.linearSize}, but found ${buffer.size}")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun get(index: IntArray): T = buffer[strides.offset(index)]
|
|
||||||
|
|
||||||
override val shape: IntArray get() = strides.shape
|
|
||||||
|
|
||||||
override fun elements() = strides.indices().map { it to this[it] }
|
|
||||||
|
|
||||||
override fun equals(other: Any?): Boolean {
|
|
||||||
return when {
|
|
||||||
this === other -> true
|
|
||||||
other is BufferNDStructure<*> && this.strides == other.strides -> this.buffer.contentEquals(other.buffer)
|
|
||||||
other is NDStructure<*> -> elements().all { (index, value) -> value == other[index] }
|
|
||||||
else -> false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun hashCode(): Int {
|
|
||||||
var result = strides.hashCode()
|
|
||||||
result = 31 * result + buffer.hashCode()
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -245,7 +239,7 @@ inline fun <T, reified R : Any> NDStructure<T>.mapToBuffer(
|
|||||||
class MutableBufferNDStructure<T>(
|
class MutableBufferNDStructure<T>(
|
||||||
override val strides: Strides,
|
override val strides: Strides,
|
||||||
override val buffer: MutableBuffer<T>
|
override val buffer: MutableBuffer<T>
|
||||||
) : NDBuffer<T>, MutableNDStructure<T> {
|
) : NDBuffer<T>(), MutableNDStructure<T> {
|
||||||
|
|
||||||
init {
|
init {
|
||||||
if (strides.linearSize != buffer.size) {
|
if (strides.linearSize != buffer.size) {
|
||||||
|
@ -14,7 +14,6 @@ interface Structure2D<T> : NDStructure<T> {
|
|||||||
return get(index[0], index[1])
|
return get(index[0], index[1])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
val rows: Buffer<Buffer<T>>
|
val rows: Buffer<Buffer<T>>
|
||||||
get() = VirtualBuffer(rowNum) { i ->
|
get() = VirtualBuffer(rowNum) { i ->
|
||||||
VirtualBuffer(colNum) { j -> get(i, j) }
|
VirtualBuffer(colNum) { j -> get(i, j) }
|
||||||
@ -58,22 +57,4 @@ fun <T> NDStructure<T>.as2D(): Structure2D<T> = if (shape.size == 2) {
|
|||||||
error("Can't create 2d-structure from ${shape.size}d-structure")
|
error("Can't create 2d-structure from ${shape.size}d-structure")
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Represent this 2D structure as 1D if it has exactly one column. Throw error otherwise.
|
|
||||||
*/
|
|
||||||
fun <T> Structure2D<T>.as1D() = if (colNum == 1) {
|
|
||||||
object : Structure1D<T> {
|
|
||||||
override fun get(index: Int): T = get(index, 0)
|
|
||||||
|
|
||||||
override val shape: IntArray get() = intArrayOf(rowNum)
|
|
||||||
|
|
||||||
override fun elements(): Sequence<Pair<IntArray, T>> = elements()
|
|
||||||
|
|
||||||
override val size: Int get() = rowNum
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
error("Can't convert matrix with more than one column to vector")
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
typealias Matrix<T> = Structure2D<T>
|
typealias Matrix<T> = Structure2D<T>
|
@ -5,7 +5,7 @@ import java.math.BigDecimal
|
|||||||
import java.math.BigInteger
|
import java.math.BigInteger
|
||||||
import java.math.MathContext
|
import java.math.MathContext
|
||||||
|
|
||||||
object BigIntegerRing : Ring<BigInteger> {
|
object JBigIntegerField : Field<BigInteger> {
|
||||||
override val zero: BigInteger = BigInteger.ZERO
|
override val zero: BigInteger = BigInteger.ZERO
|
||||||
override val one: BigInteger = BigInteger.ONE
|
override val one: BigInteger = BigInteger.ONE
|
||||||
|
|
||||||
@ -14,9 +14,11 @@ object BigIntegerRing : Ring<BigInteger> {
|
|||||||
override fun multiply(a: BigInteger, k: Number): BigInteger = a.multiply(k.toInt().toBigInteger())
|
override fun multiply(a: BigInteger, k: Number): BigInteger = a.multiply(k.toInt().toBigInteger())
|
||||||
|
|
||||||
override fun multiply(a: BigInteger, b: BigInteger): BigInteger = a.multiply(b)
|
override fun multiply(a: BigInteger, b: BigInteger): BigInteger = a.multiply(b)
|
||||||
|
|
||||||
|
override fun divide(a: BigInteger, b: BigInteger): BigInteger = a.div(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
class BigDecimalField(val mathContext: MathContext = MathContext.DECIMAL64) : Field<BigDecimal> {
|
class JBigDecimalField(val mathContext: MathContext = MathContext.DECIMAL64) : Field<BigDecimal> {
|
||||||
override val zero: BigDecimal = BigDecimal.ZERO
|
override val zero: BigDecimal = BigDecimal.ZERO
|
||||||
override val one: BigDecimal = BigDecimal.ONE
|
override val one: BigDecimal = BigDecimal.ONE
|
||||||
|
|
||||||
@ -28,18 +30,3 @@ class BigDecimalField(val mathContext: MathContext = MathContext.DECIMAL64) : Fi
|
|||||||
override fun multiply(a: BigDecimal, b: BigDecimal): BigDecimal = a.multiply(b, mathContext)
|
override fun multiply(a: BigDecimal, b: BigDecimal): BigDecimal = a.multiply(b, mathContext)
|
||||||
override fun divide(a: BigDecimal, b: BigDecimal): BigDecimal = a.divide(b, mathContext)
|
override fun divide(a: BigDecimal, b: BigDecimal): BigDecimal = a.divide(b, mathContext)
|
||||||
}
|
}
|
||||||
|
|
||||||
inline fun Buffer.Companion.bigInt(size: Int, initializer: (Int) -> BigInteger): Buffer<BigInteger> =
|
|
||||||
boxing(size, initializer)
|
|
||||||
|
|
||||||
inline fun MutableBuffer.Companion.bigInt(size: Int, initializer: (Int) -> BigInteger): MutableBuffer<BigInteger> =
|
|
||||||
boxing(size, initializer)
|
|
||||||
|
|
||||||
fun NDAlgebra.Companion.bigInt(vararg shape: Int): BoxingNDRing<BigInteger, BigIntegerRing> =
|
|
||||||
BoxingNDRing(shape, BigIntegerRing, Buffer.Companion::bigInt)
|
|
||||||
|
|
||||||
fun NDElement.Companion.bigInt(
|
|
||||||
vararg shape: Int,
|
|
||||||
initializer: BigIntegerRing.(IntArray) -> BigInteger
|
|
||||||
): BufferedNDRingElement<BigInteger, BigIntegerRing> =
|
|
||||||
NDAlgebra.bigInt(*shape).produce(initializer)
|
|
@ -68,6 +68,8 @@ class MarkovChain<out R : Any>(private val seed: suspend () -> R, private val ge
|
|||||||
|
|
||||||
private var value: R? = null
|
private var value: R? = null
|
||||||
|
|
||||||
|
fun value() = value
|
||||||
|
|
||||||
override suspend fun next(): R {
|
override suspend fun next(): R {
|
||||||
mutex.withLock {
|
mutex.withLock {
|
||||||
val newValue = gen(value ?: seed())
|
val newValue = gen(value ?: seed())
|
||||||
@ -97,6 +99,8 @@ class StatefulChain<S, out R>(
|
|||||||
|
|
||||||
private var value: R? = null
|
private var value: R? = null
|
||||||
|
|
||||||
|
fun value() = value
|
||||||
|
|
||||||
override suspend fun next(): R {
|
override suspend fun next(): R {
|
||||||
mutex.withLock {
|
mutex.withLock {
|
||||||
val newValue = state.gen(value ?: state.seed())
|
val newValue = state.gen(value ?: state.seed())
|
||||||
|
@ -30,6 +30,20 @@ class LazyNDStructure<T>(
|
|||||||
}
|
}
|
||||||
return res.asSequence()
|
return res.asSequence()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override fun equals(other: Any?): Boolean {
|
||||||
|
return NDStructure.equals(this, other as? NDStructure<*> ?: return false)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun hashCode(): Int {
|
||||||
|
var result = scope.hashCode()
|
||||||
|
result = 31 * result + shape.contentHashCode()
|
||||||
|
result = 31 * result + function.hashCode()
|
||||||
|
result = 31 * result + cache.hashCode()
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fun <T> NDStructure<T>.deferred(index: IntArray) =
|
fun <T> NDStructure<T>.deferred(index: IntArray) =
|
||||||
|
@ -1,48 +0,0 @@
|
|||||||
package scientifik.kmath.linear
|
|
||||||
|
|
||||||
import scientifik.kmath.operations.Field
|
|
||||||
import scientifik.kmath.operations.Norm
|
|
||||||
import scientifik.kmath.operations.RealField
|
|
||||||
import scientifik.kmath.operations.SpaceElement
|
|
||||||
import scientifik.kmath.structures.Buffer
|
|
||||||
import scientifik.kmath.structures.DoubleBuffer
|
|
||||||
import scientifik.kmath.structures.asBuffer
|
|
||||||
import scientifik.kmath.structures.asSequence
|
|
||||||
|
|
||||||
fun DoubleArray.asVector() = RealVector(this.asBuffer())
|
|
||||||
fun List<Double>.asVector() = RealVector(this.asBuffer())
|
|
||||||
|
|
||||||
|
|
||||||
object VectorL2Norm : Norm<Point<out Number>, Double> {
|
|
||||||
override fun norm(arg: Point<out Number>): Double =
|
|
||||||
kotlin.math.sqrt(arg.asSequence().sumByDouble { it.toDouble() })
|
|
||||||
}
|
|
||||||
|
|
||||||
inline class RealVector(val point: Point<Double>) :
|
|
||||||
SpaceElement<Point<Double>, RealVector, VectorSpace<Double, RealField>>, Point<Double> {
|
|
||||||
override val context: VectorSpace<Double, RealField> get() = space(point.size)
|
|
||||||
|
|
||||||
override fun unwrap(): Point<Double> = point
|
|
||||||
|
|
||||||
override fun Point<Double>.wrap(): RealVector = RealVector(this)
|
|
||||||
|
|
||||||
override val size: Int get() = point.size
|
|
||||||
|
|
||||||
override fun get(index: Int): Double = point[index]
|
|
||||||
|
|
||||||
override fun iterator(): Iterator<Double> = point.iterator()
|
|
||||||
|
|
||||||
companion object {
|
|
||||||
|
|
||||||
private val spaceCache = HashMap<Int, BufferVectorSpace<Double, RealField>>()
|
|
||||||
|
|
||||||
inline operator fun invoke(dim:Int, initalizer: (Int)-> Double) = RealVector(DoubleBuffer(dim, initalizer))
|
|
||||||
|
|
||||||
operator fun invoke(vararg values: Double) = values.asVector()
|
|
||||||
|
|
||||||
fun space(dim: Int) =
|
|
||||||
spaceCache.getOrPut(dim) {
|
|
||||||
BufferVectorSpace(dim, RealField) { size, init -> Buffer.real(size, init) }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
@ -0,0 +1,59 @@
|
|||||||
|
package scientifik.kmath.real
|
||||||
|
|
||||||
|
import scientifik.kmath.linear.BufferVectorSpace
|
||||||
|
import scientifik.kmath.linear.Point
|
||||||
|
import scientifik.kmath.linear.VectorSpace
|
||||||
|
import scientifik.kmath.operations.Norm
|
||||||
|
import scientifik.kmath.operations.RealField
|
||||||
|
import scientifik.kmath.operations.SpaceElement
|
||||||
|
import scientifik.kmath.structures.Buffer
|
||||||
|
import scientifik.kmath.structures.DoubleBuffer
|
||||||
|
import scientifik.kmath.structures.asBuffer
|
||||||
|
import scientifik.kmath.structures.asIterable
|
||||||
|
import kotlin.math.sqrt
|
||||||
|
|
||||||
|
fun DoubleArray.asVector() = RealVector(this.asBuffer())
|
||||||
|
fun List<Double>.asVector() = RealVector(this.asBuffer())
|
||||||
|
|
||||||
|
|
||||||
|
object VectorL2Norm : Norm<Point<out Number>, Double> {
|
||||||
|
override fun norm(arg: Point<out Number>): Double = sqrt(arg.asIterable().sumByDouble { it.toDouble() })
|
||||||
|
}
|
||||||
|
|
||||||
|
inline class RealVector(private val point: Point<Double>) :
|
||||||
|
SpaceElement<Point<Double>, RealVector, VectorSpace<Double, RealField>>, Point<Double> {
|
||||||
|
|
||||||
|
override val context: VectorSpace<Double, RealField>
|
||||||
|
get() = space(
|
||||||
|
point.size
|
||||||
|
)
|
||||||
|
|
||||||
|
override fun unwrap(): Point<Double> = point
|
||||||
|
|
||||||
|
override fun Point<Double>.wrap(): RealVector =
|
||||||
|
RealVector(this)
|
||||||
|
|
||||||
|
override val size: Int get() = point.size
|
||||||
|
|
||||||
|
override fun get(index: Int): Double = point[index]
|
||||||
|
|
||||||
|
override fun iterator(): Iterator<Double> = point.iterator()
|
||||||
|
|
||||||
|
companion object {
|
||||||
|
|
||||||
|
private val spaceCache = HashMap<Int, BufferVectorSpace<Double, RealField>>()
|
||||||
|
|
||||||
|
inline operator fun invoke(dim: Int, initializer: (Int) -> Double) =
|
||||||
|
RealVector(DoubleBuffer(dim, initializer))
|
||||||
|
|
||||||
|
operator fun invoke(vararg values: Double): RealVector = values.asVector()
|
||||||
|
|
||||||
|
fun space(dim: Int): BufferVectorSpace<Double, RealField> =
|
||||||
|
spaceCache.getOrPut(dim) {
|
||||||
|
BufferVectorSpace(
|
||||||
|
dim,
|
||||||
|
RealField
|
||||||
|
) { size, init -> Buffer.real(size, init) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -1,5 +1,6 @@
|
|||||||
package scientifik.kmath.linear
|
package scientifik.kmath.linear
|
||||||
|
|
||||||
|
import scientifik.kmath.real.RealVector
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
package scientifik.kmath.histogram
|
package scientifik.kmath.histogram
|
||||||
|
|
||||||
import scientifik.kmath.linear.Point
|
import scientifik.kmath.linear.Point
|
||||||
import scientifik.kmath.linear.asVector
|
import scientifik.kmath.real.asVector
|
||||||
import scientifik.kmath.operations.SpaceOperations
|
import scientifik.kmath.operations.SpaceOperations
|
||||||
import scientifik.kmath.structures.*
|
import scientifik.kmath.structures.*
|
||||||
import kotlin.math.floor
|
import kotlin.math.floor
|
||||||
|
@ -3,7 +3,7 @@ package scietifik.kmath.histogram
|
|||||||
import scientifik.kmath.histogram.RealHistogram
|
import scientifik.kmath.histogram.RealHistogram
|
||||||
import scientifik.kmath.histogram.fill
|
import scientifik.kmath.histogram.fill
|
||||||
import scientifik.kmath.histogram.put
|
import scientifik.kmath.histogram.put
|
||||||
import scientifik.kmath.linear.RealVector
|
import scientifik.kmath.real.RealVector
|
||||||
import kotlin.random.Random
|
import kotlin.random.Random
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
package scientifik.kmath.histogram
|
package scientifik.kmath.histogram
|
||||||
|
|
||||||
import scientifik.kmath.linear.RealVector
|
import scientifik.kmath.real.RealVector
|
||||||
import scientifik.kmath.linear.asVector
|
import scientifik.kmath.real.asVector
|
||||||
import scientifik.kmath.structures.Buffer
|
import scientifik.kmath.structures.Buffer
|
||||||
import java.util.*
|
import java.util.*
|
||||||
import kotlin.math.floor
|
import kotlin.math.floor
|
||||||
|
@ -4,6 +4,7 @@ import koma.extensions.fill
|
|||||||
import koma.matrix.MatrixFactory
|
import koma.matrix.MatrixFactory
|
||||||
import scientifik.kmath.operations.Space
|
import scientifik.kmath.operations.Space
|
||||||
import scientifik.kmath.structures.Matrix
|
import scientifik.kmath.structures.Matrix
|
||||||
|
import scientifik.kmath.structures.NDStructure
|
||||||
|
|
||||||
class KomaMatrixContext<T : Any>(
|
class KomaMatrixContext<T : Any>(
|
||||||
private val factory: MatrixFactory<koma.matrix.Matrix<T>>,
|
private val factory: MatrixFactory<koma.matrix.Matrix<T>>,
|
||||||
@ -85,6 +86,18 @@ class KomaMatrix<T : Any>(val origin: koma.matrix.Matrix<T>, features: Set<Matri
|
|||||||
KomaMatrix(this.origin, this.features + features)
|
KomaMatrix(this.origin, this.features + features)
|
||||||
|
|
||||||
override fun get(i: Int, j: Int): T = origin.getGeneric(i, j)
|
override fun get(i: Int, j: Int): T = origin.getGeneric(i, j)
|
||||||
|
|
||||||
|
override fun equals(other: Any?): Boolean {
|
||||||
|
return NDStructure.equals(this, other as? NDStructure<*> ?: return false)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun hashCode(): Int {
|
||||||
|
var result = origin.hashCode()
|
||||||
|
result = 31 * result + features.hashCode()
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
class KomaVector<T : Any> internal constructor(val origin: koma.matrix.Matrix<T>) : Point<T> {
|
class KomaVector<T : Any> internal constructor(val origin: koma.matrix.Matrix<T>) : Point<T> {
|
||||||
|
@ -1,10 +1,12 @@
|
|||||||
pluginManagement {
|
pluginManagement {
|
||||||
|
|
||||||
|
val toolsVersion = "0.5.0"
|
||||||
|
|
||||||
plugins {
|
plugins {
|
||||||
id("scientifik.mpp") version "0.4.1"
|
id("scientifik.mpp") version toolsVersion
|
||||||
id("scientifik.jvm") version "0.4.1"
|
id("scientifik.jvm") version toolsVersion
|
||||||
id("scientifik.atomic") version "0.4.1"
|
id("scientifik.atomic") version toolsVersion
|
||||||
id("scientifik.publish") version "0.4.1"
|
id("scientifik.publish") version toolsVersion
|
||||||
}
|
}
|
||||||
|
|
||||||
repositories {
|
repositories {
|
||||||
@ -20,7 +22,7 @@ pluginManagement {
|
|||||||
resolutionStrategy {
|
resolutionStrategy {
|
||||||
eachPlugin {
|
eachPlugin {
|
||||||
when (requested.id.id) {
|
when (requested.id.id) {
|
||||||
"scientifik.mpp", "scientifik.jvm", "scientifik.publish" -> useModule("scientifik:gradle-tools:${requested.version}")
|
"scientifik.mpp", "scientifik.jvm", "scientifik.publish" -> useModule("scientifik:gradle-tools:$toolsVersion")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user