Autograd support drafting

This commit is contained in:
rgrit91 2021-01-10 16:24:57 +00:00
parent 7894799e8e
commit 7105331149
10 changed files with 316 additions and 45 deletions

View File

@ -76,7 +76,7 @@ pluginManagement {
## Usage
Tensors implement the buffer protocol over `MutableNDStructure`. They can only be instantiated through provided factory methods and require scoping:
Tensors are implemented over the `MutableNDStructure`. They can only be instantiated through provided factory methods and require scoping:
```kotlin
TorchTensorRealAlgebra {
@ -94,4 +94,24 @@ TorchTensorRealAlgebra {
println(gpuRealTensor)
}
```
Enjoy a high performance automatic differentiation engine:
```kotlin
TorchTensorRealAlgebra {
val dim = 10
val device = TorchDevice.TorchCPU //or TorchDevice.TorchCUDA(0)
val x = randNormal(shape = intArrayOf(dim), device = device)
// x is the variable
x.requiresGrad = true
val X = randNormal(shape = intArrayOf(dim,dim), device = device)
val Q = X + X.transpose(0,1)
val mu = randNormal(shape = intArrayOf(dim), device = device)
val c = randNormal(shape = IntArray(0), device = device)
// expression to differentiate w.r.t. x
val f = 0.5 * (x dot (Q dot x)) + (mu dot x) + c
// value of the gradient at x
val gradf = f grad x
}
```

View File

@ -29,7 +29,7 @@ extern "C"
float get_item_float(TorchTensorHandle tensor_handle);
long get_item_long(TorchTensorHandle tensor_handle);
int get_item_int(TorchTensorHandle tensor_handle);
int get_dim(TorchTensorHandle tensor_handle);
int get_numel(TorchTensorHandle tensor_handle);
int get_shape_at(TorchTensorHandle tensor_handle, int d);
@ -40,25 +40,52 @@ extern "C"
void dispose_char(char *ptr);
void dispose_tensor(TorchTensorHandle tensor_handle);
double get_double(TorchTensorHandle tensor_handle, int* index);
float get_float(TorchTensorHandle tensor_handle, int* index);
long get_long(TorchTensorHandle tensor_handle, int* index);
int get_int(TorchTensorHandle tensor_handle, int* index);
void set_double(TorchTensorHandle tensor_handle, int* index, double value);
void set_float(TorchTensorHandle tensor_handle, int* index, float value);
void set_long(TorchTensorHandle tensor_handle, int* index, long value);
void set_int(TorchTensorHandle tensor_handle, int* index, int value);
double get_double(TorchTensorHandle tensor_handle, int *index);
float get_float(TorchTensorHandle tensor_handle, int *index);
long get_long(TorchTensorHandle tensor_handle, int *index);
int get_int(TorchTensorHandle tensor_handle, int *index);
void set_double(TorchTensorHandle tensor_handle, int *index, double value);
void set_float(TorchTensorHandle tensor_handle, int *index, float value);
void set_long(TorchTensorHandle tensor_handle, int *index, long value);
void set_int(TorchTensorHandle tensor_handle, int *index, int value);
TorchTensorHandle randn_double(int *shape, int shape_size, int device);
TorchTensorHandle rand_double(int *shape, int shape_size, int device);
TorchTensorHandle randn_float(int *shape, int shape_size, int device);
TorchTensorHandle rand_float(int *shape, int shape_size, int device);
TorchTensorHandle randn_double(int* shape, int shape_size, int device);
TorchTensorHandle rand_double(int* shape, int shape_size, int device);
TorchTensorHandle randn_float(int* shape, int shape_size, int device);
TorchTensorHandle rand_float(int* shape, int shape_size, int device);
TorchTensorHandle full_double(double value, int *shape, int shape_size, int device);
TorchTensorHandle full_float(float value, int *shape, int shape_size, int device);
TorchTensorHandle full_long(long value, int *shape, int shape_size, int device);
TorchTensorHandle full_int(int value, int *shape, int shape_size, int device);
TorchTensorHandle matmul(TorchTensorHandle lhs, TorchTensorHandle rhs);
void matmul_assign(TorchTensorHandle lhs, TorchTensorHandle rhs);
void matmul_right_assign(TorchTensorHandle lhs, TorchTensorHandle rhs);
TorchTensorHandle times_double(double value, TorchTensorHandle other);
TorchTensorHandle times_float(float value, TorchTensorHandle other);
TorchTensorHandle times_long(long value, TorchTensorHandle other);
TorchTensorHandle times_int(int value, TorchTensorHandle other);
void times_assign_double(double value, TorchTensorHandle other);
void times_assign_float(float value, TorchTensorHandle other);
void times_assign_long(long value, TorchTensorHandle other);
void times_assign_int(int value, TorchTensorHandle other);
TorchTensorHandle plus_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs);
void plus_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs);
TorchTensorHandle minus_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs);
void minus_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs);
TorchTensorHandle abs_tensor(TorchTensorHandle tensor_handle);
TorchTensorHandle sum_tensor(TorchTensorHandle tensor_handle);
TorchTensorHandle transpose_tensor(TorchTensorHandle tensor_handle, int i, int j);
bool requires_grad(TorchTensorHandle tensor_handle);
void requires_grad_(TorchTensorHandle tensor_handle, bool status);
void detach_from_graph(TorchTensorHandle tensor_handle);
TorchTensorHandle autograd_tensor(TorchTensorHandle value, TorchTensorHandle variable);
#ifdef __cplusplus
}
#endif

View File

@ -92,4 +92,10 @@ namespace ctorch
return torch::rand(shape, torch::TensorOptions().dtype(dtype<Dtype>()).layout(torch::kStrided).device(device));
}
template <typename Dtype>
inline torch::Tensor full(Dtype value, std::vector<int64_t> shape, torch::Device device)
{
return torch::full(shape, value, torch::TensorOptions().dtype(dtype<Dtype>()).layout(torch::kStrided).device(device));
}
} // namespace ctorch

View File

@ -156,12 +156,111 @@ TorchTensorHandle rand_float(int *shape, int shape_size, int device)
return new torch::Tensor(ctorch::rand<float>(ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device)));
}
TorchTensorHandle full_double(double value, int *shape, int shape_size, int device)
{
return new torch::Tensor(ctorch::full<double>(value, ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device)));
}
TorchTensorHandle full_float(float value, int *shape, int shape_size, int device)
{
return new torch::Tensor(ctorch::full<float>(value, ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device)));
}
TorchTensorHandle full_long(long value, int *shape, int shape_size, int device)
{
return new torch::Tensor(ctorch::full<long>(value, ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device)));
}
TorchTensorHandle full_int(int value, int *shape, int shape_size, int device)
{
return new torch::Tensor(ctorch::full<int>(value, ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device)));
}
TorchTensorHandle matmul(TorchTensorHandle lhs, TorchTensorHandle rhs)
{
return new torch::Tensor(torch::matmul(ctorch::cast(lhs), ctorch::cast(rhs)));
}
void matmul_assign(TorchTensorHandle lhs, TorchTensorHandle rhs)
{
ctorch::cast(lhs) = ctorch::cast(lhs).matmul(ctorch::cast(rhs));
}
}
void matmul_right_assign(TorchTensorHandle lhs, TorchTensorHandle rhs)
{
ctorch::cast(rhs) = ctorch::cast(lhs).matmul(ctorch::cast(rhs));
}
TorchTensorHandle times_double(double value, TorchTensorHandle other)
{
return new torch::Tensor(value * ctorch::cast(other));
}
TorchTensorHandle times_float(float value, TorchTensorHandle other)
{
return new torch::Tensor(value * ctorch::cast(other));
}
TorchTensorHandle times_long(long value, TorchTensorHandle other)
{
return new torch::Tensor(value * ctorch::cast(other));
}
TorchTensorHandle times_int(int value, TorchTensorHandle other)
{
return new torch::Tensor(value * ctorch::cast(other));
}
void times_assign_double(double value, TorchTensorHandle other)
{
ctorch::cast(other) *= value;
}
void times_assign_float(float value, TorchTensorHandle other)
{
ctorch::cast(other) *= value;
}
void times_assign_long(long value, TorchTensorHandle other)
{
ctorch::cast(other) *= value;
}
void times_assign_int(int value, TorchTensorHandle other)
{
ctorch::cast(other) *= value;
}
TorchTensorHandle plus_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs)
{
return new torch::Tensor(ctorch::cast(lhs) + ctorch::cast(rhs));
}
void plus_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs)
{
ctorch::cast(lhs) += ctorch::cast(rhs);
}
TorchTensorHandle minus_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs)
{
return new torch::Tensor(ctorch::cast(lhs) - ctorch::cast(rhs));
}
void minus_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs)
{
ctorch::cast(lhs) -= ctorch::cast(rhs);
}
TorchTensorHandle abs_tensor(TorchTensorHandle tensor_handle)
{
return new torch::Tensor(ctorch::cast(tensor_handle).abs());
}
TorchTensorHandle sum_tensor(TorchTensorHandle tensor_handle)
{
return new torch::Tensor(ctorch::cast(tensor_handle).sum());
}
TorchTensorHandle transpose_tensor(TorchTensorHandle tensor_handle, int i, int j)
{
return new torch::Tensor(ctorch::cast(tensor_handle).transpose(i,j));
}
bool requires_grad(TorchTensorHandle tensor_handle)
{
return ctorch::cast(tensor_handle).requires_grad();
}
void requires_grad_(TorchTensorHandle tensor_handle, bool status)
{
ctorch::cast(tensor_handle).requires_grad_(status);
}
void detach_from_graph(TorchTensorHandle tensor_handle)
{
ctorch::cast(tensor_handle).detach();
}
TorchTensorHandle autograd_tensor(TorchTensorHandle value, TorchTensorHandle variable)
{
return new torch::Tensor(torch::autograd::grad({ctorch::cast(value)}, {ctorch::cast(variable)})[0]);
}

View File

@ -6,8 +6,17 @@ import kotlin.test.*
class TestTorchTensorAlgebraGPU {
@Test
fun testScalarProduct() = testingScalarProduct(device = TorchDevice.TorchCUDA(0))
fun testScalarProduct() =
testingScalarProduct(device = TorchDevice.TorchCUDA(0))
@Test
fun testMatrixMultiplication() = testingMatrixMultiplication(device = TorchDevice.TorchCUDA(0))
fun testMatrixMultiplication() =
testingMatrixMultiplication(device = TorchDevice.TorchCUDA(0))
@Test
fun testLinearStructure() =
testingLinearStructure(device = TorchDevice.TorchCUDA(0))
@Test
fun testAutoGrad() = testingAutoGrad(dim = 3, device = TorchDevice.TorchCUDA(0))
}

View File

@ -13,21 +13,4 @@ internal class TestUtilsGPU {
@Test
fun testSetSeed() = testingSetSeed(TorchDevice.TorchCUDA(0))
@Test
fun testReadmeFactory() = TorchTensorRealAlgebra {
val realTensor: TorchTensorReal = copyFromArray(
array = (1..10).map { it + 50.0 }.toList().toDoubleArray(),
shape = intArrayOf(2,5)
)
println(realTensor)
val gpuRealTensor: TorchTensorReal = copyFromArray(
array = (1..8).map { it * 2.5 }.toList().toDoubleArray(),
shape = intArrayOf(2, 2, 2),
device = TorchDevice.TorchCUDA(0)
)
println(gpuRealTensor)
}
}

View File

@ -62,6 +62,13 @@ public sealed class TorchTensor<T> constructor(
outTensorHandle = copy_to_device(tensorHandle, device.toInt())!!
)
public var requiresGrad: Boolean
get() = requires_grad(tensorHandle)
set(value) = requires_grad_(tensorHandle, value)
public fun detachFromGraph(): Unit {
detach_from_graph(tensorHandle)
}
}
public class TorchTensorReal internal constructor(

View File

@ -14,6 +14,11 @@ public sealed class TorchTensorAlgebra<T, PrimitiveArrayType> constructor(
): TorchTensor<T>
public abstract fun TorchTensor<T>.copyToArray(): PrimitiveArrayType
public abstract fun full(value: T, shape: IntArray, device: TorchDevice): TorchTensor<T>
public abstract operator fun T.times(other: TorchTensor<T>): TorchTensor<T>
public abstract operator fun TorchTensor<T>.times(value: T): TorchTensor<T>
public abstract operator fun TorchTensor<T>.timesAssign(value: T): Unit
public infix fun TorchTensor<T>.dot(other: TorchTensor<T>): TorchTensor<T> =
wrap(matmul(this.tensorHandle, other.tensorHandle)!!)
@ -21,6 +26,34 @@ public sealed class TorchTensorAlgebra<T, PrimitiveArrayType> constructor(
public infix fun TorchTensor<T>.dotAssign(other: TorchTensor<T>): Unit {
matmul_assign(this.tensorHandle, other.tensorHandle)
}
public infix fun TorchTensor<T>.dotRightAssign(other: TorchTensor<T>): Unit {
matmul_right_assign(this.tensorHandle, other.tensorHandle)
}
public operator fun TorchTensor<T>.plus(other: TorchTensor<T>): TorchTensor<T> =
wrap(plus_tensor(this.tensorHandle, other.tensorHandle)!!)
public operator fun TorchTensor<T>.plusAssign(other: TorchTensor<T>): Unit {
plus_tensor_assign(this.tensorHandle, other.tensorHandle)
}
public operator fun TorchTensor<T>.minus(other: TorchTensor<T>): TorchTensor<T> =
wrap(minus_tensor(this.tensorHandle, other.tensorHandle)!!)
public operator fun TorchTensor<T>.minusAssign(other: TorchTensor<T>): Unit {
minus_tensor_assign(this.tensorHandle, other.tensorHandle)
}
public fun TorchTensor<T>.abs(): TorchTensor<T> = wrap(abs_tensor(tensorHandle)!!)
public fun TorchTensor<T>.sum(): TorchTensor<T> = wrap(sum_tensor(tensorHandle)!!)
public fun TorchTensor<T>.transpose(i: Int, j: Int): TorchTensor<T> =
wrap(transpose_tensor(tensorHandle, i , j)!!)
public infix fun TorchTensor<T>.grad(variable: TorchTensor<T>): TorchTensor<T> =
wrap(autograd_tensor(this.tensorHandle, variable.tensorHandle)!!)
}
public sealed class TorchTensorFieldAlgebra<T, PrimitiveArrayType>(scope: DeferScope) :
@ -60,6 +93,27 @@ public class TorchTensorRealAlgebra(scope: DeferScope) : TorchTensorFieldAlgebra
scope = scope,
tensorHandle = rand_double(shape.toCValues(), shape.size, device.toInt())!!
)
override operator fun Double.times(other: TorchTensor<Double>): TorchTensorReal = TorchTensorReal(
scope = scope,
tensorHandle = times_double(this, other.tensorHandle)!!
)
override fun TorchTensor<Double>.times(value: Double): TorchTensorReal = TorchTensorReal(
scope = scope,
tensorHandle = times_double(value, this.tensorHandle)!!
)
override fun TorchTensor<Double>.timesAssign(value: Double): Unit {
times_assign_double(value, this.tensorHandle)
}
override fun full(value: Double, shape: IntArray, device: TorchDevice): TorchTensorReal = TorchTensorReal(
scope = scope,
tensorHandle = full_double(value, shape.toCValues(), shape.size, device.toInt())!!
)
}
public fun <R> TorchTensorRealAlgebra(block: TorchTensorRealAlgebra.() -> R): R =

View File

@ -19,4 +19,14 @@ class TestTorchTensor {
@Test
fun testCopyFromArray() = testingCopyFromArray()
@Test
fun testRequiresGrad() = TorchTensorRealAlgebra {
val tensor = randNormal(intArrayOf(3))
assertTrue(!tensor.requiresGrad)
tensor.requiresGrad = true
assertTrue(tensor.requiresGrad)
tensor.requiresGrad = false
assertTrue(!tensor.requiresGrad)
}
}

View File

@ -4,12 +4,13 @@ import kscience.kmath.linear.RealMatrixContext
import kscience.kmath.operations.invoke
import kscience.kmath.structures.Matrix
import kotlin.math.abs
import kotlin.math.exp
import kotlin.test.*
internal fun testingScalarProduct(device: TorchDevice = TorchDevice.TorchCPU): Unit {
TorchTensorRealAlgebra {
val lhs = randUniform(shape = intArrayOf(10), device = device)
val rhs = randUniform(shape = intArrayOf(10), device = device)
val lhs = randUniform(shape = intArrayOf(3), device = device)
val rhs = randUniform(shape = intArrayOf(3), device = device)
val product = lhs dot rhs
var expected = 0.0
lhs.elements().forEach {
@ -23,27 +24,76 @@ internal fun testingMatrixMultiplication(device: TorchDevice = TorchDevice.Torch
TorchTensorRealAlgebra {
setSeed(SEED)
val lhsTensor = randNormal(shape = intArrayOf(20, 20), device = device)
val rhsTensor = randNormal(shape = intArrayOf(20, 20), device = device)
val lhsTensor = randNormal(shape = intArrayOf(3, 3), device = device)
val rhsTensor = randNormal(shape = intArrayOf(3, 3), device = device)
val product = lhsTensor dot rhsTensor
val expected: Matrix<Double> = RealMatrixContext {
val lhs = produce(20, 20) { i, j -> lhsTensor[intArrayOf(i, j)] }
val rhs = produce(20, 20) { i, j -> rhsTensor[intArrayOf(i, j)] }
val lhs = produce(3, 3) { i, j -> lhsTensor[intArrayOf(i, j)] }
val rhs = produce(3, 3) { i, j -> rhsTensor[intArrayOf(i, j)] }
lhs dot rhs
}
lhsTensor dotAssign rhsTensor
val lhsTensorCopy = lhsTensor.copy()
val rhsTensorCopy = rhsTensor.copy()
lhsTensorCopy dotAssign rhsTensor
lhsTensor dotRightAssign rhsTensorCopy
var error: Double = 0.0
product.elements().forEach {
error += abs(expected[it.first] - it.second) +
abs(expected[it.first] - lhsTensor[it.first])
abs(expected[it.first] - lhsTensorCopy[it.first]) +
abs(expected[it.first] - rhsTensorCopy[it.first])
}
assertTrue(error < TOLERANCE)
}
}
internal fun testingLinearStructure(device: TorchDevice = TorchDevice.TorchCPU): Unit {
TorchTensorRealAlgebra {
val shape = intArrayOf(3)
val tensorA = full(value = -4.5, shape = shape, device = device)
val tensorB = full(value = 10.9, shape = shape, device = device)
val tensorC = full(value = 789.3, shape = shape, device = device)
val result = 15.8 * tensorA - 1.5 * tensorB + 0.02 * tensorC
val expected = copyFromArray(
array = (1..3).map { 15.8 * (-4.5) - 1.5 * 10.9 + 0.02 * 789.3 }.toDoubleArray(),
shape = shape,
device = device
)
val assignResult = full(value = 0.0, shape = shape, device = device)
tensorA *= 15.8
tensorB *= 1.5
tensorC *= 0.02
assignResult += tensorA
assignResult -= tensorB
assignResult += tensorC
val error = (expected - result).abs().sum().value() +
(expected - assignResult).abs().sum().value()
assertTrue(error < TOLERANCE)
}
}
internal fun testingAutoGrad(dim: Int, device: TorchDevice = TorchDevice.TorchCPU): Unit {
TorchTensorRealAlgebra {
setSeed(SEED)
val x = randNormal(shape = intArrayOf(dim), device = device)
x.requiresGrad = true
val X = randNormal(shape = intArrayOf(dim,dim), device = device)
val Q = X + X.transpose(0,1)
val mu = randNormal(shape = intArrayOf(dim), device = device)
val c = randNormal(shape = IntArray(0), device = device)
val f = 0.5 * (x dot (Q dot x)) + (mu dot x) + c
val gradf = f grad x
val error = (gradf - ((Q dot x) + mu)).abs().sum().value()
assertTrue(error < TOLERANCE)
}
}
internal class TestTorchTensorAlgebra {
@Test
@ -52,4 +102,10 @@ internal class TestTorchTensorAlgebra {
@Test
fun testMatrixMultiplication() = testingMatrixMultiplication()
@Test
fun testLinearStructure() = testingLinearStructure()
@Test
fun testAutoGrad() = testingAutoGrad(dim = 100)
}