Float, Long and Int support

This commit is contained in:
rgrit91 2021-01-13 21:08:14 +00:00
parent 80f28dbcd5
commit ef570254e6
12 changed files with 527 additions and 204 deletions

View File

@ -20,22 +20,21 @@ extern "C"
TorchTensorHandle empty_tensor(); TorchTensorHandle empty_tensor();
double *get_data_double(TorchTensorHandle tensor_handle);
float *get_data_float(TorchTensorHandle tensor_handle);
long *get_data_long(TorchTensorHandle tensor_handle);
int *get_data_int(TorchTensorHandle tensor_handle);
TorchTensorHandle from_blob_double(double *data, int *shape, int dim, int device, bool copy); TorchTensorHandle from_blob_double(double *data, int *shape, int dim, int device, bool copy);
TorchTensorHandle from_blob_float(float *data, int *shape, int dim, int device, bool copy); TorchTensorHandle from_blob_float(float *data, int *shape, int dim, int device, bool copy);
TorchTensorHandle from_blob_long(long *data, int *shape, int dim, int device, bool copy); TorchTensorHandle from_blob_long(long *data, int *shape, int dim, int device, bool copy);
TorchTensorHandle from_blob_int(int *data, int *shape, int dim, int device, bool copy); TorchTensorHandle from_blob_int(int *data, int *shape, int dim, int device, bool copy);
TorchTensorHandle copy_tensor(TorchTensorHandle tensor_handle); TorchTensorHandle copy_tensor(TorchTensorHandle tensor_handle);
TorchTensorHandle copy_to_device(TorchTensorHandle tensor_handle, int device); TorchTensorHandle copy_to_device(TorchTensorHandle tensor_handle, int device);
TorchTensorHandle copy_to_double(TorchTensorHandle tensor_handle);
TorchTensorHandle copy_to_float(TorchTensorHandle tensor_handle);
TorchTensorHandle copy_to_long(TorchTensorHandle tensor_handle);
TorchTensorHandle copy_to_int(TorchTensorHandle tensor_handle);
void swap_tensors(TorchTensorHandle lhs_handle, TorchTensorHandle rhs_handle);
double get_item_double(TorchTensorHandle tensor_handle); char *tensor_to_string(TorchTensorHandle tensor_handle);
float get_item_float(TorchTensorHandle tensor_handle); void dispose_char(char *ptr);
long get_item_long(TorchTensorHandle tensor_handle); void dispose_tensor(TorchTensorHandle tensor_handle);
int get_item_int(TorchTensorHandle tensor_handle);
int get_dim(TorchTensorHandle tensor_handle); int get_dim(TorchTensorHandle tensor_handle);
int get_numel(TorchTensorHandle tensor_handle); int get_numel(TorchTensorHandle tensor_handle);
@ -43,9 +42,15 @@ extern "C"
int get_stride_at(TorchTensorHandle tensor_handle, int d); int get_stride_at(TorchTensorHandle tensor_handle, int d);
int get_device(TorchTensorHandle tensor_handle); int get_device(TorchTensorHandle tensor_handle);
char *tensor_to_string(TorchTensorHandle tensor_handle); double *get_data_double(TorchTensorHandle tensor_handle);
void dispose_char(char *ptr); float *get_data_float(TorchTensorHandle tensor_handle);
void dispose_tensor(TorchTensorHandle tensor_handle); long *get_data_long(TorchTensorHandle tensor_handle);
int *get_data_int(TorchTensorHandle tensor_handle);
double get_item_double(TorchTensorHandle tensor_handle);
float get_item_float(TorchTensorHandle tensor_handle);
long get_item_long(TorchTensorHandle tensor_handle);
int get_item_int(TorchTensorHandle tensor_handle);
double get_double(TorchTensorHandle tensor_handle, int *index); double get_double(TorchTensorHandle tensor_handle, int *index);
float get_float(TorchTensorHandle tensor_handle, int *index); float get_float(TorchTensorHandle tensor_handle, int *index);
@ -61,15 +66,14 @@ extern "C"
TorchTensorHandle randn_float(int *shape, int shape_size, int device); TorchTensorHandle randn_float(int *shape, int shape_size, int device);
TorchTensorHandle rand_float(int *shape, int shape_size, int device); TorchTensorHandle rand_float(int *shape, int shape_size, int device);
TorchTensorHandle randint_long(long low, long high, int *shape, int shape_size, int device);
TorchTensorHandle randint_int(int low, int high, int *shape, int shape_size, int device);
TorchTensorHandle full_double(double value, int *shape, int shape_size, int device); TorchTensorHandle full_double(double value, int *shape, int shape_size, int device);
TorchTensorHandle full_float(float value, int *shape, int shape_size, int device); TorchTensorHandle full_float(float value, int *shape, int shape_size, int device);
TorchTensorHandle full_long(long value, int *shape, int shape_size, int device); TorchTensorHandle full_long(long value, int *shape, int shape_size, int device);
TorchTensorHandle full_int(int value, int *shape, int shape_size, int device); TorchTensorHandle full_int(int value, int *shape, int shape_size, int device);
TorchTensorHandle matmul(TorchTensorHandle lhs, TorchTensorHandle rhs);
void matmul_assign(TorchTensorHandle lhs, TorchTensorHandle rhs);
void matmul_right_assign(TorchTensorHandle lhs, TorchTensorHandle rhs);
TorchTensorHandle times_double(double value, TorchTensorHandle other); TorchTensorHandle times_double(double value, TorchTensorHandle other);
TorchTensorHandle times_float(float value, TorchTensorHandle other); TorchTensorHandle times_float(float value, TorchTensorHandle other);
TorchTensorHandle times_long(long value, TorchTensorHandle other); TorchTensorHandle times_long(long value, TorchTensorHandle other);
@ -115,10 +119,9 @@ extern "C"
TorchTensorHandle sum_tensor(TorchTensorHandle tensor_handle); TorchTensorHandle sum_tensor(TorchTensorHandle tensor_handle);
void sum_tensor_assign(TorchTensorHandle tensor_handle); void sum_tensor_assign(TorchTensorHandle tensor_handle);
bool requires_grad(TorchTensorHandle tensor_handle); TorchTensorHandle matmul(TorchTensorHandle lhs, TorchTensorHandle rhs);
void requires_grad_(TorchTensorHandle tensor_handle, bool status); void matmul_assign(TorchTensorHandle lhs, TorchTensorHandle rhs);
TorchTensorHandle detach_from_graph(TorchTensorHandle tensor_handle); void matmul_right_assign(TorchTensorHandle lhs, TorchTensorHandle rhs);
TorchTensorHandle autograd_tensor(TorchTensorHandle value, TorchTensorHandle variable);
TorchTensorHandle diag_embed(TorchTensorHandle diags_handle, int offset, int dim1, int dim2); TorchTensorHandle diag_embed(TorchTensorHandle diags_handle, int offset, int dim1, int dim2);
@ -132,6 +135,11 @@ extern "C"
TorchTensorHandle V_handle, TorchTensorHandle V_handle,
bool eigenvectors); bool eigenvectors);
bool requires_grad(TorchTensorHandle tensor_handle);
void requires_grad_(TorchTensorHandle tensor_handle, bool status);
TorchTensorHandle detach_from_graph(TorchTensorHandle tensor_handle);
TorchTensorHandle autograd_tensor(TorchTensorHandle value, TorchTensorHandle variable);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif

View File

@ -92,6 +92,12 @@ namespace ctorch
return torch::rand(shape, torch::TensorOptions().dtype(dtype<Dtype>()).layout(torch::kStrided).device(device)); return torch::rand(shape, torch::TensorOptions().dtype(dtype<Dtype>()).layout(torch::kStrided).device(device));
} }
template <typename Dtype>
inline torch::Tensor randint(Dtype low, Dtype high, std::vector<int64_t> shape, torch::Device device)
{
return torch::randint(low, high, shape, torch::TensorOptions().dtype(dtype<Dtype>()).layout(torch::kStrided).device(device));
}
template <typename Dtype> template <typename Dtype>
inline torch::Tensor full(Dtype value, std::vector<int64_t> shape, torch::Device device) inline torch::Tensor full(Dtype value, std::vector<int64_t> shape, torch::Device device)
{ {

View File

@ -30,44 +30,6 @@ TorchTensorHandle empty_tensor()
return new torch::Tensor; return new torch::Tensor;
} }
double *get_data_double(TorchTensorHandle tensor_handle)
{
return ctorch::cast(tensor_handle).data_ptr<double>();
}
float *get_data_float(TorchTensorHandle tensor_handle)
{
return ctorch::cast(tensor_handle).data_ptr<float>();
}
long *get_data_long(TorchTensorHandle tensor_handle)
{
return ctorch::cast(tensor_handle).data_ptr<long>();
}
int *get_data_int(TorchTensorHandle tensor_handle)
{
return ctorch::cast(tensor_handle).data_ptr<int>();
}
int get_dim(TorchTensorHandle tensor_handle)
{
return ctorch::cast(tensor_handle).dim();
}
int get_numel(TorchTensorHandle tensor_handle)
{
return ctorch::cast(tensor_handle).numel();
}
int get_shape_at(TorchTensorHandle tensor_handle, int d)
{
return ctorch::cast(tensor_handle).size(d);
}
int get_stride_at(TorchTensorHandle tensor_handle, int d)
{
return ctorch::cast(tensor_handle).stride(d);
}
int get_device(TorchTensorHandle tensor_handle)
{
return ctorch::device_to_int(ctorch::cast(tensor_handle));
}
TorchTensorHandle from_blob_double(double *data, int *shape, int dim, int device, bool copy) TorchTensorHandle from_blob_double(double *data, int *shape, int dim, int device, bool copy)
{ {
return new torch::Tensor(ctorch::from_blob<double>(data, ctorch::to_vec_int(shape, dim), ctorch::int_to_device(device), copy)); return new torch::Tensor(ctorch::from_blob<double>(data, ctorch::to_vec_int(shape, dim), ctorch::int_to_device(device), copy));
@ -92,6 +54,26 @@ TorchTensorHandle copy_to_device(TorchTensorHandle tensor_handle, int device)
{ {
return new torch::Tensor(ctorch::cast(tensor_handle).to(ctorch::int_to_device(device), false, true)); return new torch::Tensor(ctorch::cast(tensor_handle).to(ctorch::int_to_device(device), false, true));
} }
TorchTensorHandle copy_to_double(TorchTensorHandle tensor_handle)
{
return new torch::Tensor(ctorch::cast(tensor_handle).to(ctorch::dtype<double>(), false, true));
}
TorchTensorHandle copy_to_float(TorchTensorHandle tensor_handle)
{
return new torch::Tensor(ctorch::cast(tensor_handle).to(ctorch::dtype<float>(), false, true));
}
TorchTensorHandle copy_to_long(TorchTensorHandle tensor_handle)
{
return new torch::Tensor(ctorch::cast(tensor_handle).to(ctorch::dtype<long>(), false, true));
}
TorchTensorHandle copy_to_int(TorchTensorHandle tensor_handle)
{
return new torch::Tensor(ctorch::cast(tensor_handle).to(ctorch::dtype<int>(), false, true));
}
void swap_tensors(TorchTensorHandle lhs_handle, TorchTensorHandle rhs_handle)
{
std::swap(ctorch::cast(lhs_handle), ctorch::cast(rhs_handle));
}
char *tensor_to_string(TorchTensorHandle tensor_handle) char *tensor_to_string(TorchTensorHandle tensor_handle)
{ {
@ -111,6 +93,61 @@ void dispose_tensor(TorchTensorHandle tensor_handle)
delete static_cast<torch::Tensor *>(tensor_handle); delete static_cast<torch::Tensor *>(tensor_handle);
} }
int get_dim(TorchTensorHandle tensor_handle)
{
return ctorch::cast(tensor_handle).dim();
}
int get_numel(TorchTensorHandle tensor_handle)
{
return ctorch::cast(tensor_handle).numel();
}
int get_shape_at(TorchTensorHandle tensor_handle, int d)
{
return ctorch::cast(tensor_handle).size(d);
}
int get_stride_at(TorchTensorHandle tensor_handle, int d)
{
return ctorch::cast(tensor_handle).stride(d);
}
int get_device(TorchTensorHandle tensor_handle)
{
return ctorch::device_to_int(ctorch::cast(tensor_handle));
}
double *get_data_double(TorchTensorHandle tensor_handle)
{
return ctorch::cast(tensor_handle).data_ptr<double>();
}
float *get_data_float(TorchTensorHandle tensor_handle)
{
return ctorch::cast(tensor_handle).data_ptr<float>();
}
long *get_data_long(TorchTensorHandle tensor_handle)
{
return ctorch::cast(tensor_handle).data_ptr<long>();
}
int *get_data_int(TorchTensorHandle tensor_handle)
{
return ctorch::cast(tensor_handle).data_ptr<int>();
}
double get_item_double(TorchTensorHandle tensor_handle)
{
return ctorch::cast(tensor_handle).item<double>();
}
float get_item_float(TorchTensorHandle tensor_handle)
{
return ctorch::cast(tensor_handle).item<float>();
}
long get_item_long(TorchTensorHandle tensor_handle)
{
return ctorch::cast(tensor_handle).item<long>();
}
int get_item_int(TorchTensorHandle tensor_handle)
{
return ctorch::cast(tensor_handle).item<int>();
}
double get_double(TorchTensorHandle tensor_handle, int *index) double get_double(TorchTensorHandle tensor_handle, int *index)
{ {
return ctorch::get<double>(tensor_handle, index); return ctorch::get<double>(tensor_handle, index);
@ -144,23 +181,6 @@ void set_int(TorchTensorHandle tensor_handle, int *index, int value)
ctorch::set<int>(tensor_handle, index, value); ctorch::set<int>(tensor_handle, index, value);
} }
double get_item_double(TorchTensorHandle tensor_handle)
{
return ctorch::cast(tensor_handle).item<double>();
}
float get_item_float(TorchTensorHandle tensor_handle)
{
return ctorch::cast(tensor_handle).item<float>();
}
long get_item_long(TorchTensorHandle tensor_handle)
{
return ctorch::cast(tensor_handle).item<long>();
}
int get_item_int(TorchTensorHandle tensor_handle)
{
return ctorch::cast(tensor_handle).item<int>();
}
TorchTensorHandle randn_double(int *shape, int shape_size, int device) TorchTensorHandle randn_double(int *shape, int shape_size, int device)
{ {
return new torch::Tensor(ctorch::randn<double>(ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device))); return new torch::Tensor(ctorch::randn<double>(ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device)));
@ -178,6 +198,15 @@ TorchTensorHandle rand_float(int *shape, int shape_size, int device)
return new torch::Tensor(ctorch::rand<float>(ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device))); return new torch::Tensor(ctorch::rand<float>(ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device)));
} }
TorchTensorHandle randint_long(long low, long high, int *shape, int shape_size, int device)
{
return new torch::Tensor(ctorch::randint<long>(low, high, ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device)));
}
TorchTensorHandle randint_int(int low, int high, int *shape, int shape_size, int device)
{
return new torch::Tensor(ctorch::randint<int>(low, high, ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device)));
}
TorchTensorHandle full_double(double value, int *shape, int shape_size, int device) TorchTensorHandle full_double(double value, int *shape, int shape_size, int device)
{ {
return new torch::Tensor(ctorch::full<double>(value, ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device))); return new torch::Tensor(ctorch::full<double>(value, ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device)));
@ -195,19 +224,6 @@ TorchTensorHandle full_int(int value, int *shape, int shape_size, int device)
return new torch::Tensor(ctorch::full<int>(value, ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device))); return new torch::Tensor(ctorch::full<int>(value, ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device)));
} }
TorchTensorHandle matmul(TorchTensorHandle lhs, TorchTensorHandle rhs)
{
return new torch::Tensor(torch::matmul(ctorch::cast(lhs), ctorch::cast(rhs)));
}
void matmul_assign(TorchTensorHandle lhs, TorchTensorHandle rhs)
{
ctorch::cast(lhs) = ctorch::cast(lhs).matmul(ctorch::cast(rhs));
}
void matmul_right_assign(TorchTensorHandle lhs, TorchTensorHandle rhs)
{
ctorch::cast(rhs) = ctorch::cast(lhs).matmul(ctorch::cast(rhs));
}
TorchTensorHandle plus_double(double value, TorchTensorHandle other) TorchTensorHandle plus_double(double value, TorchTensorHandle other)
{ {
return new torch::Tensor(ctorch::cast(other) + value); return new torch::Tensor(ctorch::cast(other) + value);
@ -356,21 +372,17 @@ void sum_tensor_assign(TorchTensorHandle tensor_handle)
ctorch::cast(tensor_handle) = ctorch::cast(tensor_handle).sum(); ctorch::cast(tensor_handle) = ctorch::cast(tensor_handle).sum();
} }
bool requires_grad(TorchTensorHandle tensor_handle) TorchTensorHandle matmul(TorchTensorHandle lhs, TorchTensorHandle rhs)
{ {
return ctorch::cast(tensor_handle).requires_grad(); return new torch::Tensor(torch::matmul(ctorch::cast(lhs), ctorch::cast(rhs)));
} }
void requires_grad_(TorchTensorHandle tensor_handle, bool status) void matmul_assign(TorchTensorHandle lhs, TorchTensorHandle rhs)
{ {
ctorch::cast(tensor_handle).requires_grad_(status); ctorch::cast(lhs) = ctorch::cast(lhs).matmul(ctorch::cast(rhs));
} }
TorchTensorHandle detach_from_graph(TorchTensorHandle tensor_handle) void matmul_right_assign(TorchTensorHandle lhs, TorchTensorHandle rhs)
{ {
return new torch::Tensor(ctorch::cast(tensor_handle).detach()); ctorch::cast(rhs) = ctorch::cast(lhs).matmul(ctorch::cast(rhs));
}
TorchTensorHandle autograd_tensor(TorchTensorHandle value, TorchTensorHandle variable)
{
return new torch::Tensor(torch::autograd::grad({ctorch::cast(value)}, {ctorch::cast(variable)})[0]);
} }
TorchTensorHandle diag_embed(TorchTensorHandle diags_handle, int offset, int dim1, int dim2) TorchTensorHandle diag_embed(TorchTensorHandle diags_handle, int offset, int dim1, int dim2)
@ -398,3 +410,20 @@ void symeig_tensor(TorchTensorHandle tensor_handle,
ctorch::cast(S_handle) = S; ctorch::cast(S_handle) = S;
ctorch::cast(V_handle) = V; ctorch::cast(V_handle) = V;
} }
bool requires_grad(TorchTensorHandle tensor_handle)
{
return ctorch::cast(tensor_handle).requires_grad();
}
void requires_grad_(TorchTensorHandle tensor_handle, bool status)
{
ctorch::cast(tensor_handle).requires_grad_(status);
}
TorchTensorHandle detach_from_graph(TorchTensorHandle tensor_handle)
{
return new torch::Tensor(ctorch::cast(tensor_handle).detach());
}
TorchTensorHandle autograd_tensor(TorchTensorHandle value, TorchTensorHandle variable)
{
return new torch::Tensor(torch::autograd::grad({ctorch::cast(value)}, {ctorch::cast(variable)})[0]);
}

View File

@ -0,0 +1,20 @@
package kscience.kmath.torch
import kotlin.test.Test
class BenchmarkMatMultFloatGPU {
@Test
fun benchmarkMatMult20() =
benchmarkingMatMultFloat(20, 10, 100000,
device = TorchDevice.TorchCUDA(0))
@Test
fun benchmarkMatMult200() =
benchmarkingMatMultFloat(200, 10, 10000,
device = TorchDevice.TorchCUDA(0))
@Test
fun benchmarkMatMult2000() =
benchmarkingMatMultFloat(2000, 10, 1000,
device = TorchDevice.TorchCUDA(0))
}

View File

@ -29,5 +29,4 @@ class TestTorchTensorAlgebraGPU {
fun testBatchedSymEig() = fun testBatchedSymEig() =
testingBatchedSymEig(device = TorchDevice.TorchCUDA(0)) testingBatchedSymEig(device = TorchDevice.TorchCUDA(0))
} }

View File

@ -2,7 +2,11 @@ package kscience.kmath.torch
public sealed class TorchDevice { public sealed class TorchDevice {
public object TorchCPU: TorchDevice() public object TorchCPU: TorchDevice() {
override fun toString(): String {
return "TorchCPU"
}
}
public data class TorchCUDA(val index: Int): TorchDevice() public data class TorchCUDA(val index: Int): TorchDevice()
public fun toInt(): Int { public fun toInt(): Int {
when(this) { when(this) {

View File

@ -53,6 +53,23 @@ public sealed class TorchTensor<T> constructor(
get() = requires_grad(tensorHandle) get() = requires_grad(tensorHandle)
set(value) = requires_grad_(tensorHandle, value) set(value) = requires_grad_(tensorHandle, value)
public fun copyToDouble(): TorchTensorReal = TorchTensorReal(
scope = scope,
tensorHandle = copy_to_double(this.tensorHandle)!!
)
public fun copyToFloat(): TorchTensorFloat = TorchTensorFloat(
scope = scope,
tensorHandle = copy_to_float(this.tensorHandle)!!
)
public fun copyToLong(): TorchTensorLong = TorchTensorLong(
scope = scope,
tensorHandle = copy_to_long(this.tensorHandle)!!
)
public fun copyToInt(): TorchTensorInt = TorchTensorInt(
scope = scope,
tensorHandle = copy_to_int(this.tensorHandle)!!
)
} }
public class TorchTensorReal internal constructor( public class TorchTensorReal internal constructor(
@ -66,6 +83,39 @@ public class TorchTensorReal internal constructor(
} }
} }
public class TorchTensorFloat internal constructor(
scope: DeferScope,
tensorHandle: COpaquePointer
) : TorchTensor<Float>(scope, tensorHandle) {
override fun item(): Float = get_item_float(tensorHandle)
override fun get(index: IntArray): Float = get_float(tensorHandle, index.toCValues())
override fun set(index: IntArray, value: Float) {
set_float(tensorHandle, index.toCValues(), value)
}
}
public class TorchTensorLong internal constructor(
scope: DeferScope,
tensorHandle: COpaquePointer
) : TorchTensor<Long>(scope, tensorHandle) {
override fun item(): Long = get_item_long(tensorHandle)
override fun get(index: IntArray): Long = get_long(tensorHandle, index.toCValues())
override fun set(index: IntArray, value: Long) {
set_long(tensorHandle, index.toCValues(), value)
}
}
public class TorchTensorInt internal constructor(
scope: DeferScope,
tensorHandle: COpaquePointer
) : TorchTensor<Int>(scope, tensorHandle) {
override fun item(): Int = get_item_int(tensorHandle)
override fun get(index: IntArray): Int = get_int(tensorHandle, index.toCValues())
override fun set(index: IntArray, value: Int) {
set_int(tensorHandle, index.toCValues(), value)
}
}
private inline fun indexFromOffset(offset: Int, strides: IntArray, nDim: Int): IntArray { private inline fun indexFromOffset(offset: Int, strides: IntArray, nDim: Int): IntArray {
val res = IntArray(nDim) val res = IntArray(nDim)

View File

@ -42,13 +42,6 @@ public sealed class TorchTensorAlgebra<
times_tensor_assign(this.tensorHandle, other.tensorHandle) times_tensor_assign(this.tensorHandle, other.tensorHandle)
} }
public operator fun TorchTensorType.div(other: TorchTensorType): TorchTensorType =
wrap(div_tensor(this.tensorHandle, other.tensorHandle)!!)
public operator fun TorchTensorType.divAssign(other: TorchTensorType): Unit {
div_tensor_assign(this.tensorHandle, other.tensorHandle)
}
public infix fun TorchTensorType.dot(other: TorchTensorType): TorchTensorType = public infix fun TorchTensorType.dot(other: TorchTensorType): TorchTensorType =
wrap(matmul(this.tensorHandle, other.tensorHandle)!!) wrap(matmul(this.tensorHandle, other.tensorHandle)!!)
@ -94,11 +87,48 @@ public sealed class TorchTensorAlgebra<
sum_tensor_assign(tensorHandle) sum_tensor_assign(tensorHandle)
} }
public fun diagEmbed(diagonalEntries: TorchTensorType, public fun diagEmbed(
offset: Int = 0, dim1: Int = -2, dim2: Int = -1): TorchTensorType = diagonalEntries: TorchTensorType,
offset: Int = 0, dim1: Int = -2, dim2: Int = -1
): TorchTensorType =
wrap(diag_embed(diagonalEntries.tensorHandle, offset, dim1, dim2)!!) wrap(diag_embed(diagonalEntries.tensorHandle, offset, dim1, dim2)!!)
public fun TorchTensorType.svd(): Triple<TorchTensorType,TorchTensorType,TorchTensorType> { public fun TorchTensorType.copy(): TorchTensorType =
wrap(copy_tensor(this.tensorHandle)!!)
public fun TorchTensorType.copyToDevice(device: TorchDevice): TorchTensorType =
wrap(copy_to_device(this.tensorHandle, device.toInt())!!)
public infix fun TorchTensorType.swap(otherTensor: TorchTensorType): Unit {
swap_tensors(this.tensorHandle, otherTensor.tensorHandle)
}
}
public sealed class TorchTensorFieldAlgebra<T, TVar : CPrimitiveVar,
PrimitiveArrayType, TorchTensorType : TorchTensor<T>>(scope: DeferScope) :
TorchTensorAlgebra<T, TVar, PrimitiveArrayType, TorchTensorType>(scope) {
public abstract fun randNormal(shape: IntArray, device: TorchDevice = TorchDevice.TorchCPU): TorchTensorType
public abstract fun randUniform(shape: IntArray, device: TorchDevice = TorchDevice.TorchCPU): TorchTensorType
public operator fun TorchTensorType.div(other: TorchTensorType): TorchTensorType =
wrap(div_tensor(this.tensorHandle, other.tensorHandle)!!)
public operator fun TorchTensorType.divAssign(other: TorchTensorType): Unit {
div_tensor_assign(this.tensorHandle, other.tensorHandle)
}
public fun TorchTensorType.exp(): TorchTensorType = wrap(exp_tensor(tensorHandle)!!)
public fun TorchTensorType.expAssign(): Unit {
exp_tensor_assign(tensorHandle)
}
public fun TorchTensorType.log(): TorchTensorType = wrap(log_tensor(tensorHandle)!!)
public fun TorchTensorType.logAssign(): Unit {
log_tensor_assign(tensorHandle)
}
public fun TorchTensorType.svd(): Triple<TorchTensorType, TorchTensorType, TorchTensorType> {
val U = empty_tensor()!! val U = empty_tensor()!!
val V = empty_tensor()!! val V = empty_tensor()!!
val S = empty_tensor()!! val S = empty_tensor()!!
@ -116,33 +146,17 @@ public sealed class TorchTensorAlgebra<
public infix fun TorchTensorType.grad(variable: TorchTensorType): TorchTensorType = public infix fun TorchTensorType.grad(variable: TorchTensorType): TorchTensorType =
wrap(autograd_tensor(this.tensorHandle, variable.tensorHandle)!!) wrap(autograd_tensor(this.tensorHandle, variable.tensorHandle)!!)
public fun TorchTensorType.copy(): TorchTensorType =
wrap(tensorHandle = copy_tensor(this.tensorHandle)!!)
public fun TorchTensorType.copyToDevice(device: TorchDevice): TorchTensorType =
wrap(tensorHandle = copy_to_device(this.tensorHandle, device.toInt())!!)
public fun TorchTensorType.detachFromGraph(): TorchTensorType = public fun TorchTensorType.detachFromGraph(): TorchTensorType =
wrap(tensorHandle = detach_from_graph(this.tensorHandle)!!) wrap(tensorHandle = detach_from_graph(this.tensorHandle)!!)
} }
public sealed class TorchTensorFieldAlgebra<T, TVar : CPrimitiveVar, public sealed class TorchTensorRingAlgebra<T, TVar : CPrimitiveVar,
PrimitiveArrayType, TorchTensorType : TorchTensor<T>>(scope: DeferScope) : PrimitiveArrayType, TorchTensorType : TorchTensor<T>>(scope: DeferScope) :
TorchTensorAlgebra<T, TVar, PrimitiveArrayType, TorchTensorType>(scope) { TorchTensorAlgebra<T, TVar, PrimitiveArrayType, TorchTensorType>(scope) {
public abstract fun randNormal(shape: IntArray, device: TorchDevice = TorchDevice.TorchCPU): TorchTensorType public abstract fun randIntegral(
public abstract fun randUniform(shape: IntArray, device: TorchDevice = TorchDevice.TorchCPU): TorchTensorType low: T, high: T, shape: IntArray,
device: TorchDevice = TorchDevice.TorchCPU
public fun TorchTensorType.exp(): TorchTensorType = wrap(exp_tensor(tensorHandle)!!) ): TorchTensorType
public fun TorchTensorType.expAssign(): Unit {
exp_tensor_assign(tensorHandle)
}
public fun TorchTensorType.log(): TorchTensorType = wrap(log_tensor(tensorHandle)!!)
public fun TorchTensorType.logAssign(): Unit {
log_tensor_assign(tensorHandle)
}
} }
public class TorchTensorRealAlgebra(scope: DeferScope) : public class TorchTensorRealAlgebra(scope: DeferScope) :
@ -153,33 +167,11 @@ public class TorchTensorRealAlgebra(scope: DeferScope) :
override fun TorchTensorReal.copyToArray(): DoubleArray = override fun TorchTensorReal.copyToArray(): DoubleArray =
this.elements().map { it.second }.toList().toDoubleArray() this.elements().map { it.second }.toList().toDoubleArray()
override fun copyFromArray( override fun copyFromArray(array: DoubleArray, shape: IntArray, device: TorchDevice): TorchTensorReal =
array: DoubleArray, wrap(from_blob_double(array.toCValues(), shape.toCValues(), shape.size, device.toInt(), true)!!)
shape: IntArray,
device: TorchDevice
): TorchTensorReal =
TorchTensorReal(
scope = scope,
tensorHandle = from_blob_double(
array.toCValues(),
shape.toCValues(),
shape.size,
device.toInt(),
true
)!!
)
override fun fromBlob(arrayBlob: CPointer<DoubleVar>, shape: IntArray): TorchTensorReal = override fun fromBlob(arrayBlob: CPointer<DoubleVar>, shape: IntArray): TorchTensorReal =
TorchTensorReal( wrap(from_blob_double(arrayBlob, shape.toCValues(), shape.size, TorchDevice.TorchCPU.toInt(), false)!!)
scope = scope,
tensorHandle = from_blob_double(
arrayBlob,
shape.toCValues(),
shape.size,
TorchDevice.TorchCPU.toInt(),
false
)!!
)
override fun TorchTensorReal.getData(): CPointer<DoubleVar> { override fun TorchTensorReal.getData(): CPointer<DoubleVar> {
require(this.device is TorchDevice.TorchCPU) { require(this.device is TorchDevice.TorchCPU) {
@ -188,65 +180,231 @@ public class TorchTensorRealAlgebra(scope: DeferScope) :
return get_data_double(this.tensorHandle)!! return get_data_double(this.tensorHandle)!!
} }
override fun randNormal(shape: IntArray, device: TorchDevice): TorchTensorReal = TorchTensorReal( override fun randNormal(shape: IntArray, device: TorchDevice): TorchTensorReal =
scope = scope, wrap(randn_double(shape.toCValues(), shape.size, device.toInt())!!)
tensorHandle = randn_double(shape.toCValues(), shape.size, device.toInt())!!
)
override fun randUniform(shape: IntArray, device: TorchDevice): TorchTensorReal = TorchTensorReal( override fun randUniform(shape: IntArray, device: TorchDevice): TorchTensorReal =
scope = scope, wrap(rand_double(shape.toCValues(), shape.size, device.toInt())!!)
tensorHandle = rand_double(shape.toCValues(), shape.size, device.toInt())!!
)
override operator fun Double.plus(other: TorchTensorReal): TorchTensorReal = TorchTensorReal( override operator fun Double.plus(other: TorchTensorReal): TorchTensorReal =
scope = scope, wrap(plus_double(this, other.tensorHandle)!!)
tensorHandle = plus_double(this, other.tensorHandle)!!
)
override fun TorchTensorReal.plus(value: Double): TorchTensorReal = TorchTensorReal( override fun TorchTensorReal.plus(value: Double): TorchTensorReal =
scope = scope, wrap(plus_double(value, this.tensorHandle)!!)
tensorHandle = plus_double(value, this.tensorHandle)!!
)
override fun TorchTensorReal.plusAssign(value: Double): Unit { override fun TorchTensorReal.plusAssign(value: Double): Unit {
plus_double_assign(value, this.tensorHandle) plus_double_assign(value, this.tensorHandle)
} }
override operator fun Double.minus(other: TorchTensorReal): TorchTensorReal = TorchTensorReal( override operator fun Double.minus(other: TorchTensorReal): TorchTensorReal =
scope = scope, wrap(plus_double(-this, other.tensorHandle)!!)
tensorHandle = plus_double(-this, other.tensorHandle)!!
)
override fun TorchTensorReal.minus(value: Double): TorchTensorReal = TorchTensorReal( override fun TorchTensorReal.minus(value: Double): TorchTensorReal =
scope = scope, wrap(plus_double(-value, this.tensorHandle)!!)
tensorHandle = plus_double(-value, this.tensorHandle)!!
)
override fun TorchTensorReal.minusAssign(value: Double): Unit { override fun TorchTensorReal.minusAssign(value: Double): Unit {
plus_double_assign(-value, this.tensorHandle) plus_double_assign(-value, this.tensorHandle)
} }
override operator fun Double.times(other: TorchTensorReal): TorchTensorReal = TorchTensorReal( override operator fun Double.times(other: TorchTensorReal): TorchTensorReal =
scope = scope, wrap(times_double(this, other.tensorHandle)!!)
tensorHandle = times_double(this, other.tensorHandle)!!
)
override fun TorchTensorReal.times(value: Double): TorchTensorReal = TorchTensorReal( override fun TorchTensorReal.times(value: Double): TorchTensorReal =
scope = scope, wrap(times_double(value, this.tensorHandle)!!)
tensorHandle = times_double(value, this.tensorHandle)!!
)
override fun TorchTensorReal.timesAssign(value: Double): Unit { override fun TorchTensorReal.timesAssign(value: Double): Unit {
times_double_assign(value, this.tensorHandle) times_double_assign(value, this.tensorHandle)
} }
override fun full(value: Double, shape: IntArray, device: TorchDevice): TorchTensorReal =
wrap(full_double(value, shape.toCValues(), shape.size, device.toInt())!!)
}
override fun full(value: Double, shape: IntArray, device: TorchDevice): TorchTensorReal = TorchTensorReal( public class TorchTensorFloatAlgebra(scope: DeferScope) :
scope = scope, TorchTensorFieldAlgebra<Float, FloatVar, FloatArray, TorchTensorFloat>(scope) {
tensorHandle = full_double(value, shape.toCValues(), shape.size, device.toInt())!! override fun wrap(tensorHandle: COpaquePointer): TorchTensorFloat =
) TorchTensorFloat(scope = scope, tensorHandle = tensorHandle)
override fun TorchTensorFloat.copyToArray(): FloatArray =
this.elements().map { it.second }.toList().toFloatArray()
override fun copyFromArray(array: FloatArray, shape: IntArray, device: TorchDevice): TorchTensorFloat =
wrap(from_blob_float(array.toCValues(), shape.toCValues(), shape.size, device.toInt(), true)!!)
override fun fromBlob(arrayBlob: CPointer<FloatVar>, shape: IntArray): TorchTensorFloat =
wrap(from_blob_float(arrayBlob, shape.toCValues(), shape.size, TorchDevice.TorchCPU.toInt(), false)!!)
override fun TorchTensorFloat.getData(): CPointer<FloatVar> {
require(this.device is TorchDevice.TorchCPU) {
"This tensor is not on available on CPU"
}
return get_data_float(this.tensorHandle)!!
}
override fun randNormal(shape: IntArray, device: TorchDevice): TorchTensorFloat =
wrap(randn_float(shape.toCValues(), shape.size, device.toInt())!!)
override fun randUniform(shape: IntArray, device: TorchDevice): TorchTensorFloat =
wrap(rand_float(shape.toCValues(), shape.size, device.toInt())!!)
override operator fun Float.plus(other: TorchTensorFloat): TorchTensorFloat =
wrap(plus_float(this, other.tensorHandle)!!)
override fun TorchTensorFloat.plus(value: Float): TorchTensorFloat =
wrap(plus_float(value, this.tensorHandle)!!)
override fun TorchTensorFloat.plusAssign(value: Float): Unit {
plus_float_assign(value, this.tensorHandle)
}
override operator fun Float.minus(other: TorchTensorFloat): TorchTensorFloat =
wrap(plus_float(-this, other.tensorHandle)!!)
override fun TorchTensorFloat.minus(value: Float): TorchTensorFloat =
wrap(plus_float(-value, this.tensorHandle)!!)
override fun TorchTensorFloat.minusAssign(value: Float): Unit {
plus_float_assign(-value, this.tensorHandle)
}
override operator fun Float.times(other: TorchTensorFloat): TorchTensorFloat =
wrap(times_float(this, other.tensorHandle)!!)
override fun TorchTensorFloat.times(value: Float): TorchTensorFloat =
wrap(times_float(value, this.tensorHandle)!!)
override fun TorchTensorFloat.timesAssign(value: Float): Unit {
times_float_assign(value, this.tensorHandle)
}
override fun full(value: Float, shape: IntArray, device: TorchDevice): TorchTensorFloat =
wrap(full_float(value, shape.toCValues(), shape.size, device.toInt())!!)
}
public class TorchTensorLongAlgebra(scope: DeferScope) :
TorchTensorRingAlgebra<Long, LongVar, LongArray, TorchTensorLong>(scope) {
override fun wrap(tensorHandle: COpaquePointer): TorchTensorLong =
TorchTensorLong(scope = scope, tensorHandle = tensorHandle)
override fun TorchTensorLong.copyToArray(): LongArray =
this.elements().map { it.second }.toList().toLongArray()
override fun copyFromArray(array: LongArray, shape: IntArray, device: TorchDevice): TorchTensorLong =
wrap(from_blob_long(array.toCValues(), shape.toCValues(), shape.size, device.toInt(), true)!!)
override fun fromBlob(arrayBlob: CPointer<LongVar>, shape: IntArray): TorchTensorLong =
wrap(from_blob_long(arrayBlob, shape.toCValues(), shape.size, TorchDevice.TorchCPU.toInt(), false)!!)
override fun TorchTensorLong.getData(): CPointer<LongVar> {
require(this.device is TorchDevice.TorchCPU) {
"This tensor is not on available on CPU"
}
return get_data_long(this.tensorHandle)!!
}
override fun randIntegral(low: Long, high: Long, shape: IntArray, device: TorchDevice): TorchTensorLong =
wrap(randint_long(low, high, shape.toCValues(), shape.size, device.toInt())!!)
override operator fun Long.plus(other: TorchTensorLong): TorchTensorLong =
wrap(plus_long(this, other.tensorHandle)!!)
override fun TorchTensorLong.plus(value: Long): TorchTensorLong =
wrap(plus_long(value, this.tensorHandle)!!)
override fun TorchTensorLong.plusAssign(value: Long): Unit {
plus_long_assign(value, this.tensorHandle)
}
override operator fun Long.minus(other: TorchTensorLong): TorchTensorLong =
wrap(plus_long(-this, other.tensorHandle)!!)
override fun TorchTensorLong.minus(value: Long): TorchTensorLong =
wrap(plus_long(-value, this.tensorHandle)!!)
override fun TorchTensorLong.minusAssign(value: Long): Unit {
plus_long_assign(-value, this.tensorHandle)
}
override operator fun Long.times(other: TorchTensorLong): TorchTensorLong =
wrap(times_long(this, other.tensorHandle)!!)
override fun TorchTensorLong.times(value: Long): TorchTensorLong =
wrap(times_long(value, this.tensorHandle)!!)
override fun TorchTensorLong.timesAssign(value: Long): Unit {
times_long_assign(value, this.tensorHandle)
}
override fun full(value: Long, shape: IntArray, device: TorchDevice): TorchTensorLong =
wrap(full_long(value, shape.toCValues(), shape.size, device.toInt())!!)
}
public class TorchTensorIntAlgebra(scope: DeferScope) :
TorchTensorRingAlgebra<Int, IntVar, IntArray, TorchTensorInt>(scope) {
override fun wrap(tensorHandle: COpaquePointer): TorchTensorInt =
TorchTensorInt(scope = scope, tensorHandle = tensorHandle)
override fun TorchTensorInt.copyToArray(): IntArray =
this.elements().map { it.second }.toList().toIntArray()
override fun copyFromArray(array: IntArray, shape: IntArray, device: TorchDevice): TorchTensorInt =
wrap(from_blob_int(array.toCValues(), shape.toCValues(), shape.size, device.toInt(), true)!!)
override fun fromBlob(arrayBlob: CPointer<IntVar>, shape: IntArray): TorchTensorInt =
wrap(from_blob_int(arrayBlob, shape.toCValues(), shape.size, TorchDevice.TorchCPU.toInt(), false)!!)
override fun TorchTensorInt.getData(): CPointer<IntVar> {
require(this.device is TorchDevice.TorchCPU) {
"This tensor is not on available on CPU"
}
return get_data_int(this.tensorHandle)!!
}
override fun randIntegral(low: Int, high: Int, shape: IntArray, device: TorchDevice): TorchTensorInt =
wrap(randint_int(low, high, shape.toCValues(), shape.size, device.toInt())!!)
override operator fun Int.plus(other: TorchTensorInt): TorchTensorInt =
wrap(plus_int(this, other.tensorHandle)!!)
override fun TorchTensorInt.plus(value: Int): TorchTensorInt =
wrap(plus_int(value, this.tensorHandle)!!)
override fun TorchTensorInt.plusAssign(value: Int): Unit {
plus_int_assign(value, this.tensorHandle)
}
override operator fun Int.minus(other: TorchTensorInt): TorchTensorInt =
wrap(plus_int(-this, other.tensorHandle)!!)
override fun TorchTensorInt.minus(value: Int): TorchTensorInt =
wrap(plus_int(-value, this.tensorHandle)!!)
override fun TorchTensorInt.minusAssign(value: Int): Unit {
plus_int_assign(-value, this.tensorHandle)
}
override operator fun Int.times(other: TorchTensorInt): TorchTensorInt =
wrap(times_int(this, other.tensorHandle)!!)
override fun TorchTensorInt.times(value: Int): TorchTensorInt =
wrap(times_int(value, this.tensorHandle)!!)
override fun TorchTensorInt.timesAssign(value: Int): Unit {
times_int_assign(value, this.tensorHandle)
}
override fun full(value: Int, shape: IntArray, device: TorchDevice): TorchTensorInt =
wrap(full_int(value, shape.toCValues(), shape.size, device.toInt())!!)
} }
public inline fun <R> TorchTensorRealAlgebra(block: TorchTensorRealAlgebra.() -> R): R = public inline fun <R> TorchTensorRealAlgebra(block: TorchTensorRealAlgebra.() -> R): R =
memScoped { TorchTensorRealAlgebra(this).block() } memScoped { TorchTensorRealAlgebra(this).block() }
public inline fun <R> TorchTensorFloatAlgebra(block: TorchTensorFloatAlgebra.() -> R): R =
memScoped { TorchTensorFloatAlgebra(this).block() }
public inline fun <R> TorchTensorLongAlgebra(block: TorchTensorLongAlgebra.() -> R): R =
memScoped { TorchTensorLongAlgebra(this).block() }
public inline fun <R> TorchTensorIntAlgebra(block: TorchTensorIntAlgebra.() -> R): R =
memScoped { TorchTensorIntAlgebra(this).block() }

View File

@ -3,14 +3,14 @@ package kscience.kmath.torch
import kotlin.test.Test import kotlin.test.Test
import kotlin.time.measureTime import kotlin.time.measureTime
internal fun benchmarkingDoubleMatrixMultiplication( internal fun benchmarkingMatMultDouble(
scale: Int, scale: Int,
numWarmUp: Int, numWarmUp: Int,
numIter: Int, numIter: Int,
device: TorchDevice = TorchDevice.TorchCPU device: TorchDevice = TorchDevice.TorchCPU
): Unit { ): Unit {
TorchTensorRealAlgebra { TorchTensorRealAlgebra {
println("Benchmarking $scale x $scale matrices over Double's: ") println("Benchmarking $scale x $scale matrices over Double's on $device: ")
setSeed(SEED) setSeed(SEED)
val lhs = randNormal(shape = intArrayOf(scale, scale), device = device) val lhs = randNormal(shape = intArrayOf(scale, scale), device = device)
val rhs = randNormal(shape = intArrayOf(scale, scale), device = device) val rhs = randNormal(shape = intArrayOf(scale, scale), device = device)
@ -20,15 +20,18 @@ internal fun benchmarkingDoubleMatrixMultiplication(
} }
} }
class BenchmarkMatrixMultiplicationDouble { internal class BenchmarkMatMultDouble {
@Test @Test
fun benchmarkMatrixMultiplication20() = benchmarkingDoubleMatrixMultiplication(20, 10, 100000) fun benchmarkMatMult20() =
benchmarkingMatMultDouble(20, 10, 100000)
@Test @Test
fun benchmarkMatrixMultiplication200() = benchmarkingDoubleMatrixMultiplication(200, 10, 10000) fun benchmarkMatMult200() =
benchmarkingMatMultDouble(200, 10, 10000)
@Test @Test
fun benchmarkMatrixMultiplication2000() = benchmarkingDoubleMatrixMultiplication(2000, 3, 20) fun benchmarkMatMult2000() =
benchmarkingMatMultDouble(2000, 3, 20)
} }

View File

@ -0,0 +1,36 @@
package kscience.kmath.torch
import kotlin.test.Test
import kotlin.time.measureTime
internal fun benchmarkingMatMultFloat(
scale: Int,
numWarmUp: Int,
numIter: Int,
device: TorchDevice = TorchDevice.TorchCPU
): Unit {
TorchTensorFloatAlgebra {
println("Benchmarking $scale x $scale matrices over Float's on $device: ")
setSeed(SEED)
val lhs = randNormal(shape = intArrayOf(scale, scale), device = device)
val rhs = randNormal(shape = intArrayOf(scale, scale), device = device)
repeat(numWarmUp) { lhs dotAssign rhs }
val measuredTime = measureTime { repeat(numIter) { lhs dotAssign rhs } }
println(" ${measuredTime / numIter} p.o. with $numIter iterations")
}
}
internal class BenchmarkMatMultFloat {
@Test
fun benchmarkMatMult20() =
benchmarkingMatMultFloat(20, 10, 100000)
@Test
fun benchmarkMatMult200() =
benchmarkingMatMultFloat(200, 10, 10000)
@Test
fun benchmarkMatMult2000() =
benchmarkingMatMultFloat(2000, 3, 20)
}

View File

@ -44,7 +44,6 @@ internal fun testingBatchedAutoGrad(bath: IntArray,
val error = (gradientAtX - expectedGradientAtX).abs().sum().value() val error = (gradientAtX - expectedGradientAtX).abs().sum().value()
assertTrue(error < TOLERANCE) assertTrue(error < TOLERANCE)
} }
} }

View File

@ -49,4 +49,15 @@ class TestTorchTensor {
val detachedTensor = tensor.detachFromGraph() val detachedTensor = tensor.detachFromGraph()
assertTrue(!detachedTensor.requiresGrad) assertTrue(!detachedTensor.requiresGrad)
} }
@Test
fun testTypeMoving() = TorchTensorFloatAlgebra {
val tensorInt = copyFromArray(floatArrayOf(1f,2f,3f), intArrayOf(3)).copyToInt()
TorchTensorIntAlgebra {
val temporalTensor = copyFromArray(intArrayOf(4,5,6),intArrayOf(3))
tensorInt swap temporalTensor
assertTrue(temporalTensor.copyToArray() contentEquals intArrayOf(1,2,3))
}
assertTrue(tensorInt.copyToFloat().copyToArray() contentEquals floatArrayOf(4f,5f,6f))
}
} }