PyTorch 的 CPU 计算为什么使用 double 作为 32 位浮点数的累加类型?
Tip
本文没有得到最终结论,只是一些个人猜想。
CUDA Kernel 常用 float 类型这件事 这篇笔记提到一个 issue 里面说 ATen 使用 double 作为 32 位浮点数的累加类型是因为用 float 会挂掉一个 batchnorm 的测试。这在我看来不可思议,因为 GPU 上面测试都没有挂,为什么 CPU 上面反而会挂呢?本文记录笔者看 batchnorm 的实现、试图找到原因的过程。
PyTorch 的 batchnorm 要对每一轮的输入计算均值和标准差,其中计算均值就需要将输入都加起来(再除以元素总数),这中间就可能产生累加误差。标准差的计算也类似,会有累加产生的误差。
从 aten/src/ATen/native/native_functions.yaml 中可以找到 native_batch_norm()
函数是在哪里实现的:
- func: native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)
dispatch:
CPU: batch_norm_cpu
CUDA: batch_norm_cuda
MPS: batch_norm_mps
MkldnnCPU: mkldnn_batch_norm
CUDA 中的累加函数是 batch_norm_collect_statistics_kernel()
。计算方差时用了 Welford 算法,看不懂。Welford 算法是一种只进行一次遍历的在线计算方差的方法,能够有效降低舍入误差。可以参考 https://shuai.guru/welford-variance/ 和 https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance 。
CPU 中的累加函数是 batch_norm_cpu_collect_stats_kernel()
。该函数又分情况调用了很多函数,可以先参考 batch_norm_cpu_collect_stats_contiguous_impl()
(代码如下),可以看出并行是在多通道层面上的,每个通道上是把每个像素的值一个一个加起来的。按照 https://en.wikipedia.org/wiki/Pairwise_summation 的说法,这样的 naive 累加误差会比 pairwise 的算法误差大。而算方差的时候,这段代码又进行了第二次遍历,没有使用一次遍历算法,误差应该是会比 Welford 更小一点?
// batch_norm_cpu_collect_stats_contiguous_impl() 函数中的代码
at::parallel_for(0, n_channel, 1, [&](int64_t begin, int64_t end) {
for (const auto c : c10::irange(begin, end)) {
// compute mean per input
accscalar_t sum = 0;
for (const auto n : c10::irange(n_batch)) {
for (const auto i : c10::irange(image_size)) {
auto offset = n * n_channel * image_size + c * image_size + i;
sum += input_data[offset];
}
}
scalar_t mean = sum / N;
mean_data[c] = mean;
// compute variance per input
accscalar_t _var_sum = 0;
for (const auto n : c10::irange(n_batch)) {
for (const auto i : c10::irange(image_size)) {
auto offset = n * n_channel * image_size + c * image_size + i;
auto x = input_data[offset];
_var_sum += (x - mean) * (x - mean);
}
}
var_sum_data[c] = _var_sum;
}
});
https://github.com/pytorch/pytorch/issues/113414#issuecomment-1809470317 中说 CPU 上 batchnorm 在单精度计算时使用 float 类型而不是 double 类型作为累加类型,会导致一个单元测试失败,是不是因为 CPU 在计算均值和方差的求和过程中没有用 pairwise 或者 kahan 算法?(2024 年 8 月 27 日:可能是相比于 pairwise,按顺序加对 cache 更友好;kahan 会慢。)