fma 指令
今天写 atan 反向传播的 CUDA kernel 发现和 torch 算出来的不一样。核心代码如下:
__global__ void elementwise_atan_backward(float* in, float* din, float* dout, int N) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (N); i += blockDim.x * gridDim.x) {
// din[i] = dout[i] / (in[i] * in[i] + 1.f);
//
// pytorch/tools/autograd/derivatives.yaml
// - name: atan(Tensor self) -> Tensor
// self: grad / (self * self + 1).conj()
// result: auto_element_wise
//
float square = in[i] * in[i] + 0.f; // disable fma so we can align with torch
din[i] = dout[i] / (square + 1.f);
}
}
PyTorch 是用规则文件生成自动梯度求导的,torch.atan
的求导规则为 grad / (self * self + 1).conj()
,其中 grad
和 self
都是张量,conj()
在实数张量的场景下是可以不管的。我按照同样的方法写 kernel,发现和 torch 算出来的结果有极小的差异。
上面的代码 我也放在 Compiler Explorer 了,从汇编来看,in[i] * in[i] + 1.f
被优化成了一条指令,这样不仅更快,而且精度更高(从网上的资料来看,nvcc 默认会进行不少优化,fma 就是其中一种)。PyTorch 生成梯度求导规则的方法将乘法和加法放在了两个不同的 kernel 里面,导致编译器无法使用 fma 优化。
我尝试把乘法和加法拆开,但是编译器太聪明了,还是使用了 fma 优化:
float square = in[i] * in[i];
din[i] = dout[i] / (square + 1.f);
最后想到的一种方法是先让编译器用常数 0 完成 fma,防止乘法结果和 1.f 融合。
现代 CPU 和 GPU 基本都支持 fma。其实 fma 函数在 C++ 中也是有的,可以参考 https://en.cppreference.com/w/cpp/numeric/math/fma 。还有一篇看上去比较好的文章 https://momentsingraphics.de/FMA.html 。