PyTorch 注册反向传播的钩子

注册钩子

通过注册钩子,收集模型运行时的输出,可以对每一层的输出值进行调试。

假设模型是 model,我们可以把每一层的输入的梯度和输出的梯度保存在字典中:

module_names = {v: k for k, v in model.named_modules()}
grad_inputs = {}
grad_outputs = {}

def hook(m, grad_input, grad_output) -> Tuple[torch.Tensor] | None:
    nonlocal module_names, grad_inputs, grad_outputs
    name = module_names[m]
    grad_inputs[name] = grad_input
    grad_outputs[name] = grad_output

for m in model.modules(): # 或者在这里遍历 named_modules() 并记录名称和模型的对应关系
    m.register_full_backward_hook(hook)

注意 grad_input 的含义是输入的梯度,grad_output 的含义是输出的梯度。假设一个 module 的 forward() 函数负责计算 y = 2 * x,那么 grad_input 相当于 retain_grad() 之后的 x.gradgrad_output 相当于 retain_grad() 之后的 y.grad

另外还要注意:

  1. 我们要遍历每个子模型(torch.nn.Module 对象),按需注册钩子,因为 register_full_backward_hook 本身是不会递归注册的。
  2. 钩子的注册要放在模型开始使用之前,不能在循环中反复注册钩子。

其他的观察

  1. 打印 torch.nn.Module 时,其中的 torch.nn.parameter.Parameter 属性会被打印,torch.nn.Module 属性也会被递归打印出来。所以如果自己编写的模型带有非 nn.Module 的参数,最好存储为 torch.nn.parameter.Parameter 的形式。
  2. torch.nn.Modulestate_dict() 里的张量应该是 detach 过的,没有梯度。要利用梯度,可以用 named_parameters()
  3. state_dict() 包含 BatchNormalization 的 running_meanrunning_var,但是这两个信息不是参数,所以在 named_parameters() 中是没有的。因此比较权重要用 state_dict(),比较梯度要用 named_parameters()