Memory management refactored

This commit is contained in:
rgrit91 2021-01-07 20:52:43 +00:00
parent d97f8857a0
commit fb9d612081
4 changed files with 68 additions and 17 deletions

View File

@ -15,13 +15,20 @@ public abstract class TorchTensor<T, out TorchTensorBufferImpl : TorchTensorBuff
)!!
return TorchTensorFloat(populateStridesFromNative(tensorHandle, rawShape = shape), scope, tensorHandle)
}
public fun copyFromIntArray(scope: DeferScope, array: IntArray, shape: IntArray): TorchTensorInt {
val tensorHandle: COpaquePointer = copy_from_blob_int(
array.toCValues(), shape.toCValues(), shape.size
)!!
return TorchTensorInt(populateStridesFromNative(tensorHandle, rawShape = shape), scope, tensorHandle)
}
public fun copyFromFloatArrayToGPU(scope: DeferScope, array: FloatArray, shape: IntArray, device: Int): TorchTensorFloatGPU {
public fun copyFromFloatArrayToGPU(
scope: DeferScope,
array: FloatArray,
shape: IntArray,
device: Int
): TorchTensorFloatGPU {
val tensorHandle: COpaquePointer = copy_from_blob_to_gpu_float(
array.toCValues(), shape.toCValues(), shape.size, device
)!!
@ -30,35 +37,59 @@ public abstract class TorchTensor<T, out TorchTensorBufferImpl : TorchTensorBuff
}
override fun toString(): String {
val nativeStringRepresentation: CPointer<ByteVar> = tensor_to_string(buffer.tensorHandle)!!
val nativeStringRepresentation: CPointer<ByteVar> = tensor_to_string(buffer.tensorHandle!!)!!
val stringRepresentation = nativeStringRepresentation.toKString()
dispose_char(nativeStringRepresentation)
return stringRepresentation
}
internal abstract fun wrap(
outStrides: TorchTensorStrides,
outScope: DeferScope,
outTensorHandle: COpaquePointer
): TorchTensor<T, TorchTensorBufferImpl>
public fun copy(): TorchTensor<T, TorchTensorBufferImpl> = wrap(
outStrides = strides,
outScope = buffer.scope,
outTensorHandle = copy_tensor(buffer.tensorHandle!!)!!
)
}
public class TorchTensorFloat internal constructor(
override val strides: TorchTensorStrides,
scope: DeferScope,
tensorHandle: COpaquePointer
): TorchTensor<Float, TorchTensorBufferFloat>() {
) : TorchTensor<Float, TorchTensorBufferFloat>() {
override val buffer: TorchTensorBufferFloat = TorchTensorBufferFloat(scope, tensorHandle)
override fun wrap(outStrides: TorchTensorStrides, outScope: DeferScope, outTensorHandle: COpaquePointer) =
TorchTensorFloat(
strides = outStrides, scope = outScope, tensorHandle = outTensorHandle
)
}
public class TorchTensorInt internal constructor(
override val strides: TorchTensorStrides,
scope: DeferScope,
tensorHandle: COpaquePointer
): TorchTensor<Int, TorchTensorBufferInt>() {
) : TorchTensor<Int, TorchTensorBufferInt>() {
override val buffer: TorchTensorBufferInt = TorchTensorBufferInt(scope, tensorHandle)
override fun wrap(outStrides: TorchTensorStrides, outScope: DeferScope, outTensorHandle: COpaquePointer) =
TorchTensorInt(
strides = outStrides, scope = outScope, tensorHandle = outTensorHandle
)
}
public class TorchTensorFloatGPU internal constructor(
override val strides: TorchTensorStrides,
scope: DeferScope,
tensorHandle: COpaquePointer
): TorchTensor<Float, TorchTensorBufferFloatGPU>() {
) : TorchTensor<Float, TorchTensorBufferFloatGPU>() {
override val buffer: TorchTensorBufferFloatGPU = TorchTensorBufferFloatGPU(scope, tensorHandle)
override fun wrap(outStrides: TorchTensorStrides, outScope: DeferScope, outTensorHandle: COpaquePointer) =
TorchTensorFloatGPU(
strides = outStrides, scope = outScope, tensorHandle = outTensorHandle
)
}

View File

@ -7,10 +7,13 @@ import ctorch.*
public abstract class TorchTensorBuffer<T> internal constructor(
internal val scope: DeferScope,
internal val tensorHandle: COpaquePointer
internal var tensorHandle: COpaquePointer?
) : MutableBuffer<T> {
override val size: Int = get_numel(tensorHandle)
override val size: Int
get(){
return get_numel(tensorHandle!!)
}
init {
scope.defer(::close)
@ -18,13 +21,14 @@ public abstract class TorchTensorBuffer<T> internal constructor(
protected fun close() {
dispose_tensor(tensorHandle)
tensorHandle = null
}
internal abstract fun wrap(outScope: DeferScope, outTensorHandle: COpaquePointer): TorchTensorBuffer<T>
override fun copy(): TorchTensorBuffer<T> = wrap(
outScope = scope,
outTensorHandle = copy_tensor(tensorHandle)!!
outTensorHandle = copy_tensor(tensorHandle!!)!!
)
}
@ -33,7 +37,10 @@ public class TorchTensorBufferFloat internal constructor(
tensorHandle: COpaquePointer
) : TorchTensorBuffer<Float>(scope, tensorHandle) {
private val tensorData: CPointer<FloatVar> = get_data_float(tensorHandle)!!
private val tensorData: CPointer<FloatVar>
get(){
return get_data_float(tensorHandle!!)!!
}
override operator fun get(index: Int): Float = tensorData[index]
@ -55,7 +62,10 @@ public class TorchTensorBufferInt internal constructor(
tensorHandle: COpaquePointer
) : TorchTensorBuffer<Int>(scope, tensorHandle) {
private val tensorData: CPointer<IntVar> = get_data_int(tensorHandle)!!
private val tensorData: CPointer<IntVar>
get(){
return get_data_int(tensorHandle!!)!!
}
override operator fun get(index: Int): Int = tensorData[index]
@ -76,14 +86,14 @@ public class TorchTensorBufferFloatGPU internal constructor(
tensorHandle: COpaquePointer
) : TorchTensorBuffer<Float>(scope, tensorHandle) {
override operator fun get(index: Int): Float = get_at_offset_float(tensorHandle, index)
override operator fun get(index: Int): Float = get_at_offset_float(tensorHandle!!, index)
override operator fun set(index: Int, value: Float) {
set_at_offset_float(tensorHandle, index, value)
set_at_offset_float(tensorHandle!!, index, value)
}
override operator fun iterator(): Iterator<Float> {
val cpuCopy = copy_to_cpu(tensorHandle)!!
val cpuCopy = copy_to_cpu(tensorHandle!!)!!
val tensorCpuData = get_data_float(cpuCopy)!!
val iteratorResult = (1..size).map { tensorCpuData[it - 1] }.iterator()
dispose_tensor(cpuCopy)

View File

@ -11,7 +11,7 @@ internal class TestTorchTensor {
@Test
fun intTensorLayout() = memScoped {
val array = (1..24).toList().toIntArray()
val shape = intArrayOf(3, 2, 4)
val shape = intArrayOf(4, 6)
val tensor = TorchTensor.copyFromIntArray(scope = this, array = array, shape = shape)
tensor.elements().forEach {
assertEquals(tensor[it.first], it.second)
@ -30,4 +30,16 @@ internal class TestTorchTensor {
assertTrue(tensor.buffer.contentEquals(array.asBuffer()))
}
@Test
fun mutableStructure() = memScoped {
val array = (1..10).map { 1f * it }.toList().toFloatArray()
val shape = intArrayOf(10)
val tensor = TorchTensor.copyFromFloatArray(this, array, shape)
val tensorCopy = tensor.copy()
tensor[intArrayOf(0)] = 99f
assertEquals(99f, tensor[intArrayOf(0)])
assertEquals(1f, tensorCopy[intArrayOf(0)])
}
}

View File

@ -1,14 +1,12 @@
package kscience.kmath.torch
import kotlin.test.Ignore
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertTrue
internal class TestUtils {
@Test
fun settingTorchThreadsCount(){
fun settingTorchThreadsCount() {
val numThreads = 2
setNumThreads(numThreads)
assertEquals(numThreads, getNumThreads())