torch 自动求导的代码在哪里?

2024 年 8 月 5 日:当前 torch 发布的版本是 2.4。PyTorch 的源码中还有几个 YAML 文件,这些文件都挺重要的,可以关注一下。

tools/autograd/derivatives.yaml 中有一些求导代码片段:

- name: prod.dim_int(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
  self: prod_backward(grad, self.to(grad.scalar_type()), result, dim, keepdim)
  result: (prod_backward(at::ones({}, result.options()).expand_as(result), self_p.to(result.scalar_type()), result, dim, keepdim) * self_t.conj()).sum(dim, keepdim).conj()

- name: put(Tensor self, Tensor index, Tensor source, bool accumulate=False) -> Tensor
  self: "accumulate ? grad : grad.put(index, zeros_like(source), false)"
  index: non_differentiable
  source: grad.take(index).reshape_as(source)
  result: self_t.put(index, source_t, accumulate)

- name: linalg_qr(Tensor A, str mode='reduced') -> (Tensor Q, Tensor R)
  A: linalg_qr_backward(grad_Q, grad_R, Q, R, mode)
  Q, R: linalg_qr_jvp(A_t, Q, R, mode)

- name: rad2deg(Tensor self) -> Tensor
  self: rad2deg_backward(grad)
  result: auto_element_wise

- name: random_.from(Tensor(a!) self, int from, int? to, *, Generator? generator=None) -> Tensor(a!)
  self: zeros_like(grad)
  result: self_t.zero_()

- name: random_.to(Tensor(a!) self, int to, *, Generator? generator=None) -> Tensor(a!)
  self: zeros_like(grad)
  result: self_t.zero_()

看起来 name 是近似于 python 的伪代码,其他的都是 C++ 代码?

aten/src/ATen/native/native_functions.yaml 这个文件中有 native 函数在不同平台上的派发,比如:

- func: native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)
  dispatch:
    CPU: batch_norm_cpu
    CUDA: batch_norm_cuda
    MPS: batch_norm_mps
    MkldnnCPU: mkldnn_batch_norm