From fbb414731b5accfd696a51ad54b08db44de62106 Mon Sep 17 00:00:00 2001 From: rgrit91 Date: Thu, 14 Jan 2021 22:28:47 +0000 Subject: [PATCH] Hessian implementation --- kmath-torch/ctorch/include/ctorch.h | 3 +- kmath-torch/ctorch/include/utils.hh | 33 +++++++++++++++++++ kmath-torch/ctorch/src/ctorch.cc | 15 ++++----- .../kscience.kmath.torch/TorchTensor.kt | 8 +++-- .../TorchTensorAlgebra.kt | 12 +++++-- .../kscience/kmath/torch/TestAutograd.kt | 7 ++-- 6 files changed, 62 insertions(+), 16 deletions(-) diff --git a/kmath-torch/ctorch/include/ctorch.h b/kmath-torch/ctorch/include/ctorch.h index 3d6ffc779..f43ac70b9 100644 --- a/kmath-torch/ctorch/include/ctorch.h +++ b/kmath-torch/ctorch/include/ctorch.h @@ -147,7 +147,8 @@ extern "C" bool requires_grad(TorchTensorHandle tensor_handle); void requires_grad_(TorchTensorHandle tensor_handle, bool status); TorchTensorHandle detach_from_graph(TorchTensorHandle tensor_handle); - TorchTensorHandle autograd_tensor(TorchTensorHandle value, TorchTensorHandle variable); + TorchTensorHandle autograd_tensor(TorchTensorHandle value, TorchTensorHandle variable, bool retain_graph); + TorchTensorHandle autohess_tensor(TorchTensorHandle value, TorchTensorHandle variable); #ifdef __cplusplus } diff --git a/kmath-torch/ctorch/include/utils.hh b/kmath-torch/ctorch/include/utils.hh index ef1da9d97..975e8f034 100644 --- a/kmath-torch/ctorch/include/utils.hh +++ b/kmath-torch/ctorch/include/utils.hh @@ -33,6 +33,16 @@ namespace ctorch return *static_cast(tensor_handle); } + inline char *tensor_to_char(const torch::Tensor &tensor) + { + std::stringstream bufrep; + bufrep << tensor; + auto rep = bufrep.str(); + char *crep = (char *)malloc(rep.length() + 1); + std::strcpy(crep, rep.c_str()); + return crep; + } + inline int device_to_int(const torch::Tensor &tensor) { return (tensor.device().type() == torch::kCPU) ? 0 : 1 + tensor.device().index(); @@ -104,4 +114,27 @@ namespace ctorch return torch::full(shape, value, torch::TensorOptions().dtype(dtype()).layout(torch::kStrided).device(device)); } + inline torch::Tensor hessian(const torch::Tensor &value, const torch::Tensor &variable) + { + auto nelem = variable.numel(); + auto hess = value.new_zeros({nelem, nelem}); + auto grad = torch::autograd::grad({value}, {variable}, {}, torch::nullopt, true)[0].view(nelem); + int i = 0; + for (int j = 0; j < nelem; j++) + { + auto row = grad[j].requires_grad() + ? torch::autograd::grad({grad[i]}, {variable}, {}, true, true, true)[0].view(nelem).slice(0, j, nelem) + : grad[j].new_zeros(nelem - j); + hess[i].slice(0, i, nelem).add_(row.type_as(hess)); + i++; + } + auto ndim = variable.dim(); + auto sizes = variable.sizes().data(); + auto shape = std::vector(ndim); + shape.assign(sizes, sizes + ndim); + shape.reserve(2 * ndim); + std::copy_n(shape.begin(), ndim, std::back_inserter(shape)); + return (hess + torch::triu(hess, 1).t()).view(shape); + } + } // namespace ctorch diff --git a/kmath-torch/ctorch/src/ctorch.cc b/kmath-torch/ctorch/src/ctorch.cc index e5e1c55ef..b68be01cd 100644 --- a/kmath-torch/ctorch/src/ctorch.cc +++ b/kmath-torch/ctorch/src/ctorch.cc @@ -77,12 +77,7 @@ void swap_tensors(TorchTensorHandle lhs_handle, TorchTensorHandle rhs_handle) char *tensor_to_string(TorchTensorHandle tensor_handle) { - std::stringstream bufrep; - bufrep << ctorch::cast(tensor_handle); - auto rep = bufrep.str(); - char *crep = (char *)malloc(rep.length() + 1); - std::strcpy(crep, rep.c_str()); - return crep; + return ctorch::tensor_to_char(ctorch::cast(tensor_handle)); } void dispose_char(char *ptr) { @@ -456,7 +451,11 @@ TorchTensorHandle detach_from_graph(TorchTensorHandle tensor_handle) { return new torch::Tensor(ctorch::cast(tensor_handle).detach()); } -TorchTensorHandle autograd_tensor(TorchTensorHandle value, TorchTensorHandle variable) +TorchTensorHandle autograd_tensor(TorchTensorHandle value, TorchTensorHandle variable, bool retain_graph) { - return new torch::Tensor(torch::autograd::grad({ctorch::cast(value)}, {ctorch::cast(variable)})[0]); + return new torch::Tensor(torch::autograd::grad({ctorch::cast(value)}, {ctorch::cast(variable)}, {}, retain_graph)[0]); +} +TorchTensorHandle autohess_tensor(TorchTensorHandle value, TorchTensorHandle variable) +{ + return new torch::Tensor(ctorch::hessian(ctorch::cast(value), ctorch::cast(variable))); } diff --git a/kmath-torch/src/nativeMain/kotlin/kscience.kmath.torch/TorchTensor.kt b/kmath-torch/src/nativeMain/kotlin/kscience.kmath.torch/TorchTensor.kt index bab55de5a..fb7c70abe 100644 --- a/kmath-torch/src/nativeMain/kotlin/kscience.kmath.torch/TorchTensor.kt +++ b/kmath-torch/src/nativeMain/kotlin/kscience.kmath.torch/TorchTensor.kt @@ -42,10 +42,12 @@ public sealed class TorchTensor constructor( return indices.map { it to get(it) } } + internal inline fun isValue() = check(dimension == 0) { + "This tensor has shape ${shape.toList()}" + } + public fun value(): T { - check(dimension == 0) { - "This tensor has shape ${shape.toList()}" - } + isValue() return item() } diff --git a/kmath-torch/src/nativeMain/kotlin/kscience.kmath.torch/TorchTensorAlgebra.kt b/kmath-torch/src/nativeMain/kotlin/kscience.kmath.torch/TorchTensorAlgebra.kt index fb07db744..0ae6954ff 100644 --- a/kmath-torch/src/nativeMain/kotlin/kscience.kmath.torch/TorchTensorAlgebra.kt +++ b/kmath-torch/src/nativeMain/kotlin/kscience.kmath.torch/TorchTensorAlgebra.kt @@ -157,8 +157,16 @@ public sealed class TorchTensorFieldAlgebra { - require(this.device is TorchDevice.TorchCPU) { + check(this.device is TorchDevice.TorchCPU) { "This tensor is not on available on CPU" } return get_data_long(this.tensorHandle)!! diff --git a/kmath-torch/src/nativeTest/kotlin/kscience/kmath/torch/TestAutograd.kt b/kmath-torch/src/nativeTest/kotlin/kscience/kmath/torch/TestAutograd.kt index 4512df180..076188fd1 100644 --- a/kmath-torch/src/nativeTest/kotlin/kscience/kmath/torch/TestAutograd.kt +++ b/kmath-torch/src/nativeTest/kotlin/kscience/kmath/torch/TestAutograd.kt @@ -14,10 +14,12 @@ internal fun testingAutoGrad(dim: Int, device: TorchDevice = TorchDevice.TorchCP val expressionAtX = 0.5 * (tensorX dot (tensorSigma dot tensorX)) + (tensorMu dot tensorX) + 25.9 - val gradientAtX = expressionAtX grad tensorX + val gradientAtX = expressionAtX.grad(tensorX, retainGraph = true) + val hessianAtX = expressionAtX hess tensorX val expectedGradientAtX = (tensorSigma dot tensorX) + tensorMu - val error = (gradientAtX - expectedGradientAtX).abs().sum().value() + val error = (gradientAtX - expectedGradientAtX).abs().sum().value() + + (hessianAtX - tensorSigma).abs().sum().value() assertTrue(error < TOLERANCE) } } @@ -47,6 +49,7 @@ internal fun testingBatchedAutoGrad(bath: IntArray, } } + internal class TestAutograd { @Test fun testAutoGrad() = testingAutoGrad(dim = 100)