forked from kscience/kmath
Float, Long and Int support
This commit is contained in:
parent
80f28dbcd5
commit
ef570254e6
@ -20,22 +20,21 @@ extern "C"
|
||||
|
||||
TorchTensorHandle empty_tensor();
|
||||
|
||||
double *get_data_double(TorchTensorHandle tensor_handle);
|
||||
float *get_data_float(TorchTensorHandle tensor_handle);
|
||||
long *get_data_long(TorchTensorHandle tensor_handle);
|
||||
int *get_data_int(TorchTensorHandle tensor_handle);
|
||||
|
||||
TorchTensorHandle from_blob_double(double *data, int *shape, int dim, int device, bool copy);
|
||||
TorchTensorHandle from_blob_float(float *data, int *shape, int dim, int device, bool copy);
|
||||
TorchTensorHandle from_blob_long(long *data, int *shape, int dim, int device, bool copy);
|
||||
TorchTensorHandle from_blob_int(int *data, int *shape, int dim, int device, bool copy);
|
||||
TorchTensorHandle copy_tensor(TorchTensorHandle tensor_handle);
|
||||
TorchTensorHandle copy_to_device(TorchTensorHandle tensor_handle, int device);
|
||||
TorchTensorHandle copy_to_double(TorchTensorHandle tensor_handle);
|
||||
TorchTensorHandle copy_to_float(TorchTensorHandle tensor_handle);
|
||||
TorchTensorHandle copy_to_long(TorchTensorHandle tensor_handle);
|
||||
TorchTensorHandle copy_to_int(TorchTensorHandle tensor_handle);
|
||||
void swap_tensors(TorchTensorHandle lhs_handle, TorchTensorHandle rhs_handle);
|
||||
|
||||
double get_item_double(TorchTensorHandle tensor_handle);
|
||||
float get_item_float(TorchTensorHandle tensor_handle);
|
||||
long get_item_long(TorchTensorHandle tensor_handle);
|
||||
int get_item_int(TorchTensorHandle tensor_handle);
|
||||
char *tensor_to_string(TorchTensorHandle tensor_handle);
|
||||
void dispose_char(char *ptr);
|
||||
void dispose_tensor(TorchTensorHandle tensor_handle);
|
||||
|
||||
int get_dim(TorchTensorHandle tensor_handle);
|
||||
int get_numel(TorchTensorHandle tensor_handle);
|
||||
@ -43,9 +42,15 @@ extern "C"
|
||||
int get_stride_at(TorchTensorHandle tensor_handle, int d);
|
||||
int get_device(TorchTensorHandle tensor_handle);
|
||||
|
||||
char *tensor_to_string(TorchTensorHandle tensor_handle);
|
||||
void dispose_char(char *ptr);
|
||||
void dispose_tensor(TorchTensorHandle tensor_handle);
|
||||
double *get_data_double(TorchTensorHandle tensor_handle);
|
||||
float *get_data_float(TorchTensorHandle tensor_handle);
|
||||
long *get_data_long(TorchTensorHandle tensor_handle);
|
||||
int *get_data_int(TorchTensorHandle tensor_handle);
|
||||
|
||||
double get_item_double(TorchTensorHandle tensor_handle);
|
||||
float get_item_float(TorchTensorHandle tensor_handle);
|
||||
long get_item_long(TorchTensorHandle tensor_handle);
|
||||
int get_item_int(TorchTensorHandle tensor_handle);
|
||||
|
||||
double get_double(TorchTensorHandle tensor_handle, int *index);
|
||||
float get_float(TorchTensorHandle tensor_handle, int *index);
|
||||
@ -61,15 +66,14 @@ extern "C"
|
||||
TorchTensorHandle randn_float(int *shape, int shape_size, int device);
|
||||
TorchTensorHandle rand_float(int *shape, int shape_size, int device);
|
||||
|
||||
TorchTensorHandle randint_long(long low, long high, int *shape, int shape_size, int device);
|
||||
TorchTensorHandle randint_int(int low, int high, int *shape, int shape_size, int device);
|
||||
|
||||
TorchTensorHandle full_double(double value, int *shape, int shape_size, int device);
|
||||
TorchTensorHandle full_float(float value, int *shape, int shape_size, int device);
|
||||
TorchTensorHandle full_long(long value, int *shape, int shape_size, int device);
|
||||
TorchTensorHandle full_int(int value, int *shape, int shape_size, int device);
|
||||
|
||||
TorchTensorHandle matmul(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||
void matmul_assign(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||
void matmul_right_assign(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||
|
||||
TorchTensorHandle times_double(double value, TorchTensorHandle other);
|
||||
TorchTensorHandle times_float(float value, TorchTensorHandle other);
|
||||
TorchTensorHandle times_long(long value, TorchTensorHandle other);
|
||||
@ -115,10 +119,9 @@ extern "C"
|
||||
TorchTensorHandle sum_tensor(TorchTensorHandle tensor_handle);
|
||||
void sum_tensor_assign(TorchTensorHandle tensor_handle);
|
||||
|
||||
bool requires_grad(TorchTensorHandle tensor_handle);
|
||||
void requires_grad_(TorchTensorHandle tensor_handle, bool status);
|
||||
TorchTensorHandle detach_from_graph(TorchTensorHandle tensor_handle);
|
||||
TorchTensorHandle autograd_tensor(TorchTensorHandle value, TorchTensorHandle variable);
|
||||
TorchTensorHandle matmul(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||
void matmul_assign(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||
void matmul_right_assign(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||
|
||||
TorchTensorHandle diag_embed(TorchTensorHandle diags_handle, int offset, int dim1, int dim2);
|
||||
|
||||
@ -132,6 +135,11 @@ extern "C"
|
||||
TorchTensorHandle V_handle,
|
||||
bool eigenvectors);
|
||||
|
||||
bool requires_grad(TorchTensorHandle tensor_handle);
|
||||
void requires_grad_(TorchTensorHandle tensor_handle, bool status);
|
||||
TorchTensorHandle detach_from_graph(TorchTensorHandle tensor_handle);
|
||||
TorchTensorHandle autograd_tensor(TorchTensorHandle value, TorchTensorHandle variable);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
@ -92,6 +92,12 @@ namespace ctorch
|
||||
return torch::rand(shape, torch::TensorOptions().dtype(dtype<Dtype>()).layout(torch::kStrided).device(device));
|
||||
}
|
||||
|
||||
template <typename Dtype>
|
||||
inline torch::Tensor randint(Dtype low, Dtype high, std::vector<int64_t> shape, torch::Device device)
|
||||
{
|
||||
return torch::randint(low, high, shape, torch::TensorOptions().dtype(dtype<Dtype>()).layout(torch::kStrided).device(device));
|
||||
}
|
||||
|
||||
template <typename Dtype>
|
||||
inline torch::Tensor full(Dtype value, std::vector<int64_t> shape, torch::Device device)
|
||||
{
|
||||
|
@ -30,44 +30,6 @@ TorchTensorHandle empty_tensor()
|
||||
return new torch::Tensor;
|
||||
}
|
||||
|
||||
double *get_data_double(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return ctorch::cast(tensor_handle).data_ptr<double>();
|
||||
}
|
||||
float *get_data_float(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return ctorch::cast(tensor_handle).data_ptr<float>();
|
||||
}
|
||||
long *get_data_long(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return ctorch::cast(tensor_handle).data_ptr<long>();
|
||||
}
|
||||
int *get_data_int(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return ctorch::cast(tensor_handle).data_ptr<int>();
|
||||
}
|
||||
|
||||
int get_dim(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return ctorch::cast(tensor_handle).dim();
|
||||
}
|
||||
int get_numel(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return ctorch::cast(tensor_handle).numel();
|
||||
}
|
||||
int get_shape_at(TorchTensorHandle tensor_handle, int d)
|
||||
{
|
||||
return ctorch::cast(tensor_handle).size(d);
|
||||
}
|
||||
int get_stride_at(TorchTensorHandle tensor_handle, int d)
|
||||
{
|
||||
return ctorch::cast(tensor_handle).stride(d);
|
||||
}
|
||||
int get_device(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return ctorch::device_to_int(ctorch::cast(tensor_handle));
|
||||
}
|
||||
|
||||
TorchTensorHandle from_blob_double(double *data, int *shape, int dim, int device, bool copy)
|
||||
{
|
||||
return new torch::Tensor(ctorch::from_blob<double>(data, ctorch::to_vec_int(shape, dim), ctorch::int_to_device(device), copy));
|
||||
@ -92,6 +54,26 @@ TorchTensorHandle copy_to_device(TorchTensorHandle tensor_handle, int device)
|
||||
{
|
||||
return new torch::Tensor(ctorch::cast(tensor_handle).to(ctorch::int_to_device(device), false, true));
|
||||
}
|
||||
TorchTensorHandle copy_to_double(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return new torch::Tensor(ctorch::cast(tensor_handle).to(ctorch::dtype<double>(), false, true));
|
||||
}
|
||||
TorchTensorHandle copy_to_float(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return new torch::Tensor(ctorch::cast(tensor_handle).to(ctorch::dtype<float>(), false, true));
|
||||
}
|
||||
TorchTensorHandle copy_to_long(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return new torch::Tensor(ctorch::cast(tensor_handle).to(ctorch::dtype<long>(), false, true));
|
||||
}
|
||||
TorchTensorHandle copy_to_int(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return new torch::Tensor(ctorch::cast(tensor_handle).to(ctorch::dtype<int>(), false, true));
|
||||
}
|
||||
void swap_tensors(TorchTensorHandle lhs_handle, TorchTensorHandle rhs_handle)
|
||||
{
|
||||
std::swap(ctorch::cast(lhs_handle), ctorch::cast(rhs_handle));
|
||||
}
|
||||
|
||||
char *tensor_to_string(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
@ -111,6 +93,61 @@ void dispose_tensor(TorchTensorHandle tensor_handle)
|
||||
delete static_cast<torch::Tensor *>(tensor_handle);
|
||||
}
|
||||
|
||||
int get_dim(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return ctorch::cast(tensor_handle).dim();
|
||||
}
|
||||
int get_numel(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return ctorch::cast(tensor_handle).numel();
|
||||
}
|
||||
int get_shape_at(TorchTensorHandle tensor_handle, int d)
|
||||
{
|
||||
return ctorch::cast(tensor_handle).size(d);
|
||||
}
|
||||
int get_stride_at(TorchTensorHandle tensor_handle, int d)
|
||||
{
|
||||
return ctorch::cast(tensor_handle).stride(d);
|
||||
}
|
||||
int get_device(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return ctorch::device_to_int(ctorch::cast(tensor_handle));
|
||||
}
|
||||
|
||||
double *get_data_double(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return ctorch::cast(tensor_handle).data_ptr<double>();
|
||||
}
|
||||
float *get_data_float(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return ctorch::cast(tensor_handle).data_ptr<float>();
|
||||
}
|
||||
long *get_data_long(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return ctorch::cast(tensor_handle).data_ptr<long>();
|
||||
}
|
||||
int *get_data_int(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return ctorch::cast(tensor_handle).data_ptr<int>();
|
||||
}
|
||||
|
||||
double get_item_double(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return ctorch::cast(tensor_handle).item<double>();
|
||||
}
|
||||
float get_item_float(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return ctorch::cast(tensor_handle).item<float>();
|
||||
}
|
||||
long get_item_long(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return ctorch::cast(tensor_handle).item<long>();
|
||||
}
|
||||
int get_item_int(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return ctorch::cast(tensor_handle).item<int>();
|
||||
}
|
||||
|
||||
double get_double(TorchTensorHandle tensor_handle, int *index)
|
||||
{
|
||||
return ctorch::get<double>(tensor_handle, index);
|
||||
@ -144,23 +181,6 @@ void set_int(TorchTensorHandle tensor_handle, int *index, int value)
|
||||
ctorch::set<int>(tensor_handle, index, value);
|
||||
}
|
||||
|
||||
double get_item_double(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return ctorch::cast(tensor_handle).item<double>();
|
||||
}
|
||||
float get_item_float(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return ctorch::cast(tensor_handle).item<float>();
|
||||
}
|
||||
long get_item_long(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return ctorch::cast(tensor_handle).item<long>();
|
||||
}
|
||||
int get_item_int(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return ctorch::cast(tensor_handle).item<int>();
|
||||
}
|
||||
|
||||
TorchTensorHandle randn_double(int *shape, int shape_size, int device)
|
||||
{
|
||||
return new torch::Tensor(ctorch::randn<double>(ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device)));
|
||||
@ -178,6 +198,15 @@ TorchTensorHandle rand_float(int *shape, int shape_size, int device)
|
||||
return new torch::Tensor(ctorch::rand<float>(ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device)));
|
||||
}
|
||||
|
||||
TorchTensorHandle randint_long(long low, long high, int *shape, int shape_size, int device)
|
||||
{
|
||||
return new torch::Tensor(ctorch::randint<long>(low, high, ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device)));
|
||||
}
|
||||
TorchTensorHandle randint_int(int low, int high, int *shape, int shape_size, int device)
|
||||
{
|
||||
return new torch::Tensor(ctorch::randint<int>(low, high, ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device)));
|
||||
}
|
||||
|
||||
TorchTensorHandle full_double(double value, int *shape, int shape_size, int device)
|
||||
{
|
||||
return new torch::Tensor(ctorch::full<double>(value, ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device)));
|
||||
@ -195,19 +224,6 @@ TorchTensorHandle full_int(int value, int *shape, int shape_size, int device)
|
||||
return new torch::Tensor(ctorch::full<int>(value, ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device)));
|
||||
}
|
||||
|
||||
TorchTensorHandle matmul(TorchTensorHandle lhs, TorchTensorHandle rhs)
|
||||
{
|
||||
return new torch::Tensor(torch::matmul(ctorch::cast(lhs), ctorch::cast(rhs)));
|
||||
}
|
||||
void matmul_assign(TorchTensorHandle lhs, TorchTensorHandle rhs)
|
||||
{
|
||||
ctorch::cast(lhs) = ctorch::cast(lhs).matmul(ctorch::cast(rhs));
|
||||
}
|
||||
void matmul_right_assign(TorchTensorHandle lhs, TorchTensorHandle rhs)
|
||||
{
|
||||
ctorch::cast(rhs) = ctorch::cast(lhs).matmul(ctorch::cast(rhs));
|
||||
}
|
||||
|
||||
TorchTensorHandle plus_double(double value, TorchTensorHandle other)
|
||||
{
|
||||
return new torch::Tensor(ctorch::cast(other) + value);
|
||||
@ -356,21 +372,17 @@ void sum_tensor_assign(TorchTensorHandle tensor_handle)
|
||||
ctorch::cast(tensor_handle) = ctorch::cast(tensor_handle).sum();
|
||||
}
|
||||
|
||||
bool requires_grad(TorchTensorHandle tensor_handle)
|
||||
TorchTensorHandle matmul(TorchTensorHandle lhs, TorchTensorHandle rhs)
|
||||
{
|
||||
return ctorch::cast(tensor_handle).requires_grad();
|
||||
return new torch::Tensor(torch::matmul(ctorch::cast(lhs), ctorch::cast(rhs)));
|
||||
}
|
||||
void requires_grad_(TorchTensorHandle tensor_handle, bool status)
|
||||
void matmul_assign(TorchTensorHandle lhs, TorchTensorHandle rhs)
|
||||
{
|
||||
ctorch::cast(tensor_handle).requires_grad_(status);
|
||||
ctorch::cast(lhs) = ctorch::cast(lhs).matmul(ctorch::cast(rhs));
|
||||
}
|
||||
TorchTensorHandle detach_from_graph(TorchTensorHandle tensor_handle)
|
||||
void matmul_right_assign(TorchTensorHandle lhs, TorchTensorHandle rhs)
|
||||
{
|
||||
return new torch::Tensor(ctorch::cast(tensor_handle).detach());
|
||||
}
|
||||
TorchTensorHandle autograd_tensor(TorchTensorHandle value, TorchTensorHandle variable)
|
||||
{
|
||||
return new torch::Tensor(torch::autograd::grad({ctorch::cast(value)}, {ctorch::cast(variable)})[0]);
|
||||
ctorch::cast(rhs) = ctorch::cast(lhs).matmul(ctorch::cast(rhs));
|
||||
}
|
||||
|
||||
TorchTensorHandle diag_embed(TorchTensorHandle diags_handle, int offset, int dim1, int dim2)
|
||||
@ -390,11 +402,28 @@ void svd_tensor(TorchTensorHandle tensor_handle,
|
||||
}
|
||||
|
||||
void symeig_tensor(TorchTensorHandle tensor_handle,
|
||||
TorchTensorHandle S_handle,
|
||||
TorchTensorHandle V_handle,
|
||||
bool eigenvectors)
|
||||
TorchTensorHandle S_handle,
|
||||
TorchTensorHandle V_handle,
|
||||
bool eigenvectors)
|
||||
{
|
||||
auto [S, V] = torch::symeig(ctorch::cast(tensor_handle), eigenvectors);
|
||||
ctorch::cast(S_handle) = S;
|
||||
ctorch::cast(V_handle) = V;
|
||||
}
|
||||
|
||||
bool requires_grad(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return ctorch::cast(tensor_handle).requires_grad();
|
||||
}
|
||||
void requires_grad_(TorchTensorHandle tensor_handle, bool status)
|
||||
{
|
||||
ctorch::cast(tensor_handle).requires_grad_(status);
|
||||
}
|
||||
TorchTensorHandle detach_from_graph(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return new torch::Tensor(ctorch::cast(tensor_handle).detach());
|
||||
}
|
||||
TorchTensorHandle autograd_tensor(TorchTensorHandle value, TorchTensorHandle variable)
|
||||
{
|
||||
return new torch::Tensor(torch::autograd::grad({ctorch::cast(value)}, {ctorch::cast(variable)})[0]);
|
||||
}
|
||||
|
@ -0,0 +1,20 @@
|
||||
package kscience.kmath.torch
|
||||
|
||||
import kotlin.test.Test
|
||||
|
||||
class BenchmarkMatMultFloatGPU {
|
||||
@Test
|
||||
fun benchmarkMatMult20() =
|
||||
benchmarkingMatMultFloat(20, 10, 100000,
|
||||
device = TorchDevice.TorchCUDA(0))
|
||||
|
||||
@Test
|
||||
fun benchmarkMatMult200() =
|
||||
benchmarkingMatMultFloat(200, 10, 10000,
|
||||
device = TorchDevice.TorchCUDA(0))
|
||||
|
||||
@Test
|
||||
fun benchmarkMatMult2000() =
|
||||
benchmarkingMatMultFloat(2000, 10, 1000,
|
||||
device = TorchDevice.TorchCUDA(0))
|
||||
}
|
@ -29,5 +29,4 @@ class TestTorchTensorAlgebraGPU {
|
||||
fun testBatchedSymEig() =
|
||||
testingBatchedSymEig(device = TorchDevice.TorchCUDA(0))
|
||||
|
||||
|
||||
}
|
||||
|
@ -2,7 +2,11 @@ package kscience.kmath.torch
|
||||
|
||||
|
||||
public sealed class TorchDevice {
|
||||
public object TorchCPU: TorchDevice()
|
||||
public object TorchCPU: TorchDevice() {
|
||||
override fun toString(): String {
|
||||
return "TorchCPU"
|
||||
}
|
||||
}
|
||||
public data class TorchCUDA(val index: Int): TorchDevice()
|
||||
public fun toInt(): Int {
|
||||
when(this) {
|
||||
|
@ -53,6 +53,23 @@ public sealed class TorchTensor<T> constructor(
|
||||
get() = requires_grad(tensorHandle)
|
||||
set(value) = requires_grad_(tensorHandle, value)
|
||||
|
||||
public fun copyToDouble(): TorchTensorReal = TorchTensorReal(
|
||||
scope = scope,
|
||||
tensorHandle = copy_to_double(this.tensorHandle)!!
|
||||
)
|
||||
public fun copyToFloat(): TorchTensorFloat = TorchTensorFloat(
|
||||
scope = scope,
|
||||
tensorHandle = copy_to_float(this.tensorHandle)!!
|
||||
)
|
||||
public fun copyToLong(): TorchTensorLong = TorchTensorLong(
|
||||
scope = scope,
|
||||
tensorHandle = copy_to_long(this.tensorHandle)!!
|
||||
)
|
||||
public fun copyToInt(): TorchTensorInt = TorchTensorInt(
|
||||
scope = scope,
|
||||
tensorHandle = copy_to_int(this.tensorHandle)!!
|
||||
)
|
||||
|
||||
}
|
||||
|
||||
public class TorchTensorReal internal constructor(
|
||||
@ -66,6 +83,39 @@ public class TorchTensorReal internal constructor(
|
||||
}
|
||||
}
|
||||
|
||||
public class TorchTensorFloat internal constructor(
|
||||
scope: DeferScope,
|
||||
tensorHandle: COpaquePointer
|
||||
) : TorchTensor<Float>(scope, tensorHandle) {
|
||||
override fun item(): Float = get_item_float(tensorHandle)
|
||||
override fun get(index: IntArray): Float = get_float(tensorHandle, index.toCValues())
|
||||
override fun set(index: IntArray, value: Float) {
|
||||
set_float(tensorHandle, index.toCValues(), value)
|
||||
}
|
||||
}
|
||||
|
||||
public class TorchTensorLong internal constructor(
|
||||
scope: DeferScope,
|
||||
tensorHandle: COpaquePointer
|
||||
) : TorchTensor<Long>(scope, tensorHandle) {
|
||||
override fun item(): Long = get_item_long(tensorHandle)
|
||||
override fun get(index: IntArray): Long = get_long(tensorHandle, index.toCValues())
|
||||
override fun set(index: IntArray, value: Long) {
|
||||
set_long(tensorHandle, index.toCValues(), value)
|
||||
}
|
||||
}
|
||||
|
||||
public class TorchTensorInt internal constructor(
|
||||
scope: DeferScope,
|
||||
tensorHandle: COpaquePointer
|
||||
) : TorchTensor<Int>(scope, tensorHandle) {
|
||||
override fun item(): Int = get_item_int(tensorHandle)
|
||||
override fun get(index: IntArray): Int = get_int(tensorHandle, index.toCValues())
|
||||
override fun set(index: IntArray, value: Int) {
|
||||
set_int(tensorHandle, index.toCValues(), value)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
private inline fun indexFromOffset(offset: Int, strides: IntArray, nDim: Int): IntArray {
|
||||
val res = IntArray(nDim)
|
||||
|
@ -42,13 +42,6 @@ public sealed class TorchTensorAlgebra<
|
||||
times_tensor_assign(this.tensorHandle, other.tensorHandle)
|
||||
}
|
||||
|
||||
public operator fun TorchTensorType.div(other: TorchTensorType): TorchTensorType =
|
||||
wrap(div_tensor(this.tensorHandle, other.tensorHandle)!!)
|
||||
|
||||
public operator fun TorchTensorType.divAssign(other: TorchTensorType): Unit {
|
||||
div_tensor_assign(this.tensorHandle, other.tensorHandle)
|
||||
}
|
||||
|
||||
public infix fun TorchTensorType.dot(other: TorchTensorType): TorchTensorType =
|
||||
wrap(matmul(this.tensorHandle, other.tensorHandle)!!)
|
||||
|
||||
@ -94,45 +87,37 @@ public sealed class TorchTensorAlgebra<
|
||||
sum_tensor_assign(tensorHandle)
|
||||
}
|
||||
|
||||
public fun diagEmbed(diagonalEntries: TorchTensorType,
|
||||
offset: Int = 0, dim1: Int = -2, dim2: Int = -1): TorchTensorType =
|
||||
public fun diagEmbed(
|
||||
diagonalEntries: TorchTensorType,
|
||||
offset: Int = 0, dim1: Int = -2, dim2: Int = -1
|
||||
): TorchTensorType =
|
||||
wrap(diag_embed(diagonalEntries.tensorHandle, offset, dim1, dim2)!!)
|
||||
|
||||
public fun TorchTensorType.svd(): Triple<TorchTensorType,TorchTensorType,TorchTensorType> {
|
||||
val U = empty_tensor()!!
|
||||
val V = empty_tensor()!!
|
||||
val S = empty_tensor()!!
|
||||
svd_tensor(this.tensorHandle, U, S, V)
|
||||
return Triple(wrap(U), wrap(S), wrap(V))
|
||||
}
|
||||
|
||||
public fun TorchTensorType.symEig(eigenvectors: Boolean = true): Pair<TorchTensorType, TorchTensorType> {
|
||||
val V = empty_tensor()!!
|
||||
val S = empty_tensor()!!
|
||||
symeig_tensor(this.tensorHandle, S, V, eigenvectors)
|
||||
return Pair(wrap(S), wrap(V))
|
||||
}
|
||||
|
||||
public infix fun TorchTensorType.grad(variable: TorchTensorType): TorchTensorType =
|
||||
wrap(autograd_tensor(this.tensorHandle, variable.tensorHandle)!!)
|
||||
|
||||
public fun TorchTensorType.copy(): TorchTensorType =
|
||||
wrap(tensorHandle = copy_tensor(this.tensorHandle)!!)
|
||||
wrap(copy_tensor(this.tensorHandle)!!)
|
||||
|
||||
public fun TorchTensorType.copyToDevice(device: TorchDevice): TorchTensorType =
|
||||
wrap(tensorHandle = copy_to_device(this.tensorHandle, device.toInt())!!)
|
||||
|
||||
public fun TorchTensorType.detachFromGraph(): TorchTensorType =
|
||||
wrap(tensorHandle = detach_from_graph(this.tensorHandle)!!)
|
||||
wrap(copy_to_device(this.tensorHandle, device.toInt())!!)
|
||||
|
||||
public infix fun TorchTensorType.swap(otherTensor: TorchTensorType): Unit {
|
||||
swap_tensors(this.tensorHandle, otherTensor.tensorHandle)
|
||||
}
|
||||
}
|
||||
|
||||
public sealed class TorchTensorFieldAlgebra<T, TVar : CPrimitiveVar,
|
||||
PrimitiveArrayType, TorchTensorType : TorchTensor<T>>(scope: DeferScope) :
|
||||
TorchTensorAlgebra<T, TVar, PrimitiveArrayType, TorchTensorType>(scope) {
|
||||
|
||||
public abstract fun randNormal(shape: IntArray, device: TorchDevice = TorchDevice.TorchCPU): TorchTensorType
|
||||
public abstract fun randUniform(shape: IntArray, device: TorchDevice = TorchDevice.TorchCPU): TorchTensorType
|
||||
|
||||
public operator fun TorchTensorType.div(other: TorchTensorType): TorchTensorType =
|
||||
wrap(div_tensor(this.tensorHandle, other.tensorHandle)!!)
|
||||
|
||||
public operator fun TorchTensorType.divAssign(other: TorchTensorType): Unit {
|
||||
div_tensor_assign(this.tensorHandle, other.tensorHandle)
|
||||
}
|
||||
|
||||
public fun TorchTensorType.exp(): TorchTensorType = wrap(exp_tensor(tensorHandle)!!)
|
||||
public fun TorchTensorType.expAssign(): Unit {
|
||||
exp_tensor_assign(tensorHandle)
|
||||
@ -143,6 +128,35 @@ public sealed class TorchTensorFieldAlgebra<T, TVar : CPrimitiveVar,
|
||||
log_tensor_assign(tensorHandle)
|
||||
}
|
||||
|
||||
public fun TorchTensorType.svd(): Triple<TorchTensorType, TorchTensorType, TorchTensorType> {
|
||||
val U = empty_tensor()!!
|
||||
val V = empty_tensor()!!
|
||||
val S = empty_tensor()!!
|
||||
svd_tensor(this.tensorHandle, U, S, V)
|
||||
return Triple(wrap(U), wrap(S), wrap(V))
|
||||
}
|
||||
|
||||
public fun TorchTensorType.symEig(eigenvectors: Boolean = true): Pair<TorchTensorType, TorchTensorType> {
|
||||
val V = empty_tensor()!!
|
||||
val S = empty_tensor()!!
|
||||
symeig_tensor(this.tensorHandle, S, V, eigenvectors)
|
||||
return Pair(wrap(S), wrap(V))
|
||||
}
|
||||
|
||||
public infix fun TorchTensorType.grad(variable: TorchTensorType): TorchTensorType =
|
||||
wrap(autograd_tensor(this.tensorHandle, variable.tensorHandle)!!)
|
||||
|
||||
public fun TorchTensorType.detachFromGraph(): TorchTensorType =
|
||||
wrap(tensorHandle = detach_from_graph(this.tensorHandle)!!)
|
||||
}
|
||||
|
||||
public sealed class TorchTensorRingAlgebra<T, TVar : CPrimitiveVar,
|
||||
PrimitiveArrayType, TorchTensorType : TorchTensor<T>>(scope: DeferScope) :
|
||||
TorchTensorAlgebra<T, TVar, PrimitiveArrayType, TorchTensorType>(scope) {
|
||||
public abstract fun randIntegral(
|
||||
low: T, high: T, shape: IntArray,
|
||||
device: TorchDevice = TorchDevice.TorchCPU
|
||||
): TorchTensorType
|
||||
}
|
||||
|
||||
public class TorchTensorRealAlgebra(scope: DeferScope) :
|
||||
@ -153,33 +167,11 @@ public class TorchTensorRealAlgebra(scope: DeferScope) :
|
||||
override fun TorchTensorReal.copyToArray(): DoubleArray =
|
||||
this.elements().map { it.second }.toList().toDoubleArray()
|
||||
|
||||
override fun copyFromArray(
|
||||
array: DoubleArray,
|
||||
shape: IntArray,
|
||||
device: TorchDevice
|
||||
): TorchTensorReal =
|
||||
TorchTensorReal(
|
||||
scope = scope,
|
||||
tensorHandle = from_blob_double(
|
||||
array.toCValues(),
|
||||
shape.toCValues(),
|
||||
shape.size,
|
||||
device.toInt(),
|
||||
true
|
||||
)!!
|
||||
)
|
||||
override fun copyFromArray(array: DoubleArray, shape: IntArray, device: TorchDevice): TorchTensorReal =
|
||||
wrap(from_blob_double(array.toCValues(), shape.toCValues(), shape.size, device.toInt(), true)!!)
|
||||
|
||||
override fun fromBlob(arrayBlob: CPointer<DoubleVar>, shape: IntArray): TorchTensorReal =
|
||||
TorchTensorReal(
|
||||
scope = scope,
|
||||
tensorHandle = from_blob_double(
|
||||
arrayBlob,
|
||||
shape.toCValues(),
|
||||
shape.size,
|
||||
TorchDevice.TorchCPU.toInt(),
|
||||
false
|
||||
)!!
|
||||
)
|
||||
wrap(from_blob_double(arrayBlob, shape.toCValues(), shape.size, TorchDevice.TorchCPU.toInt(), false)!!)
|
||||
|
||||
override fun TorchTensorReal.getData(): CPointer<DoubleVar> {
|
||||
require(this.device is TorchDevice.TorchCPU) {
|
||||
@ -188,65 +180,231 @@ public class TorchTensorRealAlgebra(scope: DeferScope) :
|
||||
return get_data_double(this.tensorHandle)!!
|
||||
}
|
||||
|
||||
override fun randNormal(shape: IntArray, device: TorchDevice): TorchTensorReal = TorchTensorReal(
|
||||
scope = scope,
|
||||
tensorHandle = randn_double(shape.toCValues(), shape.size, device.toInt())!!
|
||||
)
|
||||
override fun randNormal(shape: IntArray, device: TorchDevice): TorchTensorReal =
|
||||
wrap(randn_double(shape.toCValues(), shape.size, device.toInt())!!)
|
||||
|
||||
override fun randUniform(shape: IntArray, device: TorchDevice): TorchTensorReal = TorchTensorReal(
|
||||
scope = scope,
|
||||
tensorHandle = rand_double(shape.toCValues(), shape.size, device.toInt())!!
|
||||
)
|
||||
override fun randUniform(shape: IntArray, device: TorchDevice): TorchTensorReal =
|
||||
wrap(rand_double(shape.toCValues(), shape.size, device.toInt())!!)
|
||||
|
||||
override operator fun Double.plus(other: TorchTensorReal): TorchTensorReal = TorchTensorReal(
|
||||
scope = scope,
|
||||
tensorHandle = plus_double(this, other.tensorHandle)!!
|
||||
)
|
||||
override operator fun Double.plus(other: TorchTensorReal): TorchTensorReal =
|
||||
wrap(plus_double(this, other.tensorHandle)!!)
|
||||
|
||||
override fun TorchTensorReal.plus(value: Double): TorchTensorReal = TorchTensorReal(
|
||||
scope = scope,
|
||||
tensorHandle = plus_double(value, this.tensorHandle)!!
|
||||
)
|
||||
override fun TorchTensorReal.plus(value: Double): TorchTensorReal =
|
||||
wrap(plus_double(value, this.tensorHandle)!!)
|
||||
|
||||
override fun TorchTensorReal.plusAssign(value: Double): Unit {
|
||||
plus_double_assign(value, this.tensorHandle)
|
||||
}
|
||||
|
||||
override operator fun Double.minus(other: TorchTensorReal): TorchTensorReal = TorchTensorReal(
|
||||
scope = scope,
|
||||
tensorHandle = plus_double(-this, other.tensorHandle)!!
|
||||
)
|
||||
override operator fun Double.minus(other: TorchTensorReal): TorchTensorReal =
|
||||
wrap(plus_double(-this, other.tensorHandle)!!)
|
||||
|
||||
override fun TorchTensorReal.minus(value: Double): TorchTensorReal = TorchTensorReal(
|
||||
scope = scope,
|
||||
tensorHandle = plus_double(-value, this.tensorHandle)!!
|
||||
)
|
||||
override fun TorchTensorReal.minus(value: Double): TorchTensorReal =
|
||||
wrap(plus_double(-value, this.tensorHandle)!!)
|
||||
|
||||
override fun TorchTensorReal.minusAssign(value: Double): Unit {
|
||||
plus_double_assign(-value, this.tensorHandle)
|
||||
}
|
||||
|
||||
override operator fun Double.times(other: TorchTensorReal): TorchTensorReal = TorchTensorReal(
|
||||
scope = scope,
|
||||
tensorHandle = times_double(this, other.tensorHandle)!!
|
||||
)
|
||||
override operator fun Double.times(other: TorchTensorReal): TorchTensorReal =
|
||||
wrap(times_double(this, other.tensorHandle)!!)
|
||||
|
||||
override fun TorchTensorReal.times(value: Double): TorchTensorReal = TorchTensorReal(
|
||||
scope = scope,
|
||||
tensorHandle = times_double(value, this.tensorHandle)!!
|
||||
)
|
||||
override fun TorchTensorReal.times(value: Double): TorchTensorReal =
|
||||
wrap(times_double(value, this.tensorHandle)!!)
|
||||
|
||||
override fun TorchTensorReal.timesAssign(value: Double): Unit {
|
||||
times_double_assign(value, this.tensorHandle)
|
||||
}
|
||||
|
||||
override fun full(value: Double, shape: IntArray, device: TorchDevice): TorchTensorReal =
|
||||
wrap(full_double(value, shape.toCValues(), shape.size, device.toInt())!!)
|
||||
}
|
||||
|
||||
override fun full(value: Double, shape: IntArray, device: TorchDevice): TorchTensorReal = TorchTensorReal(
|
||||
scope = scope,
|
||||
tensorHandle = full_double(value, shape.toCValues(), shape.size, device.toInt())!!
|
||||
)
|
||||
public class TorchTensorFloatAlgebra(scope: DeferScope) :
|
||||
TorchTensorFieldAlgebra<Float, FloatVar, FloatArray, TorchTensorFloat>(scope) {
|
||||
override fun wrap(tensorHandle: COpaquePointer): TorchTensorFloat =
|
||||
TorchTensorFloat(scope = scope, tensorHandle = tensorHandle)
|
||||
|
||||
override fun TorchTensorFloat.copyToArray(): FloatArray =
|
||||
this.elements().map { it.second }.toList().toFloatArray()
|
||||
|
||||
override fun copyFromArray(array: FloatArray, shape: IntArray, device: TorchDevice): TorchTensorFloat =
|
||||
wrap(from_blob_float(array.toCValues(), shape.toCValues(), shape.size, device.toInt(), true)!!)
|
||||
|
||||
override fun fromBlob(arrayBlob: CPointer<FloatVar>, shape: IntArray): TorchTensorFloat =
|
||||
wrap(from_blob_float(arrayBlob, shape.toCValues(), shape.size, TorchDevice.TorchCPU.toInt(), false)!!)
|
||||
|
||||
override fun TorchTensorFloat.getData(): CPointer<FloatVar> {
|
||||
require(this.device is TorchDevice.TorchCPU) {
|
||||
"This tensor is not on available on CPU"
|
||||
}
|
||||
return get_data_float(this.tensorHandle)!!
|
||||
}
|
||||
|
||||
override fun randNormal(shape: IntArray, device: TorchDevice): TorchTensorFloat =
|
||||
wrap(randn_float(shape.toCValues(), shape.size, device.toInt())!!)
|
||||
|
||||
override fun randUniform(shape: IntArray, device: TorchDevice): TorchTensorFloat =
|
||||
wrap(rand_float(shape.toCValues(), shape.size, device.toInt())!!)
|
||||
|
||||
override operator fun Float.plus(other: TorchTensorFloat): TorchTensorFloat =
|
||||
wrap(plus_float(this, other.tensorHandle)!!)
|
||||
|
||||
override fun TorchTensorFloat.plus(value: Float): TorchTensorFloat =
|
||||
wrap(plus_float(value, this.tensorHandle)!!)
|
||||
|
||||
override fun TorchTensorFloat.plusAssign(value: Float): Unit {
|
||||
plus_float_assign(value, this.tensorHandle)
|
||||
}
|
||||
|
||||
override operator fun Float.minus(other: TorchTensorFloat): TorchTensorFloat =
|
||||
wrap(plus_float(-this, other.tensorHandle)!!)
|
||||
|
||||
override fun TorchTensorFloat.minus(value: Float): TorchTensorFloat =
|
||||
wrap(plus_float(-value, this.tensorHandle)!!)
|
||||
|
||||
override fun TorchTensorFloat.minusAssign(value: Float): Unit {
|
||||
plus_float_assign(-value, this.tensorHandle)
|
||||
}
|
||||
|
||||
override operator fun Float.times(other: TorchTensorFloat): TorchTensorFloat =
|
||||
wrap(times_float(this, other.tensorHandle)!!)
|
||||
|
||||
override fun TorchTensorFloat.times(value: Float): TorchTensorFloat =
|
||||
wrap(times_float(value, this.tensorHandle)!!)
|
||||
|
||||
override fun TorchTensorFloat.timesAssign(value: Float): Unit {
|
||||
times_float_assign(value, this.tensorHandle)
|
||||
}
|
||||
|
||||
override fun full(value: Float, shape: IntArray, device: TorchDevice): TorchTensorFloat =
|
||||
wrap(full_float(value, shape.toCValues(), shape.size, device.toInt())!!)
|
||||
}
|
||||
|
||||
public class TorchTensorLongAlgebra(scope: DeferScope) :
|
||||
TorchTensorRingAlgebra<Long, LongVar, LongArray, TorchTensorLong>(scope) {
|
||||
override fun wrap(tensorHandle: COpaquePointer): TorchTensorLong =
|
||||
TorchTensorLong(scope = scope, tensorHandle = tensorHandle)
|
||||
|
||||
override fun TorchTensorLong.copyToArray(): LongArray =
|
||||
this.elements().map { it.second }.toList().toLongArray()
|
||||
|
||||
override fun copyFromArray(array: LongArray, shape: IntArray, device: TorchDevice): TorchTensorLong =
|
||||
wrap(from_blob_long(array.toCValues(), shape.toCValues(), shape.size, device.toInt(), true)!!)
|
||||
|
||||
override fun fromBlob(arrayBlob: CPointer<LongVar>, shape: IntArray): TorchTensorLong =
|
||||
wrap(from_blob_long(arrayBlob, shape.toCValues(), shape.size, TorchDevice.TorchCPU.toInt(), false)!!)
|
||||
|
||||
override fun TorchTensorLong.getData(): CPointer<LongVar> {
|
||||
require(this.device is TorchDevice.TorchCPU) {
|
||||
"This tensor is not on available on CPU"
|
||||
}
|
||||
return get_data_long(this.tensorHandle)!!
|
||||
}
|
||||
|
||||
override fun randIntegral(low: Long, high: Long, shape: IntArray, device: TorchDevice): TorchTensorLong =
|
||||
wrap(randint_long(low, high, shape.toCValues(), shape.size, device.toInt())!!)
|
||||
|
||||
override operator fun Long.plus(other: TorchTensorLong): TorchTensorLong =
|
||||
wrap(plus_long(this, other.tensorHandle)!!)
|
||||
|
||||
override fun TorchTensorLong.plus(value: Long): TorchTensorLong =
|
||||
wrap(plus_long(value, this.tensorHandle)!!)
|
||||
|
||||
override fun TorchTensorLong.plusAssign(value: Long): Unit {
|
||||
plus_long_assign(value, this.tensorHandle)
|
||||
}
|
||||
|
||||
override operator fun Long.minus(other: TorchTensorLong): TorchTensorLong =
|
||||
wrap(plus_long(-this, other.tensorHandle)!!)
|
||||
|
||||
override fun TorchTensorLong.minus(value: Long): TorchTensorLong =
|
||||
wrap(plus_long(-value, this.tensorHandle)!!)
|
||||
|
||||
override fun TorchTensorLong.minusAssign(value: Long): Unit {
|
||||
plus_long_assign(-value, this.tensorHandle)
|
||||
}
|
||||
|
||||
override operator fun Long.times(other: TorchTensorLong): TorchTensorLong =
|
||||
wrap(times_long(this, other.tensorHandle)!!)
|
||||
|
||||
override fun TorchTensorLong.times(value: Long): TorchTensorLong =
|
||||
wrap(times_long(value, this.tensorHandle)!!)
|
||||
|
||||
override fun TorchTensorLong.timesAssign(value: Long): Unit {
|
||||
times_long_assign(value, this.tensorHandle)
|
||||
}
|
||||
|
||||
override fun full(value: Long, shape: IntArray, device: TorchDevice): TorchTensorLong =
|
||||
wrap(full_long(value, shape.toCValues(), shape.size, device.toInt())!!)
|
||||
}
|
||||
|
||||
public class TorchTensorIntAlgebra(scope: DeferScope) :
|
||||
TorchTensorRingAlgebra<Int, IntVar, IntArray, TorchTensorInt>(scope) {
|
||||
override fun wrap(tensorHandle: COpaquePointer): TorchTensorInt =
|
||||
TorchTensorInt(scope = scope, tensorHandle = tensorHandle)
|
||||
|
||||
override fun TorchTensorInt.copyToArray(): IntArray =
|
||||
this.elements().map { it.second }.toList().toIntArray()
|
||||
|
||||
override fun copyFromArray(array: IntArray, shape: IntArray, device: TorchDevice): TorchTensorInt =
|
||||
wrap(from_blob_int(array.toCValues(), shape.toCValues(), shape.size, device.toInt(), true)!!)
|
||||
|
||||
override fun fromBlob(arrayBlob: CPointer<IntVar>, shape: IntArray): TorchTensorInt =
|
||||
wrap(from_blob_int(arrayBlob, shape.toCValues(), shape.size, TorchDevice.TorchCPU.toInt(), false)!!)
|
||||
|
||||
override fun TorchTensorInt.getData(): CPointer<IntVar> {
|
||||
require(this.device is TorchDevice.TorchCPU) {
|
||||
"This tensor is not on available on CPU"
|
||||
}
|
||||
return get_data_int(this.tensorHandle)!!
|
||||
}
|
||||
|
||||
override fun randIntegral(low: Int, high: Int, shape: IntArray, device: TorchDevice): TorchTensorInt =
|
||||
wrap(randint_int(low, high, shape.toCValues(), shape.size, device.toInt())!!)
|
||||
|
||||
override operator fun Int.plus(other: TorchTensorInt): TorchTensorInt =
|
||||
wrap(plus_int(this, other.tensorHandle)!!)
|
||||
|
||||
override fun TorchTensorInt.plus(value: Int): TorchTensorInt =
|
||||
wrap(plus_int(value, this.tensorHandle)!!)
|
||||
|
||||
override fun TorchTensorInt.plusAssign(value: Int): Unit {
|
||||
plus_int_assign(value, this.tensorHandle)
|
||||
}
|
||||
|
||||
override operator fun Int.minus(other: TorchTensorInt): TorchTensorInt =
|
||||
wrap(plus_int(-this, other.tensorHandle)!!)
|
||||
|
||||
override fun TorchTensorInt.minus(value: Int): TorchTensorInt =
|
||||
wrap(plus_int(-value, this.tensorHandle)!!)
|
||||
|
||||
override fun TorchTensorInt.minusAssign(value: Int): Unit {
|
||||
plus_int_assign(-value, this.tensorHandle)
|
||||
}
|
||||
|
||||
override operator fun Int.times(other: TorchTensorInt): TorchTensorInt =
|
||||
wrap(times_int(this, other.tensorHandle)!!)
|
||||
|
||||
override fun TorchTensorInt.times(value: Int): TorchTensorInt =
|
||||
wrap(times_int(value, this.tensorHandle)!!)
|
||||
|
||||
override fun TorchTensorInt.timesAssign(value: Int): Unit {
|
||||
times_int_assign(value, this.tensorHandle)
|
||||
}
|
||||
|
||||
override fun full(value: Int, shape: IntArray, device: TorchDevice): TorchTensorInt =
|
||||
wrap(full_int(value, shape.toCValues(), shape.size, device.toInt())!!)
|
||||
}
|
||||
|
||||
public inline fun <R> TorchTensorRealAlgebra(block: TorchTensorRealAlgebra.() -> R): R =
|
||||
memScoped { TorchTensorRealAlgebra(this).block() }
|
||||
memScoped { TorchTensorRealAlgebra(this).block() }
|
||||
|
||||
public inline fun <R> TorchTensorFloatAlgebra(block: TorchTensorFloatAlgebra.() -> R): R =
|
||||
memScoped { TorchTensorFloatAlgebra(this).block() }
|
||||
|
||||
public inline fun <R> TorchTensorLongAlgebra(block: TorchTensorLongAlgebra.() -> R): R =
|
||||
memScoped { TorchTensorLongAlgebra(this).block() }
|
||||
|
||||
public inline fun <R> TorchTensorIntAlgebra(block: TorchTensorIntAlgebra.() -> R): R =
|
||||
memScoped { TorchTensorIntAlgebra(this).block() }
|
@ -3,14 +3,14 @@ package kscience.kmath.torch
|
||||
import kotlin.test.Test
|
||||
import kotlin.time.measureTime
|
||||
|
||||
internal fun benchmarkingDoubleMatrixMultiplication(
|
||||
internal fun benchmarkingMatMultDouble(
|
||||
scale: Int,
|
||||
numWarmUp: Int,
|
||||
numIter: Int,
|
||||
device: TorchDevice = TorchDevice.TorchCPU
|
||||
): Unit {
|
||||
TorchTensorRealAlgebra {
|
||||
println("Benchmarking $scale x $scale matrices over Double's: ")
|
||||
println("Benchmarking $scale x $scale matrices over Double's on $device: ")
|
||||
setSeed(SEED)
|
||||
val lhs = randNormal(shape = intArrayOf(scale, scale), device = device)
|
||||
val rhs = randNormal(shape = intArrayOf(scale, scale), device = device)
|
||||
@ -20,15 +20,18 @@ internal fun benchmarkingDoubleMatrixMultiplication(
|
||||
}
|
||||
}
|
||||
|
||||
class BenchmarkMatrixMultiplicationDouble {
|
||||
internal class BenchmarkMatMultDouble {
|
||||
|
||||
@Test
|
||||
fun benchmarkMatrixMultiplication20() = benchmarkingDoubleMatrixMultiplication(20, 10, 100000)
|
||||
fun benchmarkMatMult20() =
|
||||
benchmarkingMatMultDouble(20, 10, 100000)
|
||||
|
||||
@Test
|
||||
fun benchmarkMatrixMultiplication200() = benchmarkingDoubleMatrixMultiplication(200, 10, 10000)
|
||||
fun benchmarkMatMult200() =
|
||||
benchmarkingMatMultDouble(200, 10, 10000)
|
||||
|
||||
@Test
|
||||
fun benchmarkMatrixMultiplication2000() = benchmarkingDoubleMatrixMultiplication(2000, 3, 20)
|
||||
fun benchmarkMatMult2000() =
|
||||
benchmarkingMatMultDouble(2000, 3, 20)
|
||||
|
||||
}
|
@ -0,0 +1,36 @@
|
||||
package kscience.kmath.torch
|
||||
|
||||
import kotlin.test.Test
|
||||
import kotlin.time.measureTime
|
||||
|
||||
internal fun benchmarkingMatMultFloat(
|
||||
scale: Int,
|
||||
numWarmUp: Int,
|
||||
numIter: Int,
|
||||
device: TorchDevice = TorchDevice.TorchCPU
|
||||
): Unit {
|
||||
TorchTensorFloatAlgebra {
|
||||
println("Benchmarking $scale x $scale matrices over Float's on $device: ")
|
||||
setSeed(SEED)
|
||||
val lhs = randNormal(shape = intArrayOf(scale, scale), device = device)
|
||||
val rhs = randNormal(shape = intArrayOf(scale, scale), device = device)
|
||||
repeat(numWarmUp) { lhs dotAssign rhs }
|
||||
val measuredTime = measureTime { repeat(numIter) { lhs dotAssign rhs } }
|
||||
println(" ${measuredTime / numIter} p.o. with $numIter iterations")
|
||||
}
|
||||
}
|
||||
|
||||
internal class BenchmarkMatMultFloat {
|
||||
|
||||
@Test
|
||||
fun benchmarkMatMult20() =
|
||||
benchmarkingMatMultFloat(20, 10, 100000)
|
||||
|
||||
@Test
|
||||
fun benchmarkMatMult200() =
|
||||
benchmarkingMatMultFloat(200, 10, 10000)
|
||||
|
||||
@Test
|
||||
fun benchmarkMatMult2000() =
|
||||
benchmarkingMatMultFloat(2000, 3, 20)
|
||||
}
|
@ -44,7 +44,6 @@ internal fun testingBatchedAutoGrad(bath: IntArray,
|
||||
|
||||
val error = (gradientAtX - expectedGradientAtX).abs().sum().value()
|
||||
assertTrue(error < TOLERANCE)
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -49,4 +49,15 @@ class TestTorchTensor {
|
||||
val detachedTensor = tensor.detachFromGraph()
|
||||
assertTrue(!detachedTensor.requiresGrad)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testTypeMoving() = TorchTensorFloatAlgebra {
|
||||
val tensorInt = copyFromArray(floatArrayOf(1f,2f,3f), intArrayOf(3)).copyToInt()
|
||||
TorchTensorIntAlgebra {
|
||||
val temporalTensor = copyFromArray(intArrayOf(4,5,6),intArrayOf(3))
|
||||
tensorInt swap temporalTensor
|
||||
assertTrue(temporalTensor.copyToArray() contentEquals intArrayOf(1,2,3))
|
||||
}
|
||||
assertTrue(tensorInt.copyToFloat().copyToArray() contentEquals floatArrayOf(4f,5f,6f))
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user