torch 自动求导的代码在哪里?
2024 年 8 月 5 日:当前 torch 发布的版本是 2.4。PyTorch 的源码中还有几个 YAML 文件,这些文件都挺重要的,可以关注一下。
tools/autograd/derivatives.yaml 中有一些求导代码片段:
- name: prod.dim_int(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
self: prod_backward(grad, self.to(grad.scalar_type()), result, dim, keepdim)
result: (prod_backward(at::ones({}, result.options()).expand_as(result), self_p.to(result.scalar_type()), result, dim, keepdim) * self_t.conj()).sum(dim, keepdim).conj()
- name: put(Tensor self, Tensor index, Tensor source, bool accumulate=False) -> Tensor
self: "accumulate ? grad : grad.put(index, zeros_like(source), false)"
index: non_differentiable
source: grad.take(index).reshape_as(source)
result: self_t.put(index, source_t, accumulate)
- name: linalg_qr(Tensor A, str mode='reduced') -> (Tensor Q, Tensor R)
A: linalg_qr_backward(grad_Q, grad_R, Q, R, mode)
Q, R: linalg_qr_jvp(A_t, Q, R, mode)
- name: rad2deg(Tensor self) -> Tensor
self: rad2deg_backward(grad)
result: auto_element_wise
- name: random_.from(Tensor(a!) self, int from, int? to, *, Generator? generator=None) -> Tensor(a!)
self: zeros_like(grad)
result: self_t.zero_()
- name: random_.to(Tensor(a!) self, int to, *, Generator? generator=None) -> Tensor(a!)
self: zeros_like(grad)
result: self_t.zero_()
看起来 name
是近似于 python 的伪代码,其他的都是 C++ 代码?
aten/src/ATen/native/native_functions.yaml 这个文件中有 native 函数在不同平台上的派发,比如:
- func: native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)
dispatch:
CPU: batch_norm_cpu
CUDA: batch_norm_cuda
MPS: batch_norm_mps
MkldnnCPU: mkldnn_batch_norm