testing copying
This commit is contained in:
parent
d303c912d6
commit
face60824d
@ -36,14 +36,14 @@ protected constructor(protected val scope: NoaScope) :
|
|||||||
|
|
||||||
override fun Tensor<T>.value(): T = tensor.item()
|
override fun Tensor<T>.value(): T = tensor.item()
|
||||||
|
|
||||||
public abstract fun randDiscrete(low: Long, high: Long, shape: IntArray, device: Device): TensorType
|
public abstract fun randDiscrete(low: Long, high: Long, shape: IntArray, device: Device = Device.CPU): TensorType
|
||||||
|
|
||||||
@PerformancePitfall
|
@PerformancePitfall
|
||||||
public abstract fun Tensor<T>.copyToArray(): PrimitiveArray
|
public abstract fun Tensor<T>.copyToArray(): PrimitiveArray
|
||||||
|
|
||||||
public abstract fun copyFromArray(array: PrimitiveArray, shape: IntArray, device: Device): TensorType
|
public abstract fun copyFromArray(array: PrimitiveArray, shape: IntArray, device: Device = Device.CPU): TensorType
|
||||||
|
|
||||||
public abstract fun full(value: T, shape: IntArray, device: Device): TensorType
|
public abstract fun full(value: T, shape: IntArray, device: Device = Device.CPU): TensorType
|
||||||
|
|
||||||
override operator fun Tensor<T>.times(other: Tensor<T>): TensorType {
|
override operator fun Tensor<T>.times(other: Tensor<T>): TensorType {
|
||||||
return wrap(JNoa.timesTensor(tensor.tensorHandle, other.tensor.tensorHandle))
|
return wrap(JNoa.timesTensor(tensor.tensorHandle, other.tensor.tensorHandle))
|
||||||
@ -136,10 +136,10 @@ protected constructor(protected val scope: NoaScope) :
|
|||||||
public fun Tensor<T>.copy(): TensorType =
|
public fun Tensor<T>.copy(): TensorType =
|
||||||
wrap(JNoa.copyTensor(tensor.tensorHandle))
|
wrap(JNoa.copyTensor(tensor.tensorHandle))
|
||||||
|
|
||||||
public fun Tensor<T>.copyToDevice(device: Device): TensorType =
|
public fun Tensor<T>.copyToDevice(device: Device = Device.CPU): TensorType =
|
||||||
wrap(JNoa.copyToDevice(tensor.tensorHandle, device.toInt()))
|
wrap(JNoa.copyToDevice(tensor.tensorHandle, device.toInt()))
|
||||||
|
|
||||||
public abstract fun loadJitModule(path: String, device: Device): NoaJitModule
|
public abstract fun loadJitModule(path: String, device: Device = Device.CPU): NoaJitModule
|
||||||
|
|
||||||
public fun NoaJitModule.forward(parameters: Tensor<T>): TensorType =
|
public fun NoaJitModule.forward(parameters: Tensor<T>): TensorType =
|
||||||
wrap(JNoa.forwardPass(jitModuleHandle, parameters.tensor.tensorHandle))
|
wrap(JNoa.forwardPass(jitModuleHandle, parameters.tensor.tensorHandle))
|
||||||
@ -193,9 +193,9 @@ protected constructor(scope: NoaScope) :
|
|||||||
override fun Tensor<T>.variance(dim: Int, keepDim: Boolean): TensorType =
|
override fun Tensor<T>.variance(dim: Int, keepDim: Boolean): TensorType =
|
||||||
wrap(JNoa.varDimTensor(tensor.tensorHandle, dim, keepDim))
|
wrap(JNoa.varDimTensor(tensor.tensorHandle, dim, keepDim))
|
||||||
|
|
||||||
public abstract fun randNormal(shape: IntArray, device: Device): TensorType
|
public abstract fun randNormal(shape: IntArray, device: Device = Device.CPU): TensorType
|
||||||
|
|
||||||
public abstract fun randUniform(shape: IntArray, device: Device): TensorType
|
public abstract fun randUniform(shape: IntArray, device: Device = Device.CPU): TensorType
|
||||||
|
|
||||||
public fun Tensor<T>.randUniform(): TensorType =
|
public fun Tensor<T>.randUniform(): TensorType =
|
||||||
wrap(JNoa.randLike(tensor.tensorHandle))
|
wrap(JNoa.randLike(tensor.tensorHandle))
|
||||||
@ -301,7 +301,7 @@ protected constructor(scope: NoaScope) :
|
|||||||
return Pair(wrap(S), wrap(V))
|
return Pair(wrap(S), wrap(V))
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun TensorType.autoGradient(variable: TensorType, retainGraph: Boolean): TensorType =
|
public fun TensorType.autoGradient(variable: TensorType, retainGraph: Boolean = false): TensorType =
|
||||||
wrap(JNoa.autoGradTensor(tensorHandle, variable.tensorHandle, retainGraph))
|
wrap(JNoa.autoGradTensor(tensorHandle, variable.tensorHandle, retainGraph))
|
||||||
|
|
||||||
public fun TensorType.autoHessian(variable: TensorType): TensorType =
|
public fun TensorType.autoHessian(variable: TensorType): TensorType =
|
||||||
|
@ -0,0 +1,40 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2018-2021 KMath contributors.
|
||||||
|
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package space.kscience.kmath.noa
|
||||||
|
|
||||||
|
import kotlin.test.Test
|
||||||
|
import kotlin.test.assertEquals
|
||||||
|
import kotlin.test.assertTrue
|
||||||
|
|
||||||
|
|
||||||
|
internal fun NoaFloat.testingCopying(device: Device = Device.CPU): Unit {
|
||||||
|
val array = (1..24).map { 10f * it * it }.toFloatArray()
|
||||||
|
val shape = intArrayOf(2, 3, 4)
|
||||||
|
val tensor = copyFromArray(array, shape = shape, device = device)
|
||||||
|
val copyOfTensor = tensor.copy()
|
||||||
|
tensor[intArrayOf(1, 2, 3)] = 0.1f
|
||||||
|
assertTrue(copyOfTensor.copyToArray() contentEquals array)
|
||||||
|
assertEquals(0.1f, tensor[intArrayOf(1, 2, 3)])
|
||||||
|
if(device != Device.CPU){
|
||||||
|
val normalCpu = randNormal(intArrayOf(2, 3))
|
||||||
|
val normalGpu = normalCpu.copyToDevice(device)
|
||||||
|
assertTrue(normalCpu.copyToArray() contentEquals normalGpu.copyToArray())
|
||||||
|
|
||||||
|
val uniformGpu = randUniform(intArrayOf(3,2),device)
|
||||||
|
val uniformCpu = uniformGpu.copyToDevice(Device.CPU)
|
||||||
|
assertTrue(uniformGpu.copyToArray() contentEquals uniformCpu.copyToArray())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TestTensor {
|
||||||
|
@Test
|
||||||
|
fun testCopying(): Unit = NoaFloat {
|
||||||
|
withCuda { device ->
|
||||||
|
testingCopying(device)
|
||||||
|
}
|
||||||
|
}!!
|
||||||
|
}
|
@ -46,7 +46,7 @@ class TestUtils {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testSetNumThreads(){
|
fun testSetNumThreads() {
|
||||||
val numThreads = 2
|
val numThreads = 2
|
||||||
setNumThreads(numThreads)
|
setNumThreads(numThreads)
|
||||||
assertEquals(numThreads, getNumThreads())
|
assertEquals(numThreads, getNumThreads())
|
||||||
@ -59,6 +59,4 @@ class TestUtils {
|
|||||||
}
|
}
|
||||||
}!!
|
}!!
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user