PyTorch C++ 函数派发
Stub 注册流程
所有的 stub 定义几乎都在 aten/src/ATen/native/DispatchStub.h 文件,可以慢慢看。里面有段注释:
// Implements instruction set specific function dispatch.
//
// Kernels that may make use of specialized instruction sets (e.g. AVX2) are
// compiled multiple times with different compiler flags (e.g. -mavx2). A
// DispatchStub contains a table of function pointers for a kernel. At runtime,
// the fastest available kernel is chosen based on the features reported by
// cpuinfo.
//
// Example:
//
// In native/MyKernel.h:
// using fn_type = void(*)(const Tensor& x);
// DECLARE_DISPATCH(fn_type, stub)
//
// In native/MyKernel.cpp
// DEFINE_DISPATCH(stub);
//
// In native/cpu/MyKernel.cpp:
// namespace {
// // use anonymous namespace so that different cpu versions won't conflict
// void kernel(const Tensor& x) { ... }
// }
// REGISTER_DISPATCH(stub, &kernel);
//
// To call:
// stub(kCPU, tensor);
//
// TODO: CPU instruction set selection should be folded into whatever
// the main dispatch mechanism is.
//
// Supported device types for registration:
// - CPU: Central Processing Unit
// - CUDA: NVIDIA GPUs
// - HIP: AMD GPUs
// - MPS: Apple Silicon GPUs (Metal Performance Shaders)
// - MTIA: Meta Training and Inference Devices
// - XPU: Intel GPUs
// - HPU: Reserved for HPU (Intel Gaudi) device types
// - PrivateUse1: Reserved for private/custom device types
//
// If you want to update the list of supported devices, add a new dispatch_ptr
// member in DispatchStubImpl.h and update the get_call_ptr switch.
// As well you will need to update the inlined list in 'is_device_supported`
//
//
// ignore warnings about DispatchStub::DEFAULT, AVX, AVX2 defined elsewhere
DispatchStub
模板基类定义
见 aten/src/ATen/native/DispatchStub.h。DispatchStub
类型为:
template <typename rT, typename T, typename... Args>
struct DispatchStub<rT (*)(Args...), T>;
其中主要包含几类方法,一是调用,会根据设备类型来选择函数指针,强制转换后调用:
// ...
template <typename... ArgTypes>
rT operator()(c10::DeviceType device_type, ArgTypes&&... args) {
FnPtr call_ptr = get_call_ptr(device_type);
return (*call_ptr)(std::forward<ArgTypes>(args)...);
}
// ...
二是注册指针(只挑两个设备来展示,实际上还支持很多设备),后面注册 stub 的时候这些方法会被调用到:
// ...
void set_cuda_dispatch_ptr(FnPtr fn_ptr) {
impl.cuda_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
}
void set_mps_dispatch_ptr(FnPtr fn_ptr) {
impl.mps_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
}
// ...
三是检查设备是否受支持(再深入进去看代码,实际上就是用设备类型来 switch,看对应指针是否注册):
// ...
// Returns true if the dispatcher has a kernel registered for this device
// type.
bool is_device_supported(const c10::DeviceType device_type) {
auto result = impl.try_get_call_ptr(device_type
, reinterpret_cast<void*>(DEFAULT)
#ifdef HAVE_AVX512_CPU_DEFINITION
, reinterpret_cast<void*>(AVX512)
#endif
#ifdef HAVE_AVX2_CPU_DEFINITION
, reinterpret_cast<void*>(AVX2)
#endif
#ifdef HAVE_VSX_CPU_DEFINITION
, reinterpret_cast<void*>(VSX)
#endif
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
, reinterpret_cast<void*>(ZVECTOR)
#endif
#ifdef HAVE_SVE256_CPU_DEFINITION
, reinterpret_cast<void*>(SVE256)
#endif
);
if (std::holds_alternative<ErrorType>(result)){
return false;
}
return true;
}
// ...
声明派发 DECLARE_DISPATCH
声明派发需要 stub 的名字和函数指针类型。声明过程包含对 stub 类的定义,以及对 stub 对象的声明。设备 + stub 名形成了一个独一无二的 stub 类。
#define DECLARE_DISPATCH(fn, name) \
struct name##_DECLARE_DISPATCH_type : DispatchStub<fn, name##_DECLARE_DISPATCH_type> { \
name##_DECLARE_DISPATCH_type() = default; \
name##_DECLARE_DISPATCH_type(const name##_DECLARE_DISPATCH_type&) = delete; \
name##_DECLARE_DISPATCH_type& operator=(const name##_DECLARE_DISPATCH_type&) = delete; \
name##_DECLARE_DISPATCH_type(name##_DECLARE_DISPATCH_type&&) = delete; \
name##_DECLARE_DISPATCH_type& operator=(name##_DECLARE_DISPATCH_type&&) = delete; \
~name##_DECLARE_DISPATCH_type() = default; \
}; \
extern TORCH_API struct name##_DECLARE_DISPATCH_type name;
比如 DECLARE_DISPATCH(index_put_fn, index_put_stub)
会定义一个 index_put_stub_DECLARE_DISPATCH_type
类,并且声明一个该类的对象为 index_put_stub
。这个类用 CRTP 继承 DispatchStub
类,fn
参数会用于父类的模板参数。
定义派发 DEFINE_DISPATCH
定义派发只需要一个 stub 名字。
#define DEFINE_DISPATCH(name) struct name##_DECLARE_DISPATCH_type name
比如 DEFINE_DISPATCH(index_put_stub);
会展开为:
struct index_put_stub_DECLARE_DISPATCH_type index_put_stub;
注册派发 REGISTER_DISPATCH
注册派发需要 stub 的名字和函数指针,不同设备的 REGISTER_DISPATCH
定义不同。
比如 CUDA 文件中的 REGISTER_DISPATCH(index_put_stub, &index_put_kernel)
,实际上宏展开后定义了这样的静态成员:
static RegisterCUDADispatch<struct index_put_stub_DECLARE_DISPATCH_type> \
index_put_stub__register(index_put_stub, &index_put_kernel);
看这个类的定义:
template <typename DispatchStub>
struct RegisterCUDADispatch {
RegisterCUDADispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) {
stub.set_cuda_dispatch_ptr(value);
}
};
index_put_stub__register
在静态初始化阶段构造函数会调用 set_cuda_dispatch_ptr
,从而将 CUDA 设备的函数指针记录在 stub 中。
Stub 调用
按照 DispatchStub
基类的定义,stub 是函数对象。一种可能的路线是 dispatch_*
→ *_stub
,在某些源码的位置则是直接调用。调用 stub 需要提供设备的类型作为第一个参数。
调用侧,以 add 为例(可以参考 https://blog.ezyang.com/2019/05/pytorch-internals/ ):
dispatch_add
Tensor::add
at::native::add
add_stub
在文件里给了个注释:
// NB: codegenned
DECLARE_DISPATCH(structured_binary_fn_alpha, add_stub)
因为是代码生成,所以我们看不到 add_stub
的定义和实现注册。
也可以 _index_put_impl_
为例看 stub 调用。native_functions.yaml 中写到 index_put_
会调用到 _index_put_impl_
上来,进去检查里面确实有 index_put_stub
调用。那么 index_put_
是怎么联系到 _index_put_impl_
的?
在 代码 中,at::native::index_put_
调用了 at::_index_put_impl_
。
Tensor& index_put_(
Tensor& self,
const torch::List<std::optional<Tensor>>& indices,
const Tensor& value,
const bool accumulate) {
return at::_index_put_impl_(
self, indices, value, accumulate, /*unsafe=*/false);
}
因此路径为 at::native::index_put_
→ at::_index_put_impl_
→ index_put_stub
。
PyTorch 又是怎么把 torch.index_put_
和 torch.Tensor.index_put_
关联到 at::native::index_put_
上来的?这就要看下一篇
PyTorch C++ 代码生成 了。