forked from kscience/kmath
Memory management refactored
This commit is contained in:
parent
d97f8857a0
commit
fb9d612081
@ -15,13 +15,20 @@ public abstract class TorchTensor<T, out TorchTensorBufferImpl : TorchTensorBuff
|
||||
)!!
|
||||
return TorchTensorFloat(populateStridesFromNative(tensorHandle, rawShape = shape), scope, tensorHandle)
|
||||
}
|
||||
|
||||
public fun copyFromIntArray(scope: DeferScope, array: IntArray, shape: IntArray): TorchTensorInt {
|
||||
val tensorHandle: COpaquePointer = copy_from_blob_int(
|
||||
array.toCValues(), shape.toCValues(), shape.size
|
||||
)!!
|
||||
return TorchTensorInt(populateStridesFromNative(tensorHandle, rawShape = shape), scope, tensorHandle)
|
||||
}
|
||||
public fun copyFromFloatArrayToGPU(scope: DeferScope, array: FloatArray, shape: IntArray, device: Int): TorchTensorFloatGPU {
|
||||
|
||||
public fun copyFromFloatArrayToGPU(
|
||||
scope: DeferScope,
|
||||
array: FloatArray,
|
||||
shape: IntArray,
|
||||
device: Int
|
||||
): TorchTensorFloatGPU {
|
||||
val tensorHandle: COpaquePointer = copy_from_blob_to_gpu_float(
|
||||
array.toCValues(), shape.toCValues(), shape.size, device
|
||||
)!!
|
||||
@ -30,35 +37,59 @@ public abstract class TorchTensor<T, out TorchTensorBufferImpl : TorchTensorBuff
|
||||
}
|
||||
|
||||
override fun toString(): String {
|
||||
val nativeStringRepresentation: CPointer<ByteVar> = tensor_to_string(buffer.tensorHandle)!!
|
||||
val nativeStringRepresentation: CPointer<ByteVar> = tensor_to_string(buffer.tensorHandle!!)!!
|
||||
val stringRepresentation = nativeStringRepresentation.toKString()
|
||||
dispose_char(nativeStringRepresentation)
|
||||
return stringRepresentation
|
||||
}
|
||||
|
||||
internal abstract fun wrap(
|
||||
outStrides: TorchTensorStrides,
|
||||
outScope: DeferScope,
|
||||
outTensorHandle: COpaquePointer
|
||||
): TorchTensor<T, TorchTensorBufferImpl>
|
||||
|
||||
public fun copy(): TorchTensor<T, TorchTensorBufferImpl> = wrap(
|
||||
outStrides = strides,
|
||||
outScope = buffer.scope,
|
||||
outTensorHandle = copy_tensor(buffer.tensorHandle!!)!!
|
||||
)
|
||||
|
||||
}
|
||||
|
||||
public class TorchTensorFloat internal constructor(
|
||||
override val strides: TorchTensorStrides,
|
||||
scope: DeferScope,
|
||||
tensorHandle: COpaquePointer
|
||||
): TorchTensor<Float, TorchTensorBufferFloat>() {
|
||||
) : TorchTensor<Float, TorchTensorBufferFloat>() {
|
||||
override val buffer: TorchTensorBufferFloat = TorchTensorBufferFloat(scope, tensorHandle)
|
||||
override fun wrap(outStrides: TorchTensorStrides, outScope: DeferScope, outTensorHandle: COpaquePointer) =
|
||||
TorchTensorFloat(
|
||||
strides = outStrides, scope = outScope, tensorHandle = outTensorHandle
|
||||
)
|
||||
}
|
||||
|
||||
public class TorchTensorInt internal constructor(
|
||||
override val strides: TorchTensorStrides,
|
||||
scope: DeferScope,
|
||||
tensorHandle: COpaquePointer
|
||||
): TorchTensor<Int, TorchTensorBufferInt>() {
|
||||
) : TorchTensor<Int, TorchTensorBufferInt>() {
|
||||
override val buffer: TorchTensorBufferInt = TorchTensorBufferInt(scope, tensorHandle)
|
||||
override fun wrap(outStrides: TorchTensorStrides, outScope: DeferScope, outTensorHandle: COpaquePointer) =
|
||||
TorchTensorInt(
|
||||
strides = outStrides, scope = outScope, tensorHandle = outTensorHandle
|
||||
)
|
||||
}
|
||||
|
||||
public class TorchTensorFloatGPU internal constructor(
|
||||
override val strides: TorchTensorStrides,
|
||||
scope: DeferScope,
|
||||
tensorHandle: COpaquePointer
|
||||
): TorchTensor<Float, TorchTensorBufferFloatGPU>() {
|
||||
) : TorchTensor<Float, TorchTensorBufferFloatGPU>() {
|
||||
override val buffer: TorchTensorBufferFloatGPU = TorchTensorBufferFloatGPU(scope, tensorHandle)
|
||||
override fun wrap(outStrides: TorchTensorStrides, outScope: DeferScope, outTensorHandle: COpaquePointer) =
|
||||
TorchTensorFloatGPU(
|
||||
strides = outStrides, scope = outScope, tensorHandle = outTensorHandle
|
||||
)
|
||||
}
|
||||
|
||||
|
@ -7,10 +7,13 @@ import ctorch.*
|
||||
|
||||
public abstract class TorchTensorBuffer<T> internal constructor(
|
||||
internal val scope: DeferScope,
|
||||
internal val tensorHandle: COpaquePointer
|
||||
internal var tensorHandle: COpaquePointer?
|
||||
) : MutableBuffer<T> {
|
||||
|
||||
override val size: Int = get_numel(tensorHandle)
|
||||
override val size: Int
|
||||
get(){
|
||||
return get_numel(tensorHandle!!)
|
||||
}
|
||||
|
||||
init {
|
||||
scope.defer(::close)
|
||||
@ -18,13 +21,14 @@ public abstract class TorchTensorBuffer<T> internal constructor(
|
||||
|
||||
protected fun close() {
|
||||
dispose_tensor(tensorHandle)
|
||||
tensorHandle = null
|
||||
}
|
||||
|
||||
internal abstract fun wrap(outScope: DeferScope, outTensorHandle: COpaquePointer): TorchTensorBuffer<T>
|
||||
|
||||
override fun copy(): TorchTensorBuffer<T> = wrap(
|
||||
outScope = scope,
|
||||
outTensorHandle = copy_tensor(tensorHandle)!!
|
||||
outTensorHandle = copy_tensor(tensorHandle!!)!!
|
||||
)
|
||||
}
|
||||
|
||||
@ -33,7 +37,10 @@ public class TorchTensorBufferFloat internal constructor(
|
||||
tensorHandle: COpaquePointer
|
||||
) : TorchTensorBuffer<Float>(scope, tensorHandle) {
|
||||
|
||||
private val tensorData: CPointer<FloatVar> = get_data_float(tensorHandle)!!
|
||||
private val tensorData: CPointer<FloatVar>
|
||||
get(){
|
||||
return get_data_float(tensorHandle!!)!!
|
||||
}
|
||||
|
||||
override operator fun get(index: Int): Float = tensorData[index]
|
||||
|
||||
@ -55,7 +62,10 @@ public class TorchTensorBufferInt internal constructor(
|
||||
tensorHandle: COpaquePointer
|
||||
) : TorchTensorBuffer<Int>(scope, tensorHandle) {
|
||||
|
||||
private val tensorData: CPointer<IntVar> = get_data_int(tensorHandle)!!
|
||||
private val tensorData: CPointer<IntVar>
|
||||
get(){
|
||||
return get_data_int(tensorHandle!!)!!
|
||||
}
|
||||
|
||||
override operator fun get(index: Int): Int = tensorData[index]
|
||||
|
||||
@ -76,14 +86,14 @@ public class TorchTensorBufferFloatGPU internal constructor(
|
||||
tensorHandle: COpaquePointer
|
||||
) : TorchTensorBuffer<Float>(scope, tensorHandle) {
|
||||
|
||||
override operator fun get(index: Int): Float = get_at_offset_float(tensorHandle, index)
|
||||
override operator fun get(index: Int): Float = get_at_offset_float(tensorHandle!!, index)
|
||||
|
||||
override operator fun set(index: Int, value: Float) {
|
||||
set_at_offset_float(tensorHandle, index, value)
|
||||
set_at_offset_float(tensorHandle!!, index, value)
|
||||
}
|
||||
|
||||
override operator fun iterator(): Iterator<Float> {
|
||||
val cpuCopy = copy_to_cpu(tensorHandle)!!
|
||||
val cpuCopy = copy_to_cpu(tensorHandle!!)!!
|
||||
val tensorCpuData = get_data_float(cpuCopy)!!
|
||||
val iteratorResult = (1..size).map { tensorCpuData[it - 1] }.iterator()
|
||||
dispose_tensor(cpuCopy)
|
||||
|
@ -11,7 +11,7 @@ internal class TestTorchTensor {
|
||||
@Test
|
||||
fun intTensorLayout() = memScoped {
|
||||
val array = (1..24).toList().toIntArray()
|
||||
val shape = intArrayOf(3, 2, 4)
|
||||
val shape = intArrayOf(4, 6)
|
||||
val tensor = TorchTensor.copyFromIntArray(scope = this, array = array, shape = shape)
|
||||
tensor.elements().forEach {
|
||||
assertEquals(tensor[it.first], it.second)
|
||||
@ -30,4 +30,16 @@ internal class TestTorchTensor {
|
||||
assertTrue(tensor.buffer.contentEquals(array.asBuffer()))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun mutableStructure() = memScoped {
|
||||
val array = (1..10).map { 1f * it }.toList().toFloatArray()
|
||||
val shape = intArrayOf(10)
|
||||
val tensor = TorchTensor.copyFromFloatArray(this, array, shape)
|
||||
val tensorCopy = tensor.copy()
|
||||
|
||||
tensor[intArrayOf(0)] = 99f
|
||||
assertEquals(99f, tensor[intArrayOf(0)])
|
||||
assertEquals(1f, tensorCopy[intArrayOf(0)])
|
||||
}
|
||||
|
||||
}
|
@ -1,14 +1,12 @@
|
||||
package kscience.kmath.torch
|
||||
|
||||
import kotlin.test.Ignore
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
import kotlin.test.assertTrue
|
||||
|
||||
|
||||
internal class TestUtils {
|
||||
@Test
|
||||
fun settingTorchThreadsCount(){
|
||||
fun settingTorchThreadsCount() {
|
||||
val numThreads = 2
|
||||
setNumThreads(numThreads)
|
||||
assertEquals(numThreads, getNumThreads())
|
||||
|
Loading…
Reference in New Issue
Block a user