如何对齐 PyTorch 的除法运算?

过程

我遇到的情况是:a 为 64 位浮点数(FP64)标量,b 为 32 位浮点数(FP32)张量,要计算 a / b

一种做法是:使用 1 / b * a 来代替 a / b。这样的结果看起来和 PyTorch 的计算是对齐的。

奇怪的是,在 aten/src/ATen/native/cuda/BinaryDivTrueKernel.cu 这个代码的判断条件中,只有分母是 CPU 上的标量时,才会将除法转换成乘法运算,而我的遇到的情况是分子是标量,分母是张量,不符合这个条件。代码:

namespace at::native {
namespace binary_internal {

CONSTEXPR_EXCEPT_WIN_CUDA char div_name[] = "div_kernel";
void div_true_kernel_cuda(TensorIteratorBase& iter) {
  auto common_dtype = iter.common_dtype();
  if (iter.common_dtype() == kComplexHalf) {
    // 省略
  }
  if (iter.is_cpu_scalar(2)) {
    // optimization for floating-point types: if the second operand is a CPU
    // scalar, compute a * reciprocal(b). Note that this may lose one bit of
    // precision compared to computing the division.
    AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
        kHalf, kBFloat16, common_dtype, "div_true_cuda", [&]() {
          using opmath_t = at::opmath_type<scalar_t>;
          auto inv_b = opmath_t(1.0) / iter.scalar_value<opmath_t>(2);
          iter.remove_operand(2);
          gpu_kernel(
              iter,
              BUnaryFunctor<scalar_t, scalar_t, scalar_t, MulFunctor<opmath_t>>(
                  MulFunctor<opmath_t>(), inv_b));
        });
  } else {
    AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
        kHalf, kBFloat16, common_dtype, "div_true_cuda", [&]() {
          DivFunctor<scalar_t> f;
          gpu_kernel_with_scalars(iter, f);
        });
  }
}
} // namespace binary_internal

REGISTER_DISPATCH(div_true_stub, &binary_internal::div_true_kernel_cuda);

} // namespace at::native

接下来是做了一些实验。(注意:因为有随机数,下面的结果每次都可能不一样,但是大致表现是相同的。一开始记录结果的时候忘记了固定随机数种子。)

我尝试检查 PyTorch 是否也会对分母不是 CPU 标量的情况进行除法到乘法的优化,结果看上去是不会,因为 torch 的除法结果和直白的除法计算完全是一样的。

import torch
import numpy as np

for x in torch.randn([1024 * 1024]).tolist():
    a = torch.tensor(1.1).cuda()
    b = torch.tensor([x, 0.1]).cuda()
    c = (a / b)[0].cpu().numpy().view(np.int32)
    d = (np.float32(1.1) / np.float32(x)).view(np.int32)
    assert c == d, f'note: {x=}, {c=}, {d=}'

用 numpy 做实验,结果很符合直觉,32 位除法确实比 32 位乘法更贴近 64 位除法的计算结果。

import numpy as np

a_is_better = 0
b_is_better = 0
for x in np.random.random(1024).tolist():
    x = np.float32(x)
    a = np.float32(1.1) / x
    b = np.float32(1.1) * (np.float32(1) / x)
    c = np.float64(1.1) / x
    d1 = np.abs(a-c)
    d2 = np.abs(b-c)
    if d1 < d2:
        a_is_better += 1
    elif d2 < d1:
        b_is_better += 1
print(f'{a_is_better=}, {b_is_better=}')

# a_is_better=163, b_is_better=98

如果把 c = np.float64(1.1) / x 改成 c = 1.1 / x,那么结果则几乎完全倾向于 “a is better”,而不是像现在这样(a 好和 b 好的情况都有,但 a 好的情况更多)。

 a_is_better = 0
 b_is_better = 0
 for x in np.random.random(1024).tolist():
     x = np.float32(x)
     a = np.float32(1.1) / x
     b = np.float32(1.1) * (np.float32(1) / x)
-    c = np.float64(1.1) / x
+    c = 1.1 / x
     d1 = np.abs(a-c)
     d2 = np.abs(b-c)
     if d1 < d2:
         a_is_better += 1
     elif d2 < d1:
         b_is_better += 1
 print(f'{a_is_better=}, {b_is_better=}')

-# a_is_better=163, b_is_better=98
+# a_is_better=262, b_is_better=0

用 torch 做实验,“分子 1.1 是 Python 内置浮点数还是 numpy/torch 的 64 位浮点数”对结果的影响不同。换成 Python 内置浮点数之后,结果变成了“b is better”。这证明 torch 和 numpy 都对 Python 内置浮点数这样的情况有特殊计算路径

 import torch
 import numpy as np

 a_is_better = 0
 b_is_better = 0
 for x in np.random.random(1024).tolist():
     x = torch.tensor(x).cuda()
     a = torch.tensor(1.1) / x
     b = torch.tensor(1.1) * (torch.tensor(1) / x)
-    c = torch.tensor(1.1, dtype=torch.float64) / x
+    c = 1.1 / x
     d1 = torch.abs(a-c).item()
     d2 = torch.abs(b-c).item()
     if d1 < d2:
         a_is_better += 1
     elif d2 < d1:
         b_is_better += 1
 print(f'{a_is_better=}, {b_is_better=}')

-# a_is_better=151, b_is_better=101
+# a_is_better=0, b_is_better=278

Note

这里使用的分子 1.1 是在 CPU 上,分母 x 在 GPU 上,仍然能进行计算,说明 PyTorch 对于标量做了特殊处理,不要求和其他张量在同一个设备上。我将所有分子的末尾都加上 .cuda(),将其转移到 cuda:0 设备上,但是这并不影响结果。

经过检查,1.1 / x/ 操作符是写在 float 类型的 __truediv__ 方法中的(1 / x/ 则写在 int 类型的 __truediv__ 中)。这导致了被除数为 Python 内置类型时,代码执行的路径不同。如果被除数是 torch 的张量,那么使用的应该是 TensorBase(源代码注释:Defined in torch/csrc/autograd/python_variable.cpp)中的 __truediv__

既然被除数是 Python 内置类型时,调用的方法不同,那么 torch 和 numpy 又是怎么向 Python 注册这些方法的?intfloat 是 Python 内置类型(因此不可能知道 torch.Tensor 的存在) ,在遇到 __truediv__ 时会认为其参数是未知类型,进而尝试在其参数上调用 __rtruediv__ 方法(如果这个方法有实现)。

Note

这里 __rtruediv__r 字母不是 right 的意思,而是 reflected 的意思。

在文件 torch/_tensor.py 中,有对这个方法的定义:

class Tensor(torch._C.TensorBase):
    @_handle_torch_function_and_wrap_type_error_to_not_implemented
    def __rdiv__(self, other):
        return self.reciprocal() * other

    __rtruediv__ = __rdiv__
    __itruediv__ = _C.TensorBase.__idiv__

实际上就是把除法转换为了乘法来算!

再看 numpy:

# numpy/lib/mixins.py
def _reflected_binary_method(ufunc, name):
    """Implement a reflected binary method with a ufunc, e.g., __radd__."""
    def func(self, other):
        if _disables_array_ufunc(other):
            return NotImplemented
        return ufunc(other, self)
    func.__name__ = '__r{}__'.format(name)
    return func

def _numeric_methods(ufunc, name):
    """Implement forward, reflected and inplace binary methods with a ufunc."""
    return (_binary_method(ufunc, name),
            _reflected_binary_method(ufunc, name),
            _inplace_binary_method(ufunc, name))

class NDArrayOperatorsMixin:
    # ...
    __truediv__, __rtruediv__, __itruediv__ = _numeric_methods(
        um.true_divide, 'truediv')
    # ...

其实就是把 um.true_divide 这个 ufunc 的两个参数交换顺序。也就是说 1.1 / x 在 x 为 numpy.ndarray 时,会先匹配到 float.__truediv__(self, numpy.ndarray),再匹配到 numpy.ndarray.__rtruediv__(self, f),然后匹配到 um.true_divide(f, self)。剩下的细节都是 true_divide 函数在处理。

从直觉上来讲,numpy 的处理方式更容易理解一点。PyTorch 这种转除法为乘法的计算方式应该有其用意。

结论

# x: torch.Tensor
1.1 / x => x.__rtruediv__(1.1) => x.reciprocal() * 1.1

# y: numpy.ndarray
1.1 / y => y.__rtruediv__(1.1) => np._core.umath.true_divide(1.1, y)