Exact conversions from Long to Int, Int indexing of Dimension

This commit is contained in:
Iaroslav Postovalov 2021-06-20 17:57:33 +07:00 committed by Iaroslav Postovalov
parent ec8f14a6e9
commit 8b3298f7a8
14 changed files with 83 additions and 36 deletions

View File

@ -19,7 +19,7 @@ private fun DMatrixContext<Double, *>.simple() {
} }
private object D5 : Dimension { private object D5 : Dimension {
override val dim: UInt = 5u override val dim: Int = 5
} }
private fun DMatrixContext<Double, *>.custom() { private fun DMatrixContext<Double, *>.custom() {

View File

@ -8,6 +8,7 @@ package space.kscience.kmath.commons.random
import kotlinx.coroutines.runBlocking import kotlinx.coroutines.runBlocking
import space.kscience.kmath.misc.PerformancePitfall import space.kscience.kmath.misc.PerformancePitfall
import space.kscience.kmath.samplers.GaussianSampler import space.kscience.kmath.samplers.GaussianSampler
import space.kscience.kmath.misc.toIntExact
import space.kscience.kmath.stat.RandomGenerator import space.kscience.kmath.stat.RandomGenerator
import space.kscience.kmath.stat.next import space.kscience.kmath.stat.next
@ -28,7 +29,7 @@ public class CMRandomGeneratorWrapper(
} }
override fun setSeed(seed: Long) { override fun setSeed(seed: Long) {
setSeed(seed.toInt()) setSeed(seed.toIntExact())
} }
override fun nextBytes(bytes: ByteArray) { override fun nextBytes(bytes: ByteArray) {

View File

@ -0,0 +1,8 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.misc
public expect fun Long.toIntExact(): Int

View File

@ -0,0 +1,12 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.misc
public actual fun Long.toIntExact(): Int {
val i = toInt()
if (i.toLong() == this) throw ArithmeticException("integer overflow")
return i
}

View File

@ -0,0 +1,8 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.misc
public actual fun Long.toIntExact(): Int = Math.toIntExact(this)

View File

@ -0,0 +1,12 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.misc
public actual fun Long.toIntExact(): Int {
val i = toInt()
if (i.toLong() == this) throw ArithmeticException("integer overflow")
return i
}

View File

@ -8,47 +8,47 @@ package space.kscience.kmath.dimensions
import kotlin.reflect.KClass import kotlin.reflect.KClass
/** /**
* Represents a quantity of dimensions in certain structure. * Represents a quantity of dimensions in certain structure. **This interface must be implemented only by objects.**
* *
* @property dim The number of dimensions. * @property dim The number of dimensions.
*/ */
public interface Dimension { public interface Dimension {
public val dim: UInt public val dim: Int
public companion object public companion object
} }
public fun <D : Dimension> KClass<D>.dim(): UInt = Dimension.resolve(this).dim public fun <D : Dimension> KClass<D>.dim(): Int = Dimension.resolve(this).dim
public expect fun <D : Dimension> Dimension.Companion.resolve(type: KClass<D>): D public expect fun <D : Dimension> Dimension.Companion.resolve(type: KClass<D>): D
/** /**
* Finds or creates [Dimension] with [Dimension.dim] equal to [dim]. * Finds or creates [Dimension] with [Dimension.dim] equal to [dim].
*/ */
public expect fun Dimension.Companion.of(dim: UInt): Dimension public expect fun Dimension.Companion.of(dim: Int): Dimension
/** /**
* Finds [Dimension.dim] of given type [D]. * Finds [Dimension.dim] of given type [D].
*/ */
public inline fun <reified D : Dimension> Dimension.Companion.dim(): UInt = D::class.dim() public inline fun <reified D : Dimension> Dimension.Companion.dim(): Int = D::class.dim()
/** /**
* Type representing 1 dimension. * Type representing 1 dimension.
*/ */
public object D1 : Dimension { public object D1 : Dimension {
override val dim: UInt get() = 1U override val dim: Int get() = 1
} }
/** /**
* Type representing 2 dimensions. * Type representing 2 dimensions.
*/ */
public object D2 : Dimension { public object D2 : Dimension {
override val dim: UInt get() = 2U override val dim: Int get() = 2
} }
/** /**
* Type representing 3 dimensions. * Type representing 3 dimensions.
*/ */
public object D3 : Dimension { public object D3 : Dimension {
override val dim: UInt get() = 3U override val dim: Int get() = 3
} }

View File

@ -23,11 +23,11 @@ public interface DMatrix<out T, R : Dimension, C : Dimension> : Structure2D<T> {
* Coerces a regular matrix to a matrix with type-safe dimensions and throws a error if coercion failed * Coerces a regular matrix to a matrix with type-safe dimensions and throws a error if coercion failed
*/ */
public inline fun <T, reified R : Dimension, reified C : Dimension> coerce(structure: Structure2D<T>): DMatrix<T, R, C> { public inline fun <T, reified R : Dimension, reified C : Dimension> coerce(structure: Structure2D<T>): DMatrix<T, R, C> {
require(structure.rowNum == Dimension.dim<R>().toInt()) { require(structure.rowNum == Dimension.dim<R>()) {
"Row number mismatch: expected ${Dimension.dim<R>()} but found ${structure.rowNum}" "Row number mismatch: expected ${Dimension.dim<R>()} but found ${structure.rowNum}"
} }
require(structure.colNum == Dimension.dim<C>().toInt()) { require(structure.colNum == Dimension.dim<C>()) {
"Column number mismatch: expected ${Dimension.dim<C>()} but found ${structure.colNum}" "Column number mismatch: expected ${Dimension.dim<C>()} but found ${structure.colNum}"
} }
@ -61,7 +61,7 @@ public value class DMatrixWrapper<out T, R : Dimension, C : Dimension>(
public interface DPoint<out T, D : Dimension> : Point<T> { public interface DPoint<out T, D : Dimension> : Point<T> {
public companion object { public companion object {
public inline fun <T, reified D : Dimension> coerce(point: Point<T>): DPoint<T, D> { public inline fun <T, reified D : Dimension> coerce(point: Point<T>): DPoint<T, D> {
require(point.size == Dimension.dim<D>().toInt()) { require(point.size == Dimension.dim<D>()) {
"Vector dimension mismatch: expected ${Dimension.dim<D>()}, but found ${point.size}" "Vector dimension mismatch: expected ${Dimension.dim<D>()}, but found ${point.size}"
} }
@ -92,11 +92,11 @@ public value class DPointWrapper<out T, D : Dimension>(public val point: Point<T
@JvmInline @JvmInline
public value class DMatrixContext<T : Any, out A : Ring<T>>(public val context: LinearSpace<T, A>) { public value class DMatrixContext<T : Any, out A : Ring<T>>(public val context: LinearSpace<T, A>) {
public inline fun <reified R : Dimension, reified C : Dimension> Matrix<T>.coerce(): DMatrix<T, R, C> { public inline fun <reified R : Dimension, reified C : Dimension> Matrix<T>.coerce(): DMatrix<T, R, C> {
require(rowNum == Dimension.dim<R>().toInt()) { require(rowNum == Dimension.dim<R>()) {
"Row number mismatch: expected ${Dimension.dim<R>()} but found $rowNum" "Row number mismatch: expected ${Dimension.dim<R>()} but found $rowNum"
} }
require(colNum == Dimension.dim<C>().toInt()) { require(colNum == Dimension.dim<C>()) {
"Column number mismatch: expected ${Dimension.dim<C>()} but found $colNum" "Column number mismatch: expected ${Dimension.dim<C>()} but found $colNum"
} }
@ -111,7 +111,7 @@ public value class DMatrixContext<T : Any, out A : Ring<T>>(public val context:
): DMatrix<T, R, C> { ): DMatrix<T, R, C> {
val rows = Dimension.dim<R>() val rows = Dimension.dim<R>()
val cols = Dimension.dim<C>() val cols = Dimension.dim<C>()
return context.buildMatrix(rows.toInt(), cols.toInt(), initializer).coerce() return context.buildMatrix(rows, cols, initializer).coerce()
} }
public inline fun <reified D : Dimension> point(noinline initializer: A.(Int) -> T): DPoint<T, D> { public inline fun <reified D : Dimension> point(noinline initializer: A.(Int) -> T): DPoint<T, D> {
@ -119,7 +119,7 @@ public value class DMatrixContext<T : Any, out A : Ring<T>>(public val context:
return DPoint.coerceUnsafe( return DPoint.coerceUnsafe(
context.buildVector( context.buildVector(
size.toInt(), size,
initializer initializer
) )
) )

View File

@ -7,17 +7,17 @@ package space.kscience.kmath.dimensions
import kotlin.reflect.KClass import kotlin.reflect.KClass
private val dimensionMap: MutableMap<UInt, Dimension> = hashMapOf(1u to D1, 2u to D2, 3u to D3) private val dimensionMap: MutableMap<Int, Dimension> = hashMapOf(1 to D1, 2 to D2, 3 to D3)
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
public actual fun <D : Dimension> Dimension.Companion.resolve(type: KClass<D>): D = dimensionMap public actual fun <D : Dimension> Dimension.Companion.resolve(type: KClass<D>): D = dimensionMap
.entries .entries
.map(MutableMap.MutableEntry<UInt, Dimension>::value) .map(MutableMap.MutableEntry<Int, Dimension>::value)
.find { it::class == type } as? D .find { it::class == type } as? D
?: error("Can't resolve dimension $type") ?: error("Can't resolve dimension $type")
public actual fun Dimension.Companion.of(dim: UInt): Dimension = dimensionMap.getOrPut(dim) { public actual fun Dimension.Companion.of(dim: Int): Dimension = dimensionMap.getOrPut(dim) {
object : Dimension { object : Dimension {
override val dim: UInt get() = dim override val dim: Int get() = dim
} }
} }

View File

@ -3,6 +3,8 @@
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/ */
@file:JvmName("DimensionJVM")
package space.kscience.kmath.dimensions package space.kscience.kmath.dimensions
import kotlin.reflect.KClass import kotlin.reflect.KClass
@ -10,12 +12,12 @@ import kotlin.reflect.KClass
public actual fun <D : Dimension> Dimension.Companion.resolve(type: KClass<D>): D = public actual fun <D : Dimension> Dimension.Companion.resolve(type: KClass<D>): D =
type.objectInstance ?: error("No object instance for dimension class") type.objectInstance ?: error("No object instance for dimension class")
public actual fun Dimension.Companion.of(dim: UInt): Dimension = when (dim) { public actual fun Dimension.Companion.of(dim: Int): Dimension = when (dim) {
1u -> D1 1 -> D1
2u -> D2 2 -> D2
3u -> D3 3 -> D3
else -> object : Dimension { else -> object : Dimension {
override val dim: UInt get() = dim override val dim: Int get() = dim
} }
} }

View File

@ -9,17 +9,17 @@ import kotlin.native.concurrent.ThreadLocal
import kotlin.reflect.KClass import kotlin.reflect.KClass
@ThreadLocal @ThreadLocal
private val dimensionMap: MutableMap<UInt, Dimension> = hashMapOf(1u to D1, 2u to D2, 3u to D3) private val dimensionMap: MutableMap<Int, Dimension> = hashMapOf(1 to D1, 2 to D2, 3 to D3)
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
public actual fun <D : Dimension> Dimension.Companion.resolve(type: KClass<D>): D = dimensionMap public actual fun <D : Dimension> Dimension.Companion.resolve(type: KClass<D>): D = dimensionMap
.entries .entries
.map(MutableMap.MutableEntry<UInt, Dimension>::value) .map(MutableMap.MutableEntry<Int, Dimension>::value)
.find { it::class == type } as? D .find { it::class == type } as? D
?: error("Can't resolve dimension $type") ?: error("Can't resolve dimension $type")
public actual fun Dimension.Companion.of(dim: UInt): Dimension = dimensionMap.getOrPut(dim) { public actual fun Dimension.Companion.of(dim: Int): Dimension = dimensionMap.getOrPut(dim) {
object : Dimension { object : Dimension {
override val dim: UInt get() = dim override val dim: Int get() = dim
} }
} }

View File

@ -5,4 +5,6 @@
package space.kscience.kmath.nd4j package space.kscience.kmath.nd4j
internal fun LongArray.toIntArray(): IntArray = IntArray(size) { this[it].toInt() } import space.kscience.kmath.misc.toIntExact
internal fun LongArray.toIntArray(): IntArray = IntArray(size) { this[it].toIntExact() }

View File

@ -7,6 +7,7 @@ package space.kscience.kmath.samplers
import space.kscience.kmath.chains.BlockingIntChain import space.kscience.kmath.chains.BlockingIntChain
import space.kscience.kmath.internal.InternalUtils import space.kscience.kmath.internal.InternalUtils
import space.kscience.kmath.misc.toIntExact
import space.kscience.kmath.stat.RandomGenerator import space.kscience.kmath.stat.RandomGenerator
import space.kscience.kmath.stat.Sampler import space.kscience.kmath.stat.Sampler
import space.kscience.kmath.structures.IntBuffer import space.kscience.kmath.structures.IntBuffer
@ -188,7 +189,7 @@ public class LargeMeanPoissonSampler(public val mean: Double) : Sampler<Int> {
} }
} }
return min(y2 + y.toLong(), Int.MAX_VALUE.toLong()).toInt() return min(y2 + y.toLong(), Int.MAX_VALUE.toLong()).toIntExact()
} }
override fun nextBufferBlocking(size: Int): IntBuffer = IntBuffer(size) { nextBlocking() } override fun nextBufferBlocking(size: Int): IntBuffer = IntBuffer(size) { nextBlocking() }

View File

@ -6,6 +6,7 @@
package space.kscience.kmath.samplers package space.kscience.kmath.samplers
import space.kscience.kmath.chains.BlockingDoubleChain import space.kscience.kmath.chains.BlockingDoubleChain
import space.kscience.kmath.misc.toIntExact
import space.kscience.kmath.stat.RandomGenerator import space.kscience.kmath.stat.RandomGenerator
import space.kscience.kmath.structures.DoubleBuffer import space.kscience.kmath.structures.DoubleBuffer
import kotlin.math.* import kotlin.math.*
@ -58,7 +59,7 @@ public object ZigguratNormalizedGaussianSampler : NormalizedGaussianSampler {
private fun sampleOne(generator: RandomGenerator): Double { private fun sampleOne(generator: RandomGenerator): Double {
val j = generator.nextLong() val j = generator.nextLong()
val i = (j and LAST.toLong()).toInt() val i = (j and LAST.toLong()).toIntExact()
return if (abs(j) < K[i]) j * W[i] else fix(generator, j, i) return if (abs(j) < K[i]) j * W[i] else fix(generator, j, i)
} }