Initial drafts for TorchTensorAlgebra
This commit is contained in:
parent
cfe93886ac
commit
9b1a958491
@ -58,6 +58,10 @@ extern "C"
|
|||||||
TorchTensorHandle copy_to_cpu(TorchTensorHandle tensor_handle);
|
TorchTensorHandle copy_to_cpu(TorchTensorHandle tensor_handle);
|
||||||
TorchTensorHandle copy_to_gpu(TorchTensorHandle tensor_handle, int device);
|
TorchTensorHandle copy_to_gpu(TorchTensorHandle tensor_handle, int device);
|
||||||
|
|
||||||
|
TorchTensorHandle randn_float(int* shape, int shape_size);
|
||||||
|
|
||||||
|
TorchTensorHandle matmul(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
@ -33,13 +33,17 @@ namespace ctorch
|
|||||||
return *static_cast<torch::Tensor *>(tensor_handle);
|
return *static_cast<torch::Tensor *>(tensor_handle);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename Dtype>
|
inline std::vector<int64_t> to_vec_int(int *arr, int arr_size)
|
||||||
inline torch::Tensor copy_from_blob(Dtype *data, int *shape, int dim, torch::Device device)
|
|
||||||
{
|
{
|
||||||
auto shape_vec = std::vector<int64_t>(dim);
|
auto vec = std::vector<int64_t>(arr_size);
|
||||||
shape_vec.assign(shape, shape + dim);
|
vec.assign(arr, arr + arr_size);
|
||||||
return torch::from_blob(data, shape_vec, dtype<Dtype>()).to(
|
return vec;
|
||||||
torch::TensorOptions().layout(torch::kStrided).device(device), false, true);
|
}
|
||||||
|
|
||||||
|
template <typename Dtype>
|
||||||
|
inline torch::Tensor copy_from_blob(Dtype *data, std::vector<int64_t> shape, torch::Device device)
|
||||||
|
{
|
||||||
|
return torch::from_blob(data, shape, dtype<Dtype>()).to(torch::TensorOptions().layout(torch::kStrided).device(device), false, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline int *to_dynamic_ints(const c10::IntArrayRef &arr)
|
inline int *to_dynamic_ints(const c10::IntArrayRef &arr)
|
||||||
@ -78,4 +82,10 @@ namespace ctorch
|
|||||||
ten.index(offset_to_index(offset, ten.strides())) = value;
|
ten.index(offset_to_index(offset, ten.strides())) = value;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename Dtype>
|
||||||
|
inline torch::Tensor randn(std::vector<int64_t> shape, torch::Device device)
|
||||||
|
{
|
||||||
|
return torch::randn(shape, torch::TensorOptions().dtype(dtype<Dtype>()).layout(torch::kStrided).device(device));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace ctorch
|
} // namespace ctorch
|
||||||
|
@ -27,36 +27,36 @@ void set_seed(int seed)
|
|||||||
|
|
||||||
TorchTensorHandle copy_from_blob_double(double *data, int *shape, int dim)
|
TorchTensorHandle copy_from_blob_double(double *data, int *shape, int dim)
|
||||||
{
|
{
|
||||||
return new torch::Tensor(ctorch::copy_from_blob<double>(data, shape, dim, torch::kCPU));
|
return new torch::Tensor(ctorch::copy_from_blob<double>(data, ctorch::to_vec_int(shape, dim), torch::kCPU));
|
||||||
}
|
}
|
||||||
TorchTensorHandle copy_from_blob_float(float *data, int *shape, int dim)
|
TorchTensorHandle copy_from_blob_float(float *data, int *shape, int dim)
|
||||||
{
|
{
|
||||||
return new torch::Tensor(ctorch::copy_from_blob<float>(data, shape, dim, torch::kCPU));
|
return new torch::Tensor(ctorch::copy_from_blob<float>(data, ctorch::to_vec_int(shape, dim), torch::kCPU));
|
||||||
}
|
}
|
||||||
TorchTensorHandle copy_from_blob_long(long *data, int *shape, int dim)
|
TorchTensorHandle copy_from_blob_long(long *data, int *shape, int dim)
|
||||||
{
|
{
|
||||||
return new torch::Tensor(ctorch::copy_from_blob<long>(data, shape, dim, torch::kCPU));
|
return new torch::Tensor(ctorch::copy_from_blob<long>(data, ctorch::to_vec_int(shape, dim), torch::kCPU));
|
||||||
}
|
}
|
||||||
TorchTensorHandle copy_from_blob_int(int *data, int *shape, int dim)
|
TorchTensorHandle copy_from_blob_int(int *data, int *shape, int dim)
|
||||||
{
|
{
|
||||||
return new torch::Tensor(ctorch::copy_from_blob<int>(data, shape, dim, torch::kCPU));
|
return new torch::Tensor(ctorch::copy_from_blob<int>(data, ctorch::to_vec_int(shape, dim), torch::kCPU));
|
||||||
}
|
}
|
||||||
|
|
||||||
TorchTensorHandle copy_from_blob_to_gpu_double(double *data, int *shape, int dim, int device)
|
TorchTensorHandle copy_from_blob_to_gpu_double(double *data, int *shape, int dim, int device)
|
||||||
{
|
{
|
||||||
return new torch::Tensor(ctorch::copy_from_blob<double>(data, shape, dim, torch::Device(torch::kCUDA, device)));
|
return new torch::Tensor(ctorch::copy_from_blob<double>(data, ctorch::to_vec_int(shape, dim), torch::Device(torch::kCUDA, device)));
|
||||||
}
|
}
|
||||||
TorchTensorHandle copy_from_blob_to_gpu_float(float *data, int *shape, int dim, int device)
|
TorchTensorHandle copy_from_blob_to_gpu_float(float *data, int *shape, int dim, int device)
|
||||||
{
|
{
|
||||||
return new torch::Tensor(ctorch::copy_from_blob<float>(data, shape, dim, torch::Device(torch::kCUDA, device)));
|
return new torch::Tensor(ctorch::copy_from_blob<float>(data, ctorch::to_vec_int(shape, dim), torch::Device(torch::kCUDA, device)));
|
||||||
}
|
}
|
||||||
TorchTensorHandle copy_from_blob_to_gpu_long(long *data, int *shape, int dim, int device)
|
TorchTensorHandle copy_from_blob_to_gpu_long(long *data, int *shape, int dim, int device)
|
||||||
{
|
{
|
||||||
return new torch::Tensor(ctorch::copy_from_blob<long>(data, shape, dim, torch::Device(torch::kCUDA, device)));
|
return new torch::Tensor(ctorch::copy_from_blob<long>(data, ctorch::to_vec_int(shape, dim), torch::Device(torch::kCUDA, device)));
|
||||||
}
|
}
|
||||||
TorchTensorHandle copy_from_blob_to_gpu_int(int *data, int *shape, int dim, int device)
|
TorchTensorHandle copy_from_blob_to_gpu_int(int *data, int *shape, int dim, int device)
|
||||||
{
|
{
|
||||||
return new torch::Tensor(ctorch::copy_from_blob<int>(data, shape, dim, torch::Device(torch::kCUDA, device)));
|
return new torch::Tensor(ctorch::copy_from_blob<int>(data, ctorch::to_vec_int(shape, dim), torch::Device(torch::kCUDA, device)));
|
||||||
}
|
}
|
||||||
|
|
||||||
TorchTensorHandle copy_tensor(TorchTensorHandle tensor_handle)
|
TorchTensorHandle copy_tensor(TorchTensorHandle tensor_handle)
|
||||||
@ -167,3 +167,11 @@ TorchTensorHandle copy_to_gpu(TorchTensorHandle tensor_handle, int device)
|
|||||||
{
|
{
|
||||||
return new torch::Tensor(ctorch::cast(tensor_handle).to(torch::Device(torch::kCUDA, device),false, true));
|
return new torch::Tensor(ctorch::cast(tensor_handle).to(torch::Device(torch::kCUDA, device),false, true));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TorchTensorHandle randn_float(int* shape, int shape_size){
|
||||||
|
return new torch::Tensor(ctorch::randn<float>(ctorch::to_vec_int(shape, shape_size), torch::kCPU));
|
||||||
|
}
|
||||||
|
|
||||||
|
TorchTensorHandle matmul(TorchTensorHandle lhs, TorchTensorHandle rhs){
|
||||||
|
return new torch::Tensor(torch::matmul(ctorch::cast(lhs), ctorch::cast(rhs)));
|
||||||
|
}
|
@ -20,6 +20,6 @@ class TestTorchTensorGPU {
|
|||||||
tensor.elements().forEach {
|
tensor.elements().forEach {
|
||||||
assertEquals(tensor[it.first], it.second)
|
assertEquals(tensor[it.first], it.second)
|
||||||
}
|
}
|
||||||
assertTrue(tensor.buffer.contentEquals(array.asBuffer()))
|
assertTrue(tensor.asBuffer().contentEquals(array.asBuffer()))
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -0,0 +1,18 @@
|
|||||||
|
package kscience.kmath.torch
|
||||||
|
|
||||||
|
import kotlinx.cinterop.*
|
||||||
|
import ctorch.*
|
||||||
|
|
||||||
|
public abstract class TorchMemoryHolder internal constructor(
|
||||||
|
internal val scope: DeferScope,
|
||||||
|
internal var tensorHandle: COpaquePointer?
|
||||||
|
){
|
||||||
|
init {
|
||||||
|
scope.defer(::close)
|
||||||
|
}
|
||||||
|
|
||||||
|
protected fun close() {
|
||||||
|
dispose_tensor(tensorHandle)
|
||||||
|
tensorHandle = null
|
||||||
|
}
|
||||||
|
}
|
@ -5,22 +5,32 @@ import kscience.kmath.structures.*
|
|||||||
import kotlinx.cinterop.*
|
import kotlinx.cinterop.*
|
||||||
import ctorch.*
|
import ctorch.*
|
||||||
|
|
||||||
public abstract class TorchTensor<T, out TorchTensorBufferImpl : TorchTensorBuffer<T>> :
|
public sealed class TorchTensor<T, out TorchTensorBufferImpl : TorchTensorBuffer<T>> :
|
||||||
MutableNDBufferTrait<T, TorchTensorBufferImpl, TorchTensorStrides>() {
|
MutableNDBufferTrait<T, TorchTensorBufferImpl, TorchTensorStrides>() {
|
||||||
|
|
||||||
|
public fun asBuffer(): MutableBuffer<T> = buffer
|
||||||
|
|
||||||
public companion object {
|
public companion object {
|
||||||
public fun copyFromFloatArray(scope: DeferScope, array: FloatArray, shape: IntArray): TorchTensorFloat {
|
public fun copyFromFloatArray(scope: DeferScope, array: FloatArray, shape: IntArray): TorchTensorFloat {
|
||||||
val tensorHandle: COpaquePointer = copy_from_blob_float(
|
val tensorHandle: COpaquePointer = copy_from_blob_float(
|
||||||
array.toCValues(), shape.toCValues(), shape.size
|
array.toCValues(), shape.toCValues(), shape.size
|
||||||
)!!
|
)!!
|
||||||
return TorchTensorFloat(populateStridesFromNative(tensorHandle, rawShape = shape), scope, tensorHandle)
|
return TorchTensorFloat(
|
||||||
|
scope = scope,
|
||||||
|
tensorHandle = tensorHandle,
|
||||||
|
strides = populateStridesFromNative(tensorHandle, rawShape = shape)
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun copyFromIntArray(scope: DeferScope, array: IntArray, shape: IntArray): TorchTensorInt {
|
public fun copyFromIntArray(scope: DeferScope, array: IntArray, shape: IntArray): TorchTensorInt {
|
||||||
val tensorHandle: COpaquePointer = copy_from_blob_int(
|
val tensorHandle: COpaquePointer = copy_from_blob_int(
|
||||||
array.toCValues(), shape.toCValues(), shape.size
|
array.toCValues(), shape.toCValues(), shape.size
|
||||||
)!!
|
)!!
|
||||||
return TorchTensorInt(populateStridesFromNative(tensorHandle, rawShape = shape), scope, tensorHandle)
|
return TorchTensorInt(
|
||||||
|
scope = scope,
|
||||||
|
tensorHandle = tensorHandle,
|
||||||
|
strides = populateStridesFromNative(tensorHandle, rawShape = shape)
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun copyFromFloatArrayToGPU(
|
public fun copyFromFloatArrayToGPU(
|
||||||
@ -32,7 +42,11 @@ public abstract class TorchTensor<T, out TorchTensorBufferImpl : TorchTensorBuff
|
|||||||
val tensorHandle: COpaquePointer = copy_from_blob_to_gpu_float(
|
val tensorHandle: COpaquePointer = copy_from_blob_to_gpu_float(
|
||||||
array.toCValues(), shape.toCValues(), shape.size, device
|
array.toCValues(), shape.toCValues(), shape.size, device
|
||||||
)!!
|
)!!
|
||||||
return TorchTensorFloatGPU(populateStridesFromNative(tensorHandle, rawShape = shape), scope, tensorHandle)
|
return TorchTensorFloatGPU(
|
||||||
|
scope = scope,
|
||||||
|
tensorHandle = tensorHandle,
|
||||||
|
strides = populateStridesFromNative(tensorHandle, rawShape = shape)
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -43,53 +57,63 @@ public abstract class TorchTensor<T, out TorchTensorBufferImpl : TorchTensorBuff
|
|||||||
return stringRepresentation
|
return stringRepresentation
|
||||||
}
|
}
|
||||||
|
|
||||||
internal abstract fun wrap(
|
protected abstract fun wrap(
|
||||||
outStrides: TorchTensorStrides,
|
|
||||||
outScope: DeferScope,
|
outScope: DeferScope,
|
||||||
outTensorHandle: COpaquePointer
|
outTensorHandle: COpaquePointer,
|
||||||
|
outStrides: TorchTensorStrides
|
||||||
): TorchTensor<T, TorchTensorBufferImpl>
|
): TorchTensor<T, TorchTensorBufferImpl>
|
||||||
|
|
||||||
public fun copy(): TorchTensor<T, TorchTensorBufferImpl> = wrap(
|
public fun copy(): TorchTensor<T, TorchTensorBufferImpl> = wrap(
|
||||||
outStrides = strides,
|
|
||||||
outScope = buffer.scope,
|
outScope = buffer.scope,
|
||||||
outTensorHandle = copy_tensor(buffer.tensorHandle!!)!!
|
outTensorHandle = copy_tensor(buffer.tensorHandle!!)!!,
|
||||||
|
outStrides = strides
|
||||||
)
|
)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public class TorchTensorFloat internal constructor(
|
public class TorchTensorFloat internal constructor(
|
||||||
override val strides: TorchTensorStrides,
|
|
||||||
scope: DeferScope,
|
scope: DeferScope,
|
||||||
tensorHandle: COpaquePointer
|
tensorHandle: COpaquePointer,
|
||||||
|
override val strides: TorchTensorStrides
|
||||||
) : TorchTensor<Float, TorchTensorBufferFloat>() {
|
) : TorchTensor<Float, TorchTensorBufferFloat>() {
|
||||||
override val buffer: TorchTensorBufferFloat = TorchTensorBufferFloat(scope, tensorHandle)
|
override val buffer: TorchTensorBufferFloat = TorchTensorBufferFloat(scope, tensorHandle)
|
||||||
override fun wrap(outStrides: TorchTensorStrides, outScope: DeferScope, outTensorHandle: COpaquePointer) =
|
override fun wrap(
|
||||||
TorchTensorFloat(
|
outScope: DeferScope,
|
||||||
strides = outStrides, scope = outScope, tensorHandle = outTensorHandle
|
outTensorHandle: COpaquePointer,
|
||||||
)
|
outStrides: TorchTensorStrides
|
||||||
|
): TorchTensorFloat = TorchTensorFloat(
|
||||||
|
scope = outScope, tensorHandle = outTensorHandle, strides = outStrides
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
public class TorchTensorInt internal constructor(
|
public class TorchTensorInt internal constructor(
|
||||||
override val strides: TorchTensorStrides,
|
|
||||||
scope: DeferScope,
|
scope: DeferScope,
|
||||||
tensorHandle: COpaquePointer
|
tensorHandle: COpaquePointer,
|
||||||
|
override val strides: TorchTensorStrides
|
||||||
) : TorchTensor<Int, TorchTensorBufferInt>() {
|
) : TorchTensor<Int, TorchTensorBufferInt>() {
|
||||||
override val buffer: TorchTensorBufferInt = TorchTensorBufferInt(scope, tensorHandle)
|
override val buffer: TorchTensorBufferInt = TorchTensorBufferInt(scope, tensorHandle)
|
||||||
override fun wrap(outStrides: TorchTensorStrides, outScope: DeferScope, outTensorHandle: COpaquePointer) =
|
override fun wrap(
|
||||||
TorchTensorInt(
|
outScope: DeferScope,
|
||||||
strides = outStrides, scope = outScope, tensorHandle = outTensorHandle
|
outTensorHandle: COpaquePointer,
|
||||||
)
|
outStrides: TorchTensorStrides
|
||||||
|
): TorchTensorInt = TorchTensorInt(
|
||||||
|
scope = outScope, tensorHandle = outTensorHandle, strides = outStrides
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
public class TorchTensorFloatGPU internal constructor(
|
public class TorchTensorFloatGPU internal constructor(
|
||||||
override val strides: TorchTensorStrides,
|
|
||||||
scope: DeferScope,
|
scope: DeferScope,
|
||||||
tensorHandle: COpaquePointer
|
tensorHandle: COpaquePointer,
|
||||||
|
override val strides: TorchTensorStrides
|
||||||
) : TorchTensor<Float, TorchTensorBufferFloatGPU>() {
|
) : TorchTensor<Float, TorchTensorBufferFloatGPU>() {
|
||||||
override val buffer: TorchTensorBufferFloatGPU = TorchTensorBufferFloatGPU(scope, tensorHandle)
|
override val buffer: TorchTensorBufferFloatGPU = TorchTensorBufferFloatGPU(scope, tensorHandle)
|
||||||
override fun wrap(outStrides: TorchTensorStrides, outScope: DeferScope, outTensorHandle: COpaquePointer) =
|
override fun wrap(
|
||||||
|
outScope: DeferScope,
|
||||||
|
outTensorHandle: COpaquePointer,
|
||||||
|
outStrides: TorchTensorStrides
|
||||||
|
): TorchTensorFloatGPU =
|
||||||
TorchTensorFloatGPU(
|
TorchTensorFloatGPU(
|
||||||
strides = outStrides, scope = outScope, tensorHandle = outTensorHandle
|
scope = outScope, tensorHandle = outTensorHandle, strides = outStrides
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -0,0 +1,65 @@
|
|||||||
|
package kscience.kmath.torch
|
||||||
|
|
||||||
|
import kotlinx.cinterop.*
|
||||||
|
import ctorch.*
|
||||||
|
|
||||||
|
|
||||||
|
public sealed class TorchTensorAlgebra<
|
||||||
|
T,
|
||||||
|
TorchTensorBufferImpl : TorchTensorBuffer<T>,
|
||||||
|
PrimitiveArrayType>
|
||||||
|
constructor(
|
||||||
|
internal val scope: DeferScope
|
||||||
|
) {
|
||||||
|
|
||||||
|
protected abstract fun wrap(
|
||||||
|
outTensorHandle: COpaquePointer,
|
||||||
|
outStrides: TorchTensorStrides
|
||||||
|
): TorchTensor<T, TorchTensorBufferImpl>
|
||||||
|
|
||||||
|
public infix fun TorchTensor<T, TorchTensorBufferImpl>.swap(other: TorchTensor<T, TorchTensorBufferImpl>): Unit {
|
||||||
|
check(this.shape contentEquals other.shape) {
|
||||||
|
"Attempt to swap tensors with different shapes"
|
||||||
|
}
|
||||||
|
this.buffer.tensorHandle = other.buffer.tensorHandle.also {
|
||||||
|
other.buffer.tensorHandle = this.buffer.tensorHandle
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public abstract fun copyFromArray(array: PrimitiveArrayType, shape: IntArray): TorchTensor<T, TorchTensorBufferImpl>
|
||||||
|
|
||||||
|
public infix fun TorchTensor<T, TorchTensorBufferImpl>.dot(other: TorchTensor<T, TorchTensorBufferImpl>):
|
||||||
|
TorchTensor<T, TorchTensorBufferImpl> {
|
||||||
|
val resultHandle = matmul(this.buffer.tensorHandle, other.buffer.tensorHandle)!!
|
||||||
|
val strides = populateStridesFromNative(tensorHandle = resultHandle)
|
||||||
|
return wrap(resultHandle, strides)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public sealed class TorchTensorField<T, TorchTensorBufferImpl : TorchTensorBuffer<T>, PrimitiveArrayType>
|
||||||
|
constructor(scope: DeferScope) : TorchTensorAlgebra<T, TorchTensorBufferImpl, PrimitiveArrayType>(scope) {
|
||||||
|
public abstract fun randn(shape: IntArray): TorchTensor<T, TorchTensorBufferImpl>
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public class TorchTensorFloatAlgebra(scope: DeferScope) :
|
||||||
|
TorchTensorField<Float, TorchTensorBufferFloat, FloatArray>(scope) {
|
||||||
|
override fun wrap(
|
||||||
|
outTensorHandle: COpaquePointer,
|
||||||
|
outStrides: TorchTensorStrides
|
||||||
|
): TorchTensorFloat = TorchTensorFloat(scope = scope, tensorHandle = outTensorHandle, strides = outStrides)
|
||||||
|
|
||||||
|
override fun randn(shape: IntArray): TorchTensor<Float, TorchTensorBufferFloat> {
|
||||||
|
val tensorHandle = randn_float(shape.toCValues(), shape.size)!!
|
||||||
|
val strides = populateStridesFromNative(tensorHandle = tensorHandle, rawShape = shape)
|
||||||
|
return wrap(tensorHandle, strides)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun copyFromArray(array: FloatArray, shape: IntArray): TorchTensorFloat =
|
||||||
|
TorchTensor.copyFromFloatArray(scope, array, shape)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public fun <R> TorchTensorFloatAlgebra(block: TorchTensorFloatAlgebra.() -> R): R =
|
||||||
|
memScoped { TorchTensorFloatAlgebra(this).block() }
|
@ -5,25 +5,16 @@ import kscience.kmath.structures.MutableBuffer
|
|||||||
import kotlinx.cinterop.*
|
import kotlinx.cinterop.*
|
||||||
import ctorch.*
|
import ctorch.*
|
||||||
|
|
||||||
public abstract class TorchTensorBuffer<T> internal constructor(
|
public sealed class TorchTensorBuffer<T> constructor(
|
||||||
internal val scope: DeferScope,
|
scope: DeferScope,
|
||||||
internal var tensorHandle: COpaquePointer?
|
tensorHandle: COpaquePointer?
|
||||||
) : MutableBuffer<T> {
|
) : MutableBuffer<T>, TorchMemoryHolder(scope, tensorHandle) {
|
||||||
|
|
||||||
override val size: Int
|
override val size: Int
|
||||||
get(){
|
get(){
|
||||||
return get_numel(tensorHandle!!)
|
return get_numel(tensorHandle!!)
|
||||||
}
|
}
|
||||||
|
|
||||||
init {
|
|
||||||
scope.defer(::close)
|
|
||||||
}
|
|
||||||
|
|
||||||
protected fun close() {
|
|
||||||
dispose_tensor(tensorHandle)
|
|
||||||
tensorHandle = null
|
|
||||||
}
|
|
||||||
|
|
||||||
internal abstract fun wrap(outScope: DeferScope, outTensorHandle: COpaquePointer): TorchTensorBuffer<T>
|
internal abstract fun wrap(outScope: DeferScope, outTensorHandle: COpaquePointer): TorchTensorBuffer<T>
|
||||||
|
|
||||||
override fun copy(): TorchTensorBuffer<T> = wrap(
|
override fun copy(): TorchTensorBuffer<T> = wrap(
|
||||||
|
@ -16,7 +16,7 @@ internal class TestTorchTensor {
|
|||||||
tensor.elements().forEach {
|
tensor.elements().forEach {
|
||||||
assertEquals(tensor[it.first], it.second)
|
assertEquals(tensor[it.first], it.second)
|
||||||
}
|
}
|
||||||
assertTrue(tensor.buffer.contentEquals(array.asBuffer()))
|
assertTrue(tensor.asBuffer().contentEquals(array.asBuffer()))
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@ -27,7 +27,7 @@ internal class TestTorchTensor {
|
|||||||
tensor.elements().forEach {
|
tensor.elements().forEach {
|
||||||
assertEquals(tensor[it.first], it.second)
|
assertEquals(tensor[it.first], it.second)
|
||||||
}
|
}
|
||||||
assertTrue(tensor.buffer.contentEquals(array.asBuffer()))
|
assertTrue(tensor.asBuffer().contentEquals(array.asBuffer()))
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -0,0 +1,36 @@
|
|||||||
|
package kscience.kmath.torch
|
||||||
|
|
||||||
|
|
||||||
|
import kotlin.test.*
|
||||||
|
import kotlin.time.measureTime
|
||||||
|
|
||||||
|
|
||||||
|
class TestTorchTensorAlgebra {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun swappingTensors() = TorchTensorFloatAlgebra {
|
||||||
|
val tensorA = copyFromArray(floatArrayOf(1f, 2f, 3f), intArrayOf(3))
|
||||||
|
val tensorB = tensorA.copy()
|
||||||
|
val tensorC = copyFromArray(floatArrayOf(4f, 5f, 6f), intArrayOf(3))
|
||||||
|
tensorA swap tensorC
|
||||||
|
assertTrue(tensorB.asBuffer().contentEquals(tensorC.asBuffer()))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun dotOperation() = TorchTensorFloatAlgebra {
|
||||||
|
setSeed(987654)
|
||||||
|
var tensorA = randn(intArrayOf(1000, 1000))
|
||||||
|
val tensorB = randn(intArrayOf(1000, 1000))
|
||||||
|
measureTime {
|
||||||
|
repeat(100) {
|
||||||
|
TorchTensorFloatAlgebra {
|
||||||
|
tensorA swap (tensorA dot tensorB)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}.also(::println)
|
||||||
|
assertTrue(tensorA.shape contentEquals tensorB.shape)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
@ -1,7 +1,6 @@
|
|||||||
package kscience.kmath.torch
|
package kscience.kmath.torch
|
||||||
|
|
||||||
import kotlin.test.Test
|
import kotlin.test.*
|
||||||
import kotlin.test.assertEquals
|
|
||||||
|
|
||||||
|
|
||||||
internal class TestUtils {
|
internal class TestUtils {
|
||||||
@ -11,4 +10,12 @@ internal class TestUtils {
|
|||||||
setNumThreads(numThreads)
|
setNumThreads(numThreads)
|
||||||
assertEquals(numThreads, getNumThreads())
|
assertEquals(numThreads, getNumThreads())
|
||||||
}
|
}
|
||||||
|
@Test
|
||||||
|
fun seedSetting() = TorchTensorFloatAlgebra {
|
||||||
|
setSeed(987654)
|
||||||
|
val tensorA = randn(intArrayOf(2,3))
|
||||||
|
setSeed(987654)
|
||||||
|
val tensorB = randn(intArrayOf(2,3))
|
||||||
|
assertTrue(tensorA.asBuffer().contentEquals(tensorB.asBuffer()))
|
||||||
|
}
|
||||||
}
|
}
|
Loading…
Reference in New Issue
Block a user