2021-01-18 19:02:01 +00:00
|
|
|
@file:Suppress("NOTHING_TO_INLINE")
|
|
|
|
|
2021-01-16 21:03:11 +00:00
|
|
|
package kscience.kmath.torch
|
|
|
|
|
2021-01-18 19:02:01 +00:00
|
|
|
import kotlin.test.assertEquals
|
|
|
|
|
2021-01-16 21:03:11 +00:00
|
|
|
internal val SEED = 987654
|
|
|
|
internal val TOLERANCE = 1e-6
|
|
|
|
|
2021-01-18 19:02:01 +00:00
|
|
|
internal inline fun <T, PrimitiveArrayType, TorchTensorType : TorchTensor<T>,
|
|
|
|
TorchTensorAlgebraType : TorchTensorAlgebra<T, PrimitiveArrayType, TorchTensorType>>
|
|
|
|
TorchTensorAlgebraType.withCuda(block: TorchTensorAlgebraType.(Device) -> Unit): Unit {
|
|
|
|
this.block(Device.CPU)
|
|
|
|
if (cudaAvailable()) this.block(Device.CUDA(0))
|
|
|
|
}
|
|
|
|
|
|
|
|
internal inline fun <T, PrimitiveArrayType, TorchTensorType : TorchTensor<T>,
|
|
|
|
TorchTensorAlgebraType : TorchTensorAlgebra<T, PrimitiveArrayType, TorchTensorType>>
|
|
|
|
TorchTensorAlgebraType.testingSetNumThreads(): Unit {
|
|
|
|
val numThreads = 2
|
|
|
|
setNumThreads(numThreads)
|
|
|
|
assertEquals(numThreads, getNumThreads())
|
|
|
|
}
|
|
|
|
|
|
|
|
internal inline fun <TorchTensorType : TorchTensorOverField<Float>,
|
|
|
|
TorchTensorAlgebraType : TorchTensorPartialDivisionAlgebra<Float, FloatArray, TorchTensorType>>
|
|
|
|
TorchTensorAlgebraType.testingSetSeed(device: Device = Device.CPU): Unit {
|
|
|
|
setSeed(SEED)
|
|
|
|
val integral = randIntegral(0f, 100f, IntArray(0), device = device).value()
|
|
|
|
val normal = randNormal(IntArray(0), device = device).value()
|
|
|
|
val uniform = randUniform(IntArray(0), device = device).value()
|
|
|
|
setSeed(SEED)
|
|
|
|
val nextIntegral = randIntegral(0f, 100f, IntArray(0), device = device).value()
|
|
|
|
val nextNormal = randNormal(IntArray(0), device = device).value()
|
|
|
|
val nextUniform = randUniform(IntArray(0), device = device).value()
|
|
|
|
assertEquals(normal, nextNormal)
|
|
|
|
assertEquals(uniform, nextUniform)
|
|
|
|
assertEquals(integral, nextIntegral)
|
|
|
|
}
|