testing seed setting
This commit is contained in:
parent
a33af9ec94
commit
d303c912d6
@ -35,8 +35,6 @@ val cmakeCmd = "$thirdPartyDir/cmake/$cmakeArchive/bin/cmake"
|
|||||||
val ninjaCmd = "$thirdPartyDir/ninja/ninja"
|
val ninjaCmd = "$thirdPartyDir/ninja/ninja"
|
||||||
|
|
||||||
val generateJNIHeader by tasks.registering {
|
val generateJNIHeader by tasks.registering {
|
||||||
println("Path:")
|
|
||||||
println(System.getProperty("java.library.path"))
|
|
||||||
doLast {
|
doLast {
|
||||||
exec {
|
exec {
|
||||||
workingDir(projectDir.resolve("src/main/java/space/kscience/kmath/noa"))
|
workingDir(projectDir.resolve("src/main/java/space/kscience/kmath/noa"))
|
||||||
|
@ -8,6 +8,31 @@ package space.kscience.kmath.noa
|
|||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
|
|
||||||
|
internal val SEED = 987654
|
||||||
|
internal val TOLERANCE = 1e-6
|
||||||
|
|
||||||
|
internal fun <T, ArrayT, TensorT : NoaTensor<T>, AlgebraT : NoaAlgebra<T, ArrayT, TensorT>>
|
||||||
|
AlgebraT.withCuda(block: AlgebraT.(Device) -> Unit): Unit {
|
||||||
|
this.block(Device.CPU)
|
||||||
|
if (cudaAvailable()) this.block(Device.CUDA(0))
|
||||||
|
}
|
||||||
|
|
||||||
|
internal fun NoaFloat.testingSetSeed(device: Device = Device.CPU): Unit {
|
||||||
|
setSeed(SEED)
|
||||||
|
val integral = randDiscrete(0, 100, IntArray(0), device = device).value()
|
||||||
|
val normal = randNormal(IntArray(0), device = device).value()
|
||||||
|
val uniform = randUniform(IntArray(0), device = device).value()
|
||||||
|
setSeed(SEED)
|
||||||
|
val nextIntegral = randDiscrete(0, 100, IntArray(0), device = device).value()
|
||||||
|
val nextNormal = randNormal(IntArray(0), device = device).value()
|
||||||
|
val nextUniform = randUniform(IntArray(0), device = device).value()
|
||||||
|
assertEquals(normal, nextNormal)
|
||||||
|
assertEquals(uniform, nextUniform)
|
||||||
|
assertEquals(integral, nextIntegral)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class TestUtils {
|
class TestUtils {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@ -27,4 +52,13 @@ class TestUtils {
|
|||||||
assertEquals(numThreads, getNumThreads())
|
assertEquals(numThreads, getNumThreads())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testSetSeed(): Unit = NoaFloat {
|
||||||
|
withCuda { device ->
|
||||||
|
testingSetSeed(device)
|
||||||
|
}
|
||||||
|
}!!
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user