Exact conversions from Long to Int, Int indexing of Dimension

This commit is contained in:
Iaroslav Postovalov 2021-06-20 17:57:33 +07:00 committed by Iaroslav Postovalov
parent ec8f14a6e9
commit 8b3298f7a8
14 changed files with 83 additions and 36 deletions

View File

@ -19,7 +19,7 @@ private fun DMatrixContext<Double, *>.simple() {
}
private object D5 : Dimension {
override val dim: UInt = 5u
override val dim: Int = 5
}
private fun DMatrixContext<Double, *>.custom() {

View File

@ -8,6 +8,7 @@ package space.kscience.kmath.commons.random
import kotlinx.coroutines.runBlocking
import space.kscience.kmath.misc.PerformancePitfall
import space.kscience.kmath.samplers.GaussianSampler
import space.kscience.kmath.misc.toIntExact
import space.kscience.kmath.stat.RandomGenerator
import space.kscience.kmath.stat.next
@ -28,7 +29,7 @@ public class CMRandomGeneratorWrapper(
}
override fun setSeed(seed: Long) {
setSeed(seed.toInt())
setSeed(seed.toIntExact())
}
override fun nextBytes(bytes: ByteArray) {

View File

@ -0,0 +1,8 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.misc
public expect fun Long.toIntExact(): Int

View File

@ -0,0 +1,12 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.misc
public actual fun Long.toIntExact(): Int {
val i = toInt()
if (i.toLong() == this) throw ArithmeticException("integer overflow")
return i
}

View File

@ -0,0 +1,8 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.misc
public actual fun Long.toIntExact(): Int = Math.toIntExact(this)

View File

@ -0,0 +1,12 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.misc
public actual fun Long.toIntExact(): Int {
val i = toInt()
if (i.toLong() == this) throw ArithmeticException("integer overflow")
return i
}

View File

@ -8,47 +8,47 @@ package space.kscience.kmath.dimensions
import kotlin.reflect.KClass
/**
* Represents a quantity of dimensions in certain structure.
* Represents a quantity of dimensions in certain structure. **This interface must be implemented only by objects.**
*
* @property dim The number of dimensions.
*/
public interface Dimension {
public val dim: UInt
public val dim: Int
public companion object
}
public fun <D : Dimension> KClass<D>.dim(): UInt = Dimension.resolve(this).dim
public fun <D : Dimension> KClass<D>.dim(): Int = Dimension.resolve(this).dim
public expect fun <D : Dimension> Dimension.Companion.resolve(type: KClass<D>): D
/**
* Finds or creates [Dimension] with [Dimension.dim] equal to [dim].
*/
public expect fun Dimension.Companion.of(dim: UInt): Dimension
public expect fun Dimension.Companion.of(dim: Int): Dimension
/**
* Finds [Dimension.dim] of given type [D].
*/
public inline fun <reified D : Dimension> Dimension.Companion.dim(): UInt = D::class.dim()
public inline fun <reified D : Dimension> Dimension.Companion.dim(): Int = D::class.dim()
/**
* Type representing 1 dimension.
*/
public object D1 : Dimension {
override val dim: UInt get() = 1U
override val dim: Int get() = 1
}
/**
* Type representing 2 dimensions.
*/
public object D2 : Dimension {
override val dim: UInt get() = 2U
override val dim: Int get() = 2
}
/**
* Type representing 3 dimensions.
*/
public object D3 : Dimension {
override val dim: UInt get() = 3U
override val dim: Int get() = 3
}

View File

@ -23,11 +23,11 @@ public interface DMatrix<out T, R : Dimension, C : Dimension> : Structure2D<T> {
* Coerces a regular matrix to a matrix with type-safe dimensions and throws a error if coercion failed
*/
public inline fun <T, reified R : Dimension, reified C : Dimension> coerce(structure: Structure2D<T>): DMatrix<T, R, C> {
require(structure.rowNum == Dimension.dim<R>().toInt()) {
require(structure.rowNum == Dimension.dim<R>()) {
"Row number mismatch: expected ${Dimension.dim<R>()} but found ${structure.rowNum}"
}
require(structure.colNum == Dimension.dim<C>().toInt()) {
require(structure.colNum == Dimension.dim<C>()) {
"Column number mismatch: expected ${Dimension.dim<C>()} but found ${structure.colNum}"
}
@ -61,7 +61,7 @@ public value class DMatrixWrapper<out T, R : Dimension, C : Dimension>(
public interface DPoint<out T, D : Dimension> : Point<T> {
public companion object {
public inline fun <T, reified D : Dimension> coerce(point: Point<T>): DPoint<T, D> {
require(point.size == Dimension.dim<D>().toInt()) {
require(point.size == Dimension.dim<D>()) {
"Vector dimension mismatch: expected ${Dimension.dim<D>()}, but found ${point.size}"
}
@ -92,11 +92,11 @@ public value class DPointWrapper<out T, D : Dimension>(public val point: Point<T
@JvmInline
public value class DMatrixContext<T : Any, out A : Ring<T>>(public val context: LinearSpace<T, A>) {
public inline fun <reified R : Dimension, reified C : Dimension> Matrix<T>.coerce(): DMatrix<T, R, C> {
require(rowNum == Dimension.dim<R>().toInt()) {
require(rowNum == Dimension.dim<R>()) {
"Row number mismatch: expected ${Dimension.dim<R>()} but found $rowNum"
}
require(colNum == Dimension.dim<C>().toInt()) {
require(colNum == Dimension.dim<C>()) {
"Column number mismatch: expected ${Dimension.dim<C>()} but found $colNum"
}
@ -111,7 +111,7 @@ public value class DMatrixContext<T : Any, out A : Ring<T>>(public val context:
): DMatrix<T, R, C> {
val rows = Dimension.dim<R>()
val cols = Dimension.dim<C>()
return context.buildMatrix(rows.toInt(), cols.toInt(), initializer).coerce()
return context.buildMatrix(rows, cols, initializer).coerce()
}
public inline fun <reified D : Dimension> point(noinline initializer: A.(Int) -> T): DPoint<T, D> {
@ -119,7 +119,7 @@ public value class DMatrixContext<T : Any, out A : Ring<T>>(public val context:
return DPoint.coerceUnsafe(
context.buildVector(
size.toInt(),
size,
initializer
)
)

View File

@ -7,17 +7,17 @@ package space.kscience.kmath.dimensions
import kotlin.reflect.KClass
private val dimensionMap: MutableMap<UInt, Dimension> = hashMapOf(1u to D1, 2u to D2, 3u to D3)
private val dimensionMap: MutableMap<Int, Dimension> = hashMapOf(1 to D1, 2 to D2, 3 to D3)
@Suppress("UNCHECKED_CAST")
public actual fun <D : Dimension> Dimension.Companion.resolve(type: KClass<D>): D = dimensionMap
.entries
.map(MutableMap.MutableEntry<UInt, Dimension>::value)
.map(MutableMap.MutableEntry<Int, Dimension>::value)
.find { it::class == type } as? D
?: error("Can't resolve dimension $type")
public actual fun Dimension.Companion.of(dim: UInt): Dimension = dimensionMap.getOrPut(dim) {
public actual fun Dimension.Companion.of(dim: Int): Dimension = dimensionMap.getOrPut(dim) {
object : Dimension {
override val dim: UInt get() = dim
override val dim: Int get() = dim
}
}

View File

@ -3,6 +3,8 @@
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
@file:JvmName("DimensionJVM")
package space.kscience.kmath.dimensions
import kotlin.reflect.KClass
@ -10,12 +12,12 @@ import kotlin.reflect.KClass
public actual fun <D : Dimension> Dimension.Companion.resolve(type: KClass<D>): D =
type.objectInstance ?: error("No object instance for dimension class")
public actual fun Dimension.Companion.of(dim: UInt): Dimension = when (dim) {
1u -> D1
2u -> D2
3u -> D3
public actual fun Dimension.Companion.of(dim: Int): Dimension = when (dim) {
1 -> D1
2 -> D2
3 -> D3
else -> object : Dimension {
override val dim: UInt get() = dim
override val dim: Int get() = dim
}
}

View File

@ -9,17 +9,17 @@ import kotlin.native.concurrent.ThreadLocal
import kotlin.reflect.KClass
@ThreadLocal
private val dimensionMap: MutableMap<UInt, Dimension> = hashMapOf(1u to D1, 2u to D2, 3u to D3)
private val dimensionMap: MutableMap<Int, Dimension> = hashMapOf(1 to D1, 2 to D2, 3 to D3)
@Suppress("UNCHECKED_CAST")
public actual fun <D : Dimension> Dimension.Companion.resolve(type: KClass<D>): D = dimensionMap
.entries
.map(MutableMap.MutableEntry<UInt, Dimension>::value)
.map(MutableMap.MutableEntry<Int, Dimension>::value)
.find { it::class == type } as? D
?: error("Can't resolve dimension $type")
public actual fun Dimension.Companion.of(dim: UInt): Dimension = dimensionMap.getOrPut(dim) {
public actual fun Dimension.Companion.of(dim: Int): Dimension = dimensionMap.getOrPut(dim) {
object : Dimension {
override val dim: UInt get() = dim
override val dim: Int get() = dim
}
}

View File

@ -5,4 +5,6 @@
package space.kscience.kmath.nd4j
internal fun LongArray.toIntArray(): IntArray = IntArray(size) { this[it].toInt() }
import space.kscience.kmath.misc.toIntExact
internal fun LongArray.toIntArray(): IntArray = IntArray(size) { this[it].toIntExact() }

View File

@ -7,6 +7,7 @@ package space.kscience.kmath.samplers
import space.kscience.kmath.chains.BlockingIntChain
import space.kscience.kmath.internal.InternalUtils
import space.kscience.kmath.misc.toIntExact
import space.kscience.kmath.stat.RandomGenerator
import space.kscience.kmath.stat.Sampler
import space.kscience.kmath.structures.IntBuffer
@ -119,7 +120,7 @@ public class LargeMeanPoissonSampler(public val mean: Double) : Sampler<Int> {
val gaussian = ZigguratNormalizedGaussianSampler.sample(generator)
val smallMeanPoissonSampler = if (mean - lambda < Double.MIN_VALUE) {
null
null
} else {
KempSmallMeanPoissonSampler(mean - lambda).sample(generator)
}
@ -188,7 +189,7 @@ public class LargeMeanPoissonSampler(public val mean: Double) : Sampler<Int> {
}
}
return min(y2 + y.toLong(), Int.MAX_VALUE.toLong()).toInt()
return min(y2 + y.toLong(), Int.MAX_VALUE.toLong()).toIntExact()
}
override fun nextBufferBlocking(size: Int): IntBuffer = IntBuffer(size) { nextBlocking() }

View File

@ -6,6 +6,7 @@
package space.kscience.kmath.samplers
import space.kscience.kmath.chains.BlockingDoubleChain
import space.kscience.kmath.misc.toIntExact
import space.kscience.kmath.stat.RandomGenerator
import space.kscience.kmath.structures.DoubleBuffer
import kotlin.math.*
@ -58,7 +59,7 @@ public object ZigguratNormalizedGaussianSampler : NormalizedGaussianSampler {
private fun sampleOne(generator: RandomGenerator): Double {
val j = generator.nextLong()
val i = (j and LAST.toLong()).toInt()
val i = (j and LAST.toLong()).toIntExact()
return if (abs(j) < K[i]) j * W[i] else fix(generator, j, i)
}