fma 指令

今天写 atan 反向传播的 CUDA kernel 发现和 torch 算出来的不一样。核心代码如下:

__global__ void elementwise_atan_backward(float* in, float* din, float* dout, int N) {
  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (N); i += blockDim.x * gridDim.x) {
    // din[i] = dout[i] / (in[i] * in[i] + 1.f);
    //
    // pytorch/tools/autograd/derivatives.yaml
    // - name: atan(Tensor self) -> Tensor
    //   self: grad / (self * self + 1).conj()
    //   result: auto_element_wise
    //
    float square = in[i] * in[i] + 0.f; // disable fma so we can align with torch
    din[i] = dout[i] / (square + 1.f);
  }
}

PyTorch 是用规则文件生成自动梯度求导的,torch.atan 的求导规则为 grad / (self * self + 1).conj(),其中 gradself 都是张量,conj() 在实数张量的场景下是可以不管的。我按照同样的方法写 kernel,发现和 torch 算出来的结果有极小的差异。

上面的代码 我也放在 Compiler Explorer 了,从汇编来看,in[i] * in[i] + 1.f 被优化成了一条指令,这样不仅更快,而且精度更高(从网上的资料来看,nvcc 默认会进行不少优化,fma 就是其中一种)。PyTorch 生成梯度求导规则的方法将乘法和加法放在了两个不同的 kernel 里面,导致编译器无法使用 fma 优化。

我尝试把乘法和加法拆开,但是编译器太聪明了,还是使用了 fma 优化:

float square = in[i] * in[i];
din[i] = dout[i] / (square + 1.f);

最后想到的一种方法是先让编译器用常数 0 完成 fma,防止乘法结果和 1.f 融合。

现代 CPU 和 GPU 基本都支持 fma。其实 fma 函数在 C++ 中也是有的,可以参考 https://en.cppreference.com/w/cpp/numeric/math/fma 。还有一篇看上去比较好的文章 https://momentsingraphics.de/FMA.html