testing copying

This commit is contained in:
Roland Grinis 2021-07-13 13:08:00 +01:00
parent d303c912d6
commit face60824d
3 changed files with 49 additions and 11 deletions

View File

@ -36,14 +36,14 @@ protected constructor(protected val scope: NoaScope) :
override fun Tensor<T>.value(): T = tensor.item()
public abstract fun randDiscrete(low: Long, high: Long, shape: IntArray, device: Device): TensorType
public abstract fun randDiscrete(low: Long, high: Long, shape: IntArray, device: Device = Device.CPU): TensorType
@PerformancePitfall
public abstract fun Tensor<T>.copyToArray(): PrimitiveArray
public abstract fun copyFromArray(array: PrimitiveArray, shape: IntArray, device: Device): TensorType
public abstract fun copyFromArray(array: PrimitiveArray, shape: IntArray, device: Device = Device.CPU): TensorType
public abstract fun full(value: T, shape: IntArray, device: Device): TensorType
public abstract fun full(value: T, shape: IntArray, device: Device = Device.CPU): TensorType
override operator fun Tensor<T>.times(other: Tensor<T>): TensorType {
return wrap(JNoa.timesTensor(tensor.tensorHandle, other.tensor.tensorHandle))
@ -136,10 +136,10 @@ protected constructor(protected val scope: NoaScope) :
public fun Tensor<T>.copy(): TensorType =
wrap(JNoa.copyTensor(tensor.tensorHandle))
public fun Tensor<T>.copyToDevice(device: Device): TensorType =
public fun Tensor<T>.copyToDevice(device: Device = Device.CPU): TensorType =
wrap(JNoa.copyToDevice(tensor.tensorHandle, device.toInt()))
public abstract fun loadJitModule(path: String, device: Device): NoaJitModule
public abstract fun loadJitModule(path: String, device: Device = Device.CPU): NoaJitModule
public fun NoaJitModule.forward(parameters: Tensor<T>): TensorType =
wrap(JNoa.forwardPass(jitModuleHandle, parameters.tensor.tensorHandle))
@ -193,9 +193,9 @@ protected constructor(scope: NoaScope) :
override fun Tensor<T>.variance(dim: Int, keepDim: Boolean): TensorType =
wrap(JNoa.varDimTensor(tensor.tensorHandle, dim, keepDim))
public abstract fun randNormal(shape: IntArray, device: Device): TensorType
public abstract fun randNormal(shape: IntArray, device: Device = Device.CPU): TensorType
public abstract fun randUniform(shape: IntArray, device: Device): TensorType
public abstract fun randUniform(shape: IntArray, device: Device = Device.CPU): TensorType
public fun Tensor<T>.randUniform(): TensorType =
wrap(JNoa.randLike(tensor.tensorHandle))
@ -301,7 +301,7 @@ protected constructor(scope: NoaScope) :
return Pair(wrap(S), wrap(V))
}
public fun TensorType.autoGradient(variable: TensorType, retainGraph: Boolean): TensorType =
public fun TensorType.autoGradient(variable: TensorType, retainGraph: Boolean = false): TensorType =
wrap(JNoa.autoGradTensor(tensorHandle, variable.tensorHandle, retainGraph))
public fun TensorType.autoHessian(variable: TensorType): TensorType =

View File

@ -0,0 +1,40 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.noa
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertTrue
internal fun NoaFloat.testingCopying(device: Device = Device.CPU): Unit {
val array = (1..24).map { 10f * it * it }.toFloatArray()
val shape = intArrayOf(2, 3, 4)
val tensor = copyFromArray(array, shape = shape, device = device)
val copyOfTensor = tensor.copy()
tensor[intArrayOf(1, 2, 3)] = 0.1f
assertTrue(copyOfTensor.copyToArray() contentEquals array)
assertEquals(0.1f, tensor[intArrayOf(1, 2, 3)])
if(device != Device.CPU){
val normalCpu = randNormal(intArrayOf(2, 3))
val normalGpu = normalCpu.copyToDevice(device)
assertTrue(normalCpu.copyToArray() contentEquals normalGpu.copyToArray())
val uniformGpu = randUniform(intArrayOf(3,2),device)
val uniformCpu = uniformGpu.copyToDevice(Device.CPU)
assertTrue(uniformGpu.copyToArray() contentEquals uniformCpu.copyToArray())
}
}
class TestTensor {
@Test
fun testCopying(): Unit = NoaFloat {
withCuda { device ->
testingCopying(device)
}
}!!
}

View File

@ -59,6 +59,4 @@ class TestUtils {
}
}!!
}