Hessian implementation
This commit is contained in:
parent
7d25aa2834
commit
fbb414731b
@ -147,7 +147,8 @@ extern "C"
|
||||
bool requires_grad(TorchTensorHandle tensor_handle);
|
||||
void requires_grad_(TorchTensorHandle tensor_handle, bool status);
|
||||
TorchTensorHandle detach_from_graph(TorchTensorHandle tensor_handle);
|
||||
TorchTensorHandle autograd_tensor(TorchTensorHandle value, TorchTensorHandle variable);
|
||||
TorchTensorHandle autograd_tensor(TorchTensorHandle value, TorchTensorHandle variable, bool retain_graph);
|
||||
TorchTensorHandle autohess_tensor(TorchTensorHandle value, TorchTensorHandle variable);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
@ -33,6 +33,16 @@ namespace ctorch
|
||||
return *static_cast<torch::Tensor *>(tensor_handle);
|
||||
}
|
||||
|
||||
inline char *tensor_to_char(const torch::Tensor &tensor)
|
||||
{
|
||||
std::stringstream bufrep;
|
||||
bufrep << tensor;
|
||||
auto rep = bufrep.str();
|
||||
char *crep = (char *)malloc(rep.length() + 1);
|
||||
std::strcpy(crep, rep.c_str());
|
||||
return crep;
|
||||
}
|
||||
|
||||
inline int device_to_int(const torch::Tensor &tensor)
|
||||
{
|
||||
return (tensor.device().type() == torch::kCPU) ? 0 : 1 + tensor.device().index();
|
||||
@ -104,4 +114,27 @@ namespace ctorch
|
||||
return torch::full(shape, value, torch::TensorOptions().dtype(dtype<Dtype>()).layout(torch::kStrided).device(device));
|
||||
}
|
||||
|
||||
inline torch::Tensor hessian(const torch::Tensor &value, const torch::Tensor &variable)
|
||||
{
|
||||
auto nelem = variable.numel();
|
||||
auto hess = value.new_zeros({nelem, nelem});
|
||||
auto grad = torch::autograd::grad({value}, {variable}, {}, torch::nullopt, true)[0].view(nelem);
|
||||
int i = 0;
|
||||
for (int j = 0; j < nelem; j++)
|
||||
{
|
||||
auto row = grad[j].requires_grad()
|
||||
? torch::autograd::grad({grad[i]}, {variable}, {}, true, true, true)[0].view(nelem).slice(0, j, nelem)
|
||||
: grad[j].new_zeros(nelem - j);
|
||||
hess[i].slice(0, i, nelem).add_(row.type_as(hess));
|
||||
i++;
|
||||
}
|
||||
auto ndim = variable.dim();
|
||||
auto sizes = variable.sizes().data();
|
||||
auto shape = std::vector<int64_t>(ndim);
|
||||
shape.assign(sizes, sizes + ndim);
|
||||
shape.reserve(2 * ndim);
|
||||
std::copy_n(shape.begin(), ndim, std::back_inserter(shape));
|
||||
return (hess + torch::triu(hess, 1).t()).view(shape);
|
||||
}
|
||||
|
||||
} // namespace ctorch
|
||||
|
@ -77,12 +77,7 @@ void swap_tensors(TorchTensorHandle lhs_handle, TorchTensorHandle rhs_handle)
|
||||
|
||||
char *tensor_to_string(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
std::stringstream bufrep;
|
||||
bufrep << ctorch::cast(tensor_handle);
|
||||
auto rep = bufrep.str();
|
||||
char *crep = (char *)malloc(rep.length() + 1);
|
||||
std::strcpy(crep, rep.c_str());
|
||||
return crep;
|
||||
return ctorch::tensor_to_char(ctorch::cast(tensor_handle));
|
||||
}
|
||||
void dispose_char(char *ptr)
|
||||
{
|
||||
@ -456,7 +451,11 @@ TorchTensorHandle detach_from_graph(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return new torch::Tensor(ctorch::cast(tensor_handle).detach());
|
||||
}
|
||||
TorchTensorHandle autograd_tensor(TorchTensorHandle value, TorchTensorHandle variable)
|
||||
TorchTensorHandle autograd_tensor(TorchTensorHandle value, TorchTensorHandle variable, bool retain_graph)
|
||||
{
|
||||
return new torch::Tensor(torch::autograd::grad({ctorch::cast(value)}, {ctorch::cast(variable)})[0]);
|
||||
return new torch::Tensor(torch::autograd::grad({ctorch::cast(value)}, {ctorch::cast(variable)}, {}, retain_graph)[0]);
|
||||
}
|
||||
TorchTensorHandle autohess_tensor(TorchTensorHandle value, TorchTensorHandle variable)
|
||||
{
|
||||
return new torch::Tensor(ctorch::hessian(ctorch::cast(value), ctorch::cast(variable)));
|
||||
}
|
||||
|
@ -42,10 +42,12 @@ public sealed class TorchTensor<T> constructor(
|
||||
return indices.map { it to get(it) }
|
||||
}
|
||||
|
||||
internal inline fun isValue() = check(dimension == 0) {
|
||||
"This tensor has shape ${shape.toList()}"
|
||||
}
|
||||
|
||||
public fun value(): T {
|
||||
check(dimension == 0) {
|
||||
"This tensor has shape ${shape.toList()}"
|
||||
}
|
||||
isValue()
|
||||
return item()
|
||||
}
|
||||
|
||||
|
@ -157,8 +157,16 @@ public sealed class TorchTensorFieldAlgebra<T, TVar : CPrimitiveVar,
|
||||
return Pair(wrap(S), wrap(V))
|
||||
}
|
||||
|
||||
public fun TorchTensorType.grad(variable: TorchTensorType, retainGraph: Boolean=false): TorchTensorType {
|
||||
this.isValue()
|
||||
return wrap(autograd_tensor(this.tensorHandle, variable.tensorHandle, retainGraph)!!)
|
||||
}
|
||||
public infix fun TorchTensorType.grad(variable: TorchTensorType): TorchTensorType =
|
||||
wrap(autograd_tensor(this.tensorHandle, variable.tensorHandle)!!)
|
||||
this.grad(variable, false)
|
||||
public infix fun TorchTensorType.hess(variable: TorchTensorType): TorchTensorType {
|
||||
this.isValue()
|
||||
return wrap(autohess_tensor(this.tensorHandle, variable.tensorHandle)!!)
|
||||
}
|
||||
|
||||
public fun TorchTensorType.detachFromGraph(): TorchTensorType =
|
||||
wrap(tensorHandle = detach_from_graph(this.tensorHandle)!!)
|
||||
@ -312,7 +320,7 @@ public class TorchTensorLongAlgebra(scope: DeferScope) :
|
||||
wrap(from_blob_long(arrayBlob, shape.toCValues(), shape.size, TorchDevice.TorchCPU.toInt(), false)!!)
|
||||
|
||||
override fun TorchTensorLong.getData(): CPointer<LongVar> {
|
||||
require(this.device is TorchDevice.TorchCPU) {
|
||||
check(this.device is TorchDevice.TorchCPU) {
|
||||
"This tensor is not on available on CPU"
|
||||
}
|
||||
return get_data_long(this.tensorHandle)!!
|
||||
|
@ -14,10 +14,12 @@ internal fun testingAutoGrad(dim: Int, device: TorchDevice = TorchDevice.TorchCP
|
||||
val expressionAtX =
|
||||
0.5 * (tensorX dot (tensorSigma dot tensorX)) + (tensorMu dot tensorX) + 25.9
|
||||
|
||||
val gradientAtX = expressionAtX grad tensorX
|
||||
val gradientAtX = expressionAtX.grad(tensorX, retainGraph = true)
|
||||
val hessianAtX = expressionAtX hess tensorX
|
||||
val expectedGradientAtX = (tensorSigma dot tensorX) + tensorMu
|
||||
|
||||
val error = (gradientAtX - expectedGradientAtX).abs().sum().value()
|
||||
val error = (gradientAtX - expectedGradientAtX).abs().sum().value() +
|
||||
(hessianAtX - tensorSigma).abs().sum().value()
|
||||
assertTrue(error < TOLERANCE)
|
||||
}
|
||||
}
|
||||
@ -47,6 +49,7 @@ internal fun testingBatchedAutoGrad(bath: IntArray,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
internal class TestAutograd {
|
||||
@Test
|
||||
fun testAutoGrad() = testingAutoGrad(dim = 100)
|
||||
|
Loading…
Reference in New Issue
Block a user