forked from kscience/kmath
Hessian implementation
This commit is contained in:
parent
7d25aa2834
commit
fbb414731b
@ -147,7 +147,8 @@ extern "C"
|
|||||||
bool requires_grad(TorchTensorHandle tensor_handle);
|
bool requires_grad(TorchTensorHandle tensor_handle);
|
||||||
void requires_grad_(TorchTensorHandle tensor_handle, bool status);
|
void requires_grad_(TorchTensorHandle tensor_handle, bool status);
|
||||||
TorchTensorHandle detach_from_graph(TorchTensorHandle tensor_handle);
|
TorchTensorHandle detach_from_graph(TorchTensorHandle tensor_handle);
|
||||||
TorchTensorHandle autograd_tensor(TorchTensorHandle value, TorchTensorHandle variable);
|
TorchTensorHandle autograd_tensor(TorchTensorHandle value, TorchTensorHandle variable, bool retain_graph);
|
||||||
|
TorchTensorHandle autohess_tensor(TorchTensorHandle value, TorchTensorHandle variable);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
|
@ -33,6 +33,16 @@ namespace ctorch
|
|||||||
return *static_cast<torch::Tensor *>(tensor_handle);
|
return *static_cast<torch::Tensor *>(tensor_handle);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline char *tensor_to_char(const torch::Tensor &tensor)
|
||||||
|
{
|
||||||
|
std::stringstream bufrep;
|
||||||
|
bufrep << tensor;
|
||||||
|
auto rep = bufrep.str();
|
||||||
|
char *crep = (char *)malloc(rep.length() + 1);
|
||||||
|
std::strcpy(crep, rep.c_str());
|
||||||
|
return crep;
|
||||||
|
}
|
||||||
|
|
||||||
inline int device_to_int(const torch::Tensor &tensor)
|
inline int device_to_int(const torch::Tensor &tensor)
|
||||||
{
|
{
|
||||||
return (tensor.device().type() == torch::kCPU) ? 0 : 1 + tensor.device().index();
|
return (tensor.device().type() == torch::kCPU) ? 0 : 1 + tensor.device().index();
|
||||||
@ -104,4 +114,27 @@ namespace ctorch
|
|||||||
return torch::full(shape, value, torch::TensorOptions().dtype(dtype<Dtype>()).layout(torch::kStrided).device(device));
|
return torch::full(shape, value, torch::TensorOptions().dtype(dtype<Dtype>()).layout(torch::kStrided).device(device));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline torch::Tensor hessian(const torch::Tensor &value, const torch::Tensor &variable)
|
||||||
|
{
|
||||||
|
auto nelem = variable.numel();
|
||||||
|
auto hess = value.new_zeros({nelem, nelem});
|
||||||
|
auto grad = torch::autograd::grad({value}, {variable}, {}, torch::nullopt, true)[0].view(nelem);
|
||||||
|
int i = 0;
|
||||||
|
for (int j = 0; j < nelem; j++)
|
||||||
|
{
|
||||||
|
auto row = grad[j].requires_grad()
|
||||||
|
? torch::autograd::grad({grad[i]}, {variable}, {}, true, true, true)[0].view(nelem).slice(0, j, nelem)
|
||||||
|
: grad[j].new_zeros(nelem - j);
|
||||||
|
hess[i].slice(0, i, nelem).add_(row.type_as(hess));
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
auto ndim = variable.dim();
|
||||||
|
auto sizes = variable.sizes().data();
|
||||||
|
auto shape = std::vector<int64_t>(ndim);
|
||||||
|
shape.assign(sizes, sizes + ndim);
|
||||||
|
shape.reserve(2 * ndim);
|
||||||
|
std::copy_n(shape.begin(), ndim, std::back_inserter(shape));
|
||||||
|
return (hess + torch::triu(hess, 1).t()).view(shape);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace ctorch
|
} // namespace ctorch
|
||||||
|
@ -77,12 +77,7 @@ void swap_tensors(TorchTensorHandle lhs_handle, TorchTensorHandle rhs_handle)
|
|||||||
|
|
||||||
char *tensor_to_string(TorchTensorHandle tensor_handle)
|
char *tensor_to_string(TorchTensorHandle tensor_handle)
|
||||||
{
|
{
|
||||||
std::stringstream bufrep;
|
return ctorch::tensor_to_char(ctorch::cast(tensor_handle));
|
||||||
bufrep << ctorch::cast(tensor_handle);
|
|
||||||
auto rep = bufrep.str();
|
|
||||||
char *crep = (char *)malloc(rep.length() + 1);
|
|
||||||
std::strcpy(crep, rep.c_str());
|
|
||||||
return crep;
|
|
||||||
}
|
}
|
||||||
void dispose_char(char *ptr)
|
void dispose_char(char *ptr)
|
||||||
{
|
{
|
||||||
@ -456,7 +451,11 @@ TorchTensorHandle detach_from_graph(TorchTensorHandle tensor_handle)
|
|||||||
{
|
{
|
||||||
return new torch::Tensor(ctorch::cast(tensor_handle).detach());
|
return new torch::Tensor(ctorch::cast(tensor_handle).detach());
|
||||||
}
|
}
|
||||||
TorchTensorHandle autograd_tensor(TorchTensorHandle value, TorchTensorHandle variable)
|
TorchTensorHandle autograd_tensor(TorchTensorHandle value, TorchTensorHandle variable, bool retain_graph)
|
||||||
{
|
{
|
||||||
return new torch::Tensor(torch::autograd::grad({ctorch::cast(value)}, {ctorch::cast(variable)})[0]);
|
return new torch::Tensor(torch::autograd::grad({ctorch::cast(value)}, {ctorch::cast(variable)}, {}, retain_graph)[0]);
|
||||||
|
}
|
||||||
|
TorchTensorHandle autohess_tensor(TorchTensorHandle value, TorchTensorHandle variable)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(ctorch::hessian(ctorch::cast(value), ctorch::cast(variable)));
|
||||||
}
|
}
|
||||||
|
@ -42,10 +42,12 @@ public sealed class TorchTensor<T> constructor(
|
|||||||
return indices.map { it to get(it) }
|
return indices.map { it to get(it) }
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun value(): T {
|
internal inline fun isValue() = check(dimension == 0) {
|
||||||
check(dimension == 0) {
|
|
||||||
"This tensor has shape ${shape.toList()}"
|
"This tensor has shape ${shape.toList()}"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public fun value(): T {
|
||||||
|
isValue()
|
||||||
return item()
|
return item()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -157,8 +157,16 @@ public sealed class TorchTensorFieldAlgebra<T, TVar : CPrimitiveVar,
|
|||||||
return Pair(wrap(S), wrap(V))
|
return Pair(wrap(S), wrap(V))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public fun TorchTensorType.grad(variable: TorchTensorType, retainGraph: Boolean=false): TorchTensorType {
|
||||||
|
this.isValue()
|
||||||
|
return wrap(autograd_tensor(this.tensorHandle, variable.tensorHandle, retainGraph)!!)
|
||||||
|
}
|
||||||
public infix fun TorchTensorType.grad(variable: TorchTensorType): TorchTensorType =
|
public infix fun TorchTensorType.grad(variable: TorchTensorType): TorchTensorType =
|
||||||
wrap(autograd_tensor(this.tensorHandle, variable.tensorHandle)!!)
|
this.grad(variable, false)
|
||||||
|
public infix fun TorchTensorType.hess(variable: TorchTensorType): TorchTensorType {
|
||||||
|
this.isValue()
|
||||||
|
return wrap(autohess_tensor(this.tensorHandle, variable.tensorHandle)!!)
|
||||||
|
}
|
||||||
|
|
||||||
public fun TorchTensorType.detachFromGraph(): TorchTensorType =
|
public fun TorchTensorType.detachFromGraph(): TorchTensorType =
|
||||||
wrap(tensorHandle = detach_from_graph(this.tensorHandle)!!)
|
wrap(tensorHandle = detach_from_graph(this.tensorHandle)!!)
|
||||||
@ -312,7 +320,7 @@ public class TorchTensorLongAlgebra(scope: DeferScope) :
|
|||||||
wrap(from_blob_long(arrayBlob, shape.toCValues(), shape.size, TorchDevice.TorchCPU.toInt(), false)!!)
|
wrap(from_blob_long(arrayBlob, shape.toCValues(), shape.size, TorchDevice.TorchCPU.toInt(), false)!!)
|
||||||
|
|
||||||
override fun TorchTensorLong.getData(): CPointer<LongVar> {
|
override fun TorchTensorLong.getData(): CPointer<LongVar> {
|
||||||
require(this.device is TorchDevice.TorchCPU) {
|
check(this.device is TorchDevice.TorchCPU) {
|
||||||
"This tensor is not on available on CPU"
|
"This tensor is not on available on CPU"
|
||||||
}
|
}
|
||||||
return get_data_long(this.tensorHandle)!!
|
return get_data_long(this.tensorHandle)!!
|
||||||
|
@ -14,10 +14,12 @@ internal fun testingAutoGrad(dim: Int, device: TorchDevice = TorchDevice.TorchCP
|
|||||||
val expressionAtX =
|
val expressionAtX =
|
||||||
0.5 * (tensorX dot (tensorSigma dot tensorX)) + (tensorMu dot tensorX) + 25.9
|
0.5 * (tensorX dot (tensorSigma dot tensorX)) + (tensorMu dot tensorX) + 25.9
|
||||||
|
|
||||||
val gradientAtX = expressionAtX grad tensorX
|
val gradientAtX = expressionAtX.grad(tensorX, retainGraph = true)
|
||||||
|
val hessianAtX = expressionAtX hess tensorX
|
||||||
val expectedGradientAtX = (tensorSigma dot tensorX) + tensorMu
|
val expectedGradientAtX = (tensorSigma dot tensorX) + tensorMu
|
||||||
|
|
||||||
val error = (gradientAtX - expectedGradientAtX).abs().sum().value()
|
val error = (gradientAtX - expectedGradientAtX).abs().sum().value() +
|
||||||
|
(hessianAtX - tensorSigma).abs().sum().value()
|
||||||
assertTrue(error < TOLERANCE)
|
assertTrue(error < TOLERANCE)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -47,6 +49,7 @@ internal fun testingBatchedAutoGrad(bath: IntArray,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
internal class TestAutograd {
|
internal class TestAutograd {
|
||||||
@Test
|
@Test
|
||||||
fun testAutoGrad() = testingAutoGrad(dim = 100)
|
fun testAutoGrad() = testingAutoGrad(dim = 100)
|
||||||
|
Loading…
Reference in New Issue
Block a user