diff --git a/kmath-torch/src/nativeMain/kotlin/kscience/kmath/torch/TorchTensor.kt b/kmath-torch/src/nativeMain/kotlin/kscience/kmath/torch/TorchTensor.kt index 73ead8d84..6749d4319 100644 --- a/kmath-torch/src/nativeMain/kotlin/kscience/kmath/torch/TorchTensor.kt +++ b/kmath-torch/src/nativeMain/kotlin/kscience/kmath/torch/TorchTensor.kt @@ -15,13 +15,20 @@ public abstract class TorchTensor = tensor_to_string(buffer.tensorHandle)!! + val nativeStringRepresentation: CPointer = tensor_to_string(buffer.tensorHandle!!)!! val stringRepresentation = nativeStringRepresentation.toKString() dispose_char(nativeStringRepresentation) return stringRepresentation } + internal abstract fun wrap( + outStrides: TorchTensorStrides, + outScope: DeferScope, + outTensorHandle: COpaquePointer + ): TorchTensor + + public fun copy(): TorchTensor = wrap( + outStrides = strides, + outScope = buffer.scope, + outTensorHandle = copy_tensor(buffer.tensorHandle!!)!! + ) + } public class TorchTensorFloat internal constructor( override val strides: TorchTensorStrides, scope: DeferScope, tensorHandle: COpaquePointer -): TorchTensor() { +) : TorchTensor() { override val buffer: TorchTensorBufferFloat = TorchTensorBufferFloat(scope, tensorHandle) + override fun wrap(outStrides: TorchTensorStrides, outScope: DeferScope, outTensorHandle: COpaquePointer) = + TorchTensorFloat( + strides = outStrides, scope = outScope, tensorHandle = outTensorHandle + ) } public class TorchTensorInt internal constructor( override val strides: TorchTensorStrides, scope: DeferScope, tensorHandle: COpaquePointer -): TorchTensor() { +) : TorchTensor() { override val buffer: TorchTensorBufferInt = TorchTensorBufferInt(scope, tensorHandle) + override fun wrap(outStrides: TorchTensorStrides, outScope: DeferScope, outTensorHandle: COpaquePointer) = + TorchTensorInt( + strides = outStrides, scope = outScope, tensorHandle = outTensorHandle + ) } public class TorchTensorFloatGPU internal constructor( override val strides: TorchTensorStrides, scope: DeferScope, tensorHandle: COpaquePointer -): TorchTensor() { +) : TorchTensor() { override val buffer: TorchTensorBufferFloatGPU = TorchTensorBufferFloatGPU(scope, tensorHandle) + override fun wrap(outStrides: TorchTensorStrides, outScope: DeferScope, outTensorHandle: COpaquePointer) = + TorchTensorFloatGPU( + strides = outStrides, scope = outScope, tensorHandle = outTensorHandle + ) } diff --git a/kmath-torch/src/nativeMain/kotlin/kscience/kmath/torch/TorchTensorBuffer.kt b/kmath-torch/src/nativeMain/kotlin/kscience/kmath/torch/TorchTensorBuffer.kt index 83213b0ca..62873482e 100644 --- a/kmath-torch/src/nativeMain/kotlin/kscience/kmath/torch/TorchTensorBuffer.kt +++ b/kmath-torch/src/nativeMain/kotlin/kscience/kmath/torch/TorchTensorBuffer.kt @@ -7,10 +7,13 @@ import ctorch.* public abstract class TorchTensorBuffer internal constructor( internal val scope: DeferScope, - internal val tensorHandle: COpaquePointer + internal var tensorHandle: COpaquePointer? ) : MutableBuffer { - override val size: Int = get_numel(tensorHandle) + override val size: Int + get(){ + return get_numel(tensorHandle!!) + } init { scope.defer(::close) @@ -18,13 +21,14 @@ public abstract class TorchTensorBuffer internal constructor( protected fun close() { dispose_tensor(tensorHandle) + tensorHandle = null } internal abstract fun wrap(outScope: DeferScope, outTensorHandle: COpaquePointer): TorchTensorBuffer override fun copy(): TorchTensorBuffer = wrap( outScope = scope, - outTensorHandle = copy_tensor(tensorHandle)!! + outTensorHandle = copy_tensor(tensorHandle!!)!! ) } @@ -33,7 +37,10 @@ public class TorchTensorBufferFloat internal constructor( tensorHandle: COpaquePointer ) : TorchTensorBuffer(scope, tensorHandle) { - private val tensorData: CPointer = get_data_float(tensorHandle)!! + private val tensorData: CPointer + get(){ + return get_data_float(tensorHandle!!)!! + } override operator fun get(index: Int): Float = tensorData[index] @@ -55,7 +62,10 @@ public class TorchTensorBufferInt internal constructor( tensorHandle: COpaquePointer ) : TorchTensorBuffer(scope, tensorHandle) { - private val tensorData: CPointer = get_data_int(tensorHandle)!! + private val tensorData: CPointer + get(){ + return get_data_int(tensorHandle!!)!! + } override operator fun get(index: Int): Int = tensorData[index] @@ -76,14 +86,14 @@ public class TorchTensorBufferFloatGPU internal constructor( tensorHandle: COpaquePointer ) : TorchTensorBuffer(scope, tensorHandle) { - override operator fun get(index: Int): Float = get_at_offset_float(tensorHandle, index) + override operator fun get(index: Int): Float = get_at_offset_float(tensorHandle!!, index) override operator fun set(index: Int, value: Float) { - set_at_offset_float(tensorHandle, index, value) + set_at_offset_float(tensorHandle!!, index, value) } override operator fun iterator(): Iterator { - val cpuCopy = copy_to_cpu(tensorHandle)!! + val cpuCopy = copy_to_cpu(tensorHandle!!)!! val tensorCpuData = get_data_float(cpuCopy)!! val iteratorResult = (1..size).map { tensorCpuData[it - 1] }.iterator() dispose_tensor(cpuCopy) diff --git a/kmath-torch/src/nativeTest/kotlin/kscience/kmath/torch/TestTorchTensor.kt b/kmath-torch/src/nativeTest/kotlin/kscience/kmath/torch/TestTorchTensor.kt index 2805a376c..86cd05b3a 100644 --- a/kmath-torch/src/nativeTest/kotlin/kscience/kmath/torch/TestTorchTensor.kt +++ b/kmath-torch/src/nativeTest/kotlin/kscience/kmath/torch/TestTorchTensor.kt @@ -11,7 +11,7 @@ internal class TestTorchTensor { @Test fun intTensorLayout() = memScoped { val array = (1..24).toList().toIntArray() - val shape = intArrayOf(3, 2, 4) + val shape = intArrayOf(4, 6) val tensor = TorchTensor.copyFromIntArray(scope = this, array = array, shape = shape) tensor.elements().forEach { assertEquals(tensor[it.first], it.second) @@ -30,4 +30,16 @@ internal class TestTorchTensor { assertTrue(tensor.buffer.contentEquals(array.asBuffer())) } + @Test + fun mutableStructure() = memScoped { + val array = (1..10).map { 1f * it }.toList().toFloatArray() + val shape = intArrayOf(10) + val tensor = TorchTensor.copyFromFloatArray(this, array, shape) + val tensorCopy = tensor.copy() + + tensor[intArrayOf(0)] = 99f + assertEquals(99f, tensor[intArrayOf(0)]) + assertEquals(1f, tensorCopy[intArrayOf(0)]) + } + } \ No newline at end of file diff --git a/kmath-torch/src/nativeTest/kotlin/kscience/kmath/torch/TestUtils.kt b/kmath-torch/src/nativeTest/kotlin/kscience/kmath/torch/TestUtils.kt index d49cb3720..53759c5a0 100644 --- a/kmath-torch/src/nativeTest/kotlin/kscience/kmath/torch/TestUtils.kt +++ b/kmath-torch/src/nativeTest/kotlin/kscience/kmath/torch/TestUtils.kt @@ -1,14 +1,12 @@ package kscience.kmath.torch -import kotlin.test.Ignore import kotlin.test.Test import kotlin.test.assertEquals -import kotlin.test.assertTrue internal class TestUtils { @Test - fun settingTorchThreadsCount(){ + fun settingTorchThreadsCount() { val numThreads = 2 setNumThreads(numThreads) assertEquals(numThreads, getNumThreads())