PyTorch 注册反向传播的钩子
注册钩子
通过注册钩子,收集模型运行时的输出,可以对每一层的输出值进行调试。
假设模型是 model
,我们可以把每一层的输入的梯度和输出的梯度保存在字典中:
module_names = {v: k for k, v in model.named_modules()}
grad_inputs = {}
grad_outputs = {}
def hook(m, grad_input, grad_output) -> Tuple[torch.Tensor] | None:
nonlocal module_names, grad_inputs, grad_outputs
name = module_names[m]
grad_inputs[name] = grad_input
grad_outputs[name] = grad_output
for m in model.modules(): # 或者在这里遍历 named_modules() 并记录名称和模型的对应关系
m.register_full_backward_hook(hook)
注意 grad_input
的含义是输入的梯度,grad_output
的含义是输出的梯度。假设一个 module 的 forward()
函数负责计算 y = 2 * x
,那么 grad_input
相当于 retain_grad()
之后的 x.grad
,grad_output
相当于 retain_grad()
之后的 y.grad
。
另外还要注意:
- 我们要遍历每个子模型(
torch.nn.Module
对象),按需注册钩子,因为register_full_backward_hook
本身是不会递归注册的。 - 钩子的注册要放在模型开始使用之前,不能在循环中反复注册钩子。
其他的观察
- 打印
torch.nn.Module
时,其中的torch.nn.parameter.Parameter
属性会被打印,torch.nn.Module
属性也会被递归打印出来。所以如果自己编写的模型带有非nn.Module
的参数,最好存储为torch.nn.parameter.Parameter
的形式。 torch.nn.Module
的state_dict()
里的张量应该是 detach 过的,没有梯度。要利用梯度,可以用named_parameters()
。state_dict()
包含 BatchNormalization 的running_mean
和running_var
,但是这两个信息不是参数,所以在named_parameters()
中是没有的。因此比较权重要用state_dict()
,比较梯度要用named_parameters()
。