PyTorch C++ 函数派发

Stub 注册流程

所有的 stub 定义几乎都在 aten/src/ATen/native/DispatchStub.h 文件,可以慢慢看。里面有段注释:

// Implements instruction set specific function dispatch.
//
// Kernels that may make use of specialized instruction sets (e.g. AVX2) are
// compiled multiple times with different compiler flags (e.g. -mavx2). A
// DispatchStub contains a table of function pointers for a kernel. At runtime,
// the fastest available kernel is chosen based on the features reported by
// cpuinfo.
//
// Example:
//
// In native/MyKernel.h:
//   using fn_type = void(*)(const Tensor& x);
//   DECLARE_DISPATCH(fn_type, stub)
//
// In native/MyKernel.cpp
//   DEFINE_DISPATCH(stub);
//
// In native/cpu/MyKernel.cpp:
//   namespace {
//     // use anonymous namespace so that different cpu versions won't conflict
//     void kernel(const Tensor& x) { ... }
//   }
//   REGISTER_DISPATCH(stub, &kernel);
//
// To call:
//   stub(kCPU, tensor);
//
// TODO: CPU instruction set selection should be folded into whatever
// the main dispatch mechanism is.
//
// Supported device types for registration:
//   - CPU: Central Processing Unit
//   - CUDA: NVIDIA GPUs
//   - HIP: AMD GPUs
//   - MPS: Apple Silicon GPUs (Metal Performance Shaders)
//   - MTIA: Meta Training and Inference Devices
//   - XPU: Intel GPUs
//   - HPU: Reserved for HPU (Intel Gaudi) device types
//   - PrivateUse1: Reserved for private/custom device types
//
// If you want to update the list of supported devices, add a new dispatch_ptr
// member in DispatchStubImpl.h and update the get_call_ptr switch.
// As well you will need to update the inlined list in 'is_device_supported`
//
//
// ignore warnings about DispatchStub::DEFAULT, AVX, AVX2 defined elsewhere

DispatchStub 模板基类定义

见 aten/src/ATen/native/DispatchStub.h。DispatchStub 类型为:

template <typename rT, typename T, typename... Args>
struct DispatchStub<rT (*)(Args...), T>;

其中主要包含几类方法,一是调用,会根据设备类型来选择函数指针,强制转换后调用:

	// ...
  template <typename... ArgTypes>
  rT operator()(c10::DeviceType device_type, ArgTypes&&... args) {
    FnPtr call_ptr = get_call_ptr(device_type);
    return (*call_ptr)(std::forward<ArgTypes>(args)...);
  }
	// ...

二是注册指针(只挑两个设备来展示,实际上还支持很多设备),后面注册 stub 的时候这些方法会被调用到:

	// ...
  void set_cuda_dispatch_ptr(FnPtr fn_ptr) {
    impl.cuda_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
  }
  void set_mps_dispatch_ptr(FnPtr fn_ptr) {
    impl.mps_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
  }
	// ...

三是检查设备是否受支持(再深入进去看代码,实际上就是用设备类型来 switch,看对应指针是否注册):

	// ...
  // Returns true if the dispatcher has a kernel registered for this device
  // type.
  bool is_device_supported(const c10::DeviceType device_type) {
    auto result = impl.try_get_call_ptr(device_type
      , reinterpret_cast<void*>(DEFAULT)
#ifdef HAVE_AVX512_CPU_DEFINITION
      , reinterpret_cast<void*>(AVX512)
#endif
#ifdef HAVE_AVX2_CPU_DEFINITION
      , reinterpret_cast<void*>(AVX2)
#endif
#ifdef HAVE_VSX_CPU_DEFINITION
      , reinterpret_cast<void*>(VSX)
#endif
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
      , reinterpret_cast<void*>(ZVECTOR)
#endif
#ifdef HAVE_SVE256_CPU_DEFINITION
      , reinterpret_cast<void*>(SVE256)
#endif
      );
    if (std::holds_alternative<ErrorType>(result)){
      return false;
    }
    return true;
  }
  // ...

声明派发 DECLARE_DISPATCH

声明派发需要 stub 的名字和函数指针类型。声明过程包含对 stub 类的定义,以及对 stub 对象的声明。设备 + stub 名形成了一个独一无二的 stub 类。

#define DECLARE_DISPATCH(fn, name)                                                         \
  struct name##_DECLARE_DISPATCH_type : DispatchStub<fn, name##_DECLARE_DISPATCH_type> {   \
    name##_DECLARE_DISPATCH_type() = default;                                              \
    name##_DECLARE_DISPATCH_type(const name##_DECLARE_DISPATCH_type&) = delete;            \
    name##_DECLARE_DISPATCH_type& operator=(const name##_DECLARE_DISPATCH_type&) = delete; \
    name##_DECLARE_DISPATCH_type(name##_DECLARE_DISPATCH_type&&) = delete;                 \
    name##_DECLARE_DISPATCH_type& operator=(name##_DECLARE_DISPATCH_type&&) = delete;      \
    ~name##_DECLARE_DISPATCH_type() = default;                                             \
  };                                                                                       \
  extern TORCH_API struct name##_DECLARE_DISPATCH_type name;

比如 DECLARE_DISPATCH(index_put_fn, index_put_stub) 会定义一个 index_put_stub_DECLARE_DISPATCH_type 类,并且声明一个该类的对象为 index_put_stub。这个类用 CRTP 继承 DispatchStub 类,fn 参数会用于父类的模板参数。

定义派发 DEFINE_DISPATCH

定义派发只需要一个 stub 名字。

#define DEFINE_DISPATCH(name) struct name##_DECLARE_DISPATCH_type name

比如 DEFINE_DISPATCH(index_put_stub); 会展开为:

struct index_put_stub_DECLARE_DISPATCH_type index_put_stub;

注册派发 REGISTER_DISPATCH

注册派发需要 stub 的名字和函数指针,不同设备的 REGISTER_DISPATCH 定义不同。

比如 CUDA 文件中的 REGISTER_DISPATCH(index_put_stub, &index_put_kernel),实际上宏展开后定义了这样的静态成员:

static RegisterCUDADispatch<struct index_put_stub_DECLARE_DISPATCH_type> \
		index_put_stub__register(index_put_stub, &index_put_kernel);

看这个类的定义:

template <typename DispatchStub>
struct RegisterCUDADispatch {
  RegisterCUDADispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) {
    stub.set_cuda_dispatch_ptr(value);
  }
};

index_put_stub__register 在静态初始化阶段构造函数会调用 set_cuda_dispatch_ptr,从而将 CUDA 设备的函数指针记录在 stub 中。

Stub 调用

按照 DispatchStub 基类的定义,stub 是函数对象。一种可能的路线是 dispatch_**_stub,在某些源码的位置则是直接调用。调用 stub 需要提供设备的类型作为第一个参数

调用侧,以 add 为例(可以参考 https://blog.ezyang.com/2019/05/pytorch-internals/ ):

  1. dispatch_add
  2. Tensor::add
  3. at::native::add
  4. add_stub

在文件里给了个注释:

// NB: codegenned
DECLARE_DISPATCH(structured_binary_fn_alpha, add_stub)

因为是代码生成,所以我们看不到 add_stub 的定义和实现注册。

也可以 _index_put_impl_ 为例看 stub 调用。native_functions.yaml 中写到 index_put_ 会调用到 _index_put_impl_ 上来,进去检查里面确实有 index_put_stub 调用。那么 index_put_ 是怎么联系到 _index_put_impl_ 的?

代码 中,at::native::index_put_ 调用了 at::_index_put_impl_

Tensor& index_put_(
    Tensor& self,
    const torch::List<std::optional<Tensor>>& indices,
    const Tensor& value,
    const bool accumulate) {
  return at::_index_put_impl_(
      self, indices, value, accumulate, /*unsafe=*/false);
}

因此路径为 at::native::index_put_at::_index_put_impl_index_put_stub

PyTorch 又是怎么把 torch.index_put_torch.Tensor.index_put_ 关联到 at::native::index_put_ 上来的?这就要看下一篇 PyTorch C++ 代码生成 了。