39 lines
1.7 KiB
Kotlin
Raw Normal View History

2021-01-18 19:02:01 +00:00
@file:Suppress("NOTHING_TO_INLINE")
2021-01-16 21:03:11 +00:00
package kscience.kmath.torch
2021-01-18 19:02:01 +00:00
import kotlin.test.assertEquals
2021-01-16 21:03:11 +00:00
internal val SEED = 987654
internal val TOLERANCE = 1e-6
2021-01-18 19:02:01 +00:00
internal inline fun <T, PrimitiveArrayType, TorchTensorType : TorchTensor<T>,
TorchTensorAlgebraType : TorchTensorAlgebra<T, PrimitiveArrayType, TorchTensorType>>
TorchTensorAlgebraType.withCuda(block: TorchTensorAlgebraType.(Device) -> Unit): Unit {
this.block(Device.CPU)
if (cudaAvailable()) this.block(Device.CUDA(0))
}
internal inline fun <T, PrimitiveArrayType, TorchTensorType : TorchTensor<T>,
TorchTensorAlgebraType : TorchTensorAlgebra<T, PrimitiveArrayType, TorchTensorType>>
TorchTensorAlgebraType.testingSetNumThreads(): Unit {
val numThreads = 2
setNumThreads(numThreads)
assertEquals(numThreads, getNumThreads())
}
internal inline fun <TorchTensorType : TorchTensorOverField<Float>,
TorchTensorAlgebraType : TorchTensorPartialDivisionAlgebra<Float, FloatArray, TorchTensorType>>
TorchTensorAlgebraType.testingSetSeed(device: Device = Device.CPU): Unit {
setSeed(SEED)
val integral = randIntegral(0f, 100f, IntArray(0), device = device).value()
val normal = randNormal(IntArray(0), device = device).value()
val uniform = randUniform(IntArray(0), device = device).value()
setSeed(SEED)
val nextIntegral = randIntegral(0f, 100f, IntArray(0), device = device).value()
val nextNormal = randNormal(IntArray(0), device = device).value()
val nextUniform = randUniform(IntArray(0), device = device).value()
assertEquals(normal, nextNormal)
assertEquals(uniform, nextUniform)
assertEquals(integral, nextIntegral)
}