Hessian implementation

This commit is contained in:
rgrit91 2021-01-14 22:28:47 +00:00
parent 7d25aa2834
commit fbb414731b
6 changed files with 62 additions and 16 deletions

View File

@ -147,7 +147,8 @@ extern "C"
bool requires_grad(TorchTensorHandle tensor_handle); bool requires_grad(TorchTensorHandle tensor_handle);
void requires_grad_(TorchTensorHandle tensor_handle, bool status); void requires_grad_(TorchTensorHandle tensor_handle, bool status);
TorchTensorHandle detach_from_graph(TorchTensorHandle tensor_handle); TorchTensorHandle detach_from_graph(TorchTensorHandle tensor_handle);
TorchTensorHandle autograd_tensor(TorchTensorHandle value, TorchTensorHandle variable); TorchTensorHandle autograd_tensor(TorchTensorHandle value, TorchTensorHandle variable, bool retain_graph);
TorchTensorHandle autohess_tensor(TorchTensorHandle value, TorchTensorHandle variable);
#ifdef __cplusplus #ifdef __cplusplus
} }

View File

@ -33,6 +33,16 @@ namespace ctorch
return *static_cast<torch::Tensor *>(tensor_handle); return *static_cast<torch::Tensor *>(tensor_handle);
} }
inline char *tensor_to_char(const torch::Tensor &tensor)
{
std::stringstream bufrep;
bufrep << tensor;
auto rep = bufrep.str();
char *crep = (char *)malloc(rep.length() + 1);
std::strcpy(crep, rep.c_str());
return crep;
}
inline int device_to_int(const torch::Tensor &tensor) inline int device_to_int(const torch::Tensor &tensor)
{ {
return (tensor.device().type() == torch::kCPU) ? 0 : 1 + tensor.device().index(); return (tensor.device().type() == torch::kCPU) ? 0 : 1 + tensor.device().index();
@ -104,4 +114,27 @@ namespace ctorch
return torch::full(shape, value, torch::TensorOptions().dtype(dtype<Dtype>()).layout(torch::kStrided).device(device)); return torch::full(shape, value, torch::TensorOptions().dtype(dtype<Dtype>()).layout(torch::kStrided).device(device));
} }
inline torch::Tensor hessian(const torch::Tensor &value, const torch::Tensor &variable)
{
auto nelem = variable.numel();
auto hess = value.new_zeros({nelem, nelem});
auto grad = torch::autograd::grad({value}, {variable}, {}, torch::nullopt, true)[0].view(nelem);
int i = 0;
for (int j = 0; j < nelem; j++)
{
auto row = grad[j].requires_grad()
? torch::autograd::grad({grad[i]}, {variable}, {}, true, true, true)[0].view(nelem).slice(0, j, nelem)
: grad[j].new_zeros(nelem - j);
hess[i].slice(0, i, nelem).add_(row.type_as(hess));
i++;
}
auto ndim = variable.dim();
auto sizes = variable.sizes().data();
auto shape = std::vector<int64_t>(ndim);
shape.assign(sizes, sizes + ndim);
shape.reserve(2 * ndim);
std::copy_n(shape.begin(), ndim, std::back_inserter(shape));
return (hess + torch::triu(hess, 1).t()).view(shape);
}
} // namespace ctorch } // namespace ctorch

View File

@ -77,12 +77,7 @@ void swap_tensors(TorchTensorHandle lhs_handle, TorchTensorHandle rhs_handle)
char *tensor_to_string(TorchTensorHandle tensor_handle) char *tensor_to_string(TorchTensorHandle tensor_handle)
{ {
std::stringstream bufrep; return ctorch::tensor_to_char(ctorch::cast(tensor_handle));
bufrep << ctorch::cast(tensor_handle);
auto rep = bufrep.str();
char *crep = (char *)malloc(rep.length() + 1);
std::strcpy(crep, rep.c_str());
return crep;
} }
void dispose_char(char *ptr) void dispose_char(char *ptr)
{ {
@ -456,7 +451,11 @@ TorchTensorHandle detach_from_graph(TorchTensorHandle tensor_handle)
{ {
return new torch::Tensor(ctorch::cast(tensor_handle).detach()); return new torch::Tensor(ctorch::cast(tensor_handle).detach());
} }
TorchTensorHandle autograd_tensor(TorchTensorHandle value, TorchTensorHandle variable) TorchTensorHandle autograd_tensor(TorchTensorHandle value, TorchTensorHandle variable, bool retain_graph)
{ {
return new torch::Tensor(torch::autograd::grad({ctorch::cast(value)}, {ctorch::cast(variable)})[0]); return new torch::Tensor(torch::autograd::grad({ctorch::cast(value)}, {ctorch::cast(variable)}, {}, retain_graph)[0]);
}
TorchTensorHandle autohess_tensor(TorchTensorHandle value, TorchTensorHandle variable)
{
return new torch::Tensor(ctorch::hessian(ctorch::cast(value), ctorch::cast(variable)));
} }

View File

@ -42,10 +42,12 @@ public sealed class TorchTensor<T> constructor(
return indices.map { it to get(it) } return indices.map { it to get(it) }
} }
public fun value(): T { internal inline fun isValue() = check(dimension == 0) {
check(dimension == 0) {
"This tensor has shape ${shape.toList()}" "This tensor has shape ${shape.toList()}"
} }
public fun value(): T {
isValue()
return item() return item()
} }

View File

@ -157,8 +157,16 @@ public sealed class TorchTensorFieldAlgebra<T, TVar : CPrimitiveVar,
return Pair(wrap(S), wrap(V)) return Pair(wrap(S), wrap(V))
} }
public fun TorchTensorType.grad(variable: TorchTensorType, retainGraph: Boolean=false): TorchTensorType {
this.isValue()
return wrap(autograd_tensor(this.tensorHandle, variable.tensorHandle, retainGraph)!!)
}
public infix fun TorchTensorType.grad(variable: TorchTensorType): TorchTensorType = public infix fun TorchTensorType.grad(variable: TorchTensorType): TorchTensorType =
wrap(autograd_tensor(this.tensorHandle, variable.tensorHandle)!!) this.grad(variable, false)
public infix fun TorchTensorType.hess(variable: TorchTensorType): TorchTensorType {
this.isValue()
return wrap(autohess_tensor(this.tensorHandle, variable.tensorHandle)!!)
}
public fun TorchTensorType.detachFromGraph(): TorchTensorType = public fun TorchTensorType.detachFromGraph(): TorchTensorType =
wrap(tensorHandle = detach_from_graph(this.tensorHandle)!!) wrap(tensorHandle = detach_from_graph(this.tensorHandle)!!)
@ -312,7 +320,7 @@ public class TorchTensorLongAlgebra(scope: DeferScope) :
wrap(from_blob_long(arrayBlob, shape.toCValues(), shape.size, TorchDevice.TorchCPU.toInt(), false)!!) wrap(from_blob_long(arrayBlob, shape.toCValues(), shape.size, TorchDevice.TorchCPU.toInt(), false)!!)
override fun TorchTensorLong.getData(): CPointer<LongVar> { override fun TorchTensorLong.getData(): CPointer<LongVar> {
require(this.device is TorchDevice.TorchCPU) { check(this.device is TorchDevice.TorchCPU) {
"This tensor is not on available on CPU" "This tensor is not on available on CPU"
} }
return get_data_long(this.tensorHandle)!! return get_data_long(this.tensorHandle)!!

View File

@ -14,10 +14,12 @@ internal fun testingAutoGrad(dim: Int, device: TorchDevice = TorchDevice.TorchCP
val expressionAtX = val expressionAtX =
0.5 * (tensorX dot (tensorSigma dot tensorX)) + (tensorMu dot tensorX) + 25.9 0.5 * (tensorX dot (tensorSigma dot tensorX)) + (tensorMu dot tensorX) + 25.9
val gradientAtX = expressionAtX grad tensorX val gradientAtX = expressionAtX.grad(tensorX, retainGraph = true)
val hessianAtX = expressionAtX hess tensorX
val expectedGradientAtX = (tensorSigma dot tensorX) + tensorMu val expectedGradientAtX = (tensorSigma dot tensorX) + tensorMu
val error = (gradientAtX - expectedGradientAtX).abs().sum().value() val error = (gradientAtX - expectedGradientAtX).abs().sum().value() +
(hessianAtX - tensorSigma).abs().sum().value()
assertTrue(error < TOLERANCE) assertTrue(error < TOLERANCE)
} }
} }
@ -47,6 +49,7 @@ internal fun testingBatchedAutoGrad(bath: IntArray,
} }
} }
internal class TestAutograd { internal class TestAutograd {
@Test @Test
fun testAutoGrad() = testingAutoGrad(dim = 100) fun testAutoGrad() = testingAutoGrad(dim = 100)