PyTorch C++ 代码生成
一些问题
为什么有些会生成 at::cuda 名字空间的函数,有些不会?(待解决)
提要
本文说明了 m.impl("index_put.out", ...)
到 at::native::index_put
的调用路径。结合
PyTorch C++ 函数派发 中 at::native::index_put_
→ at::_index_put_impl_
→ index_put_stub
的调用路径,补全了从 m.impl 到 stub 的全路径。
本文说明了 m.impl
/ at::cuda::index_out
→ index_stub
的调用路径。准确来说是介绍了 at::cuda::index_out
调用 meta 和 impl 的过程,meta 中对下标做预处理(包括 kBool 转 kLong 下标),impl 中调用 index_stub
进行计算。和 index_out
不同,index_put_
函数没有出现在 at::cuda
名字空间中,取而代之的是 at::cuda::_index_put_impl_
。
在这两个例子中,能找到的函数有:
理解 native_functions.yaml 中的函数定义在哪里
非 structured 情况:
- func: index_put_(...)
dispatch:
CompositeExplicitAutograd: index_put_ # 默认名字空间是 aten
\ at::index_put_ (build/aten/src/ATen/Functions.h)
\ at::_ops::index_put_::call (build/aten/src/ATen/OperatorsEverything.cpp)
\ c10::Dispatcher::singleton()
| .findSchemaOrThrow(index_put_::name, index_put_::overload_name)
| .typed<index_put_::schema>().call
|struct TORCH_API index_put_ (build/aten/src/ATen/MethodOperators.h)
| ... name = "aten::index_put_";
| ... overload_name = "";
\ m.impl("index_put_", ...) (build/aten/src/ATen/RegisterCompositeExplicitAutogradEverything.cpp)
\ at::native::index_put_ (aten/src/ATen/native/TensorAdvancedIndexing.cpp)
\ at::_index_put_impl_
结论是按照 dispatch: yy: xx 字段生成 at::xx 函数,最终调用到 at::native::xx 函数。
^^^^^^^^^^^^^^
如果是 structured: True,就对应 at::native::structured_xx::impl,在代码中通常为
at::native 名字空间下的 TORCH_IMPL_FUNC(xx)。
^^^^^^^^^^^^^^^^^^^
如果是 structured_delegate: zz,可能得去找 zz 的定义。
至于为什么有时候会生成 at::cuda 下的函数,有时候不会,这个不清楚。
生成代码
可以用 python3 -m torchgen.gen
来生成一部分代码,这个过程比较快,也不用准备好全套的构建环境。接下来以 index_put_
为例来看生成结果。
非 structured
算子:以 index_put_
为例
追踪 index_put_
的派发
在 native_functions.yaml 中:
- func: index_put_(Tensor(a!) self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor(a!)
device_check: NoCheck # delegate to _index_put_impl_, which leverages TensorIterator
variants: function, method
dispatch:
CompositeExplicitAutograd: index_put_
autogen: index_put.out
# NB: The following functions are declared in aten/src/ATen/templates/TensorBody.h and defined in aten/src/ATen/TensorIndexing.cpp:
# - Tensor & Tensor::index_put_(ArrayRef<TensorIndex> indices, Tensor const & rhs)
# - Tensor & Tensor::index_put_(ArrayRef<TensorIndex> indices, Scalar v)
# - Tensor & Tensor::index_put_(std::initializer_list<TensorIndex> indices, Tensor const & rhs)
# - Tensor & Tensor::index_put_(std::initializer_list<TensorIndex> indices, Scalar v)
首先是会有个 TORCH_LIBRARY_IMPL
,然后里面会有 m.impl("名称")
,可以按照 m.impl
来搜索。
在 build/aten/src/ATen/RegisterCompositeExplicitAutogradEverything.cpp 中(在 torch/library.h 中有 torch::Library
类型,该类型有下面看到的 impl
方法,还有其他地方常看到的 def
方法):
namespace {
at::Tensor wrapper_CompositeExplicitAutograd__index_put(const at::Tensor & self, const c10::List<::std::optional<at::Tensor>> & indices, const at::Tensor & values, bool accumulate) {
// No device check
// DeviceGuard omitted
return at::native::index_put(self, indices, values, accumulate);
}
} // anonymous namespace
namespace {
at::Tensor & wrapper_CompositeExplicitAutograd_out_index_put_out(const at::Tensor & self, const c10::List<::std::optional<at::Tensor>> & indices, const at::Tensor & values, bool accumulate, at::Tensor & out) {
// No device check
// DeviceGuard omitted
return at::native::index_put_out(self, indices, values, accumulate, out);
}
} // anonymous namespace
namespace {
at::Tensor & wrapper_CompositeExplicitAutograd__index_put_(at::Tensor & self, const c10::List<::std::optional<at::Tensor>> & indices, const at::Tensor & values, bool accumulate) {
// No device check
// DeviceGuard omitted
return at::native::index_put_(self, indices, values, accumulate);
}
} // anonymous namespace
TORCH_LIBRARY_IMPL(aten, CompositeExplicitAutograd, m) {
m.impl("index_put",
TORCH_FN(wrapper_CompositeExplicitAutograd__index_put));
m.impl("index_put.out",
TORCH_FN(wrapper_CompositeExplicitAutograd_out_index_put_out));
m.impl("index_put_",
TORCH_FN(wrapper_CompositeExplicitAutograd__index_put_));
}
从这里我们发现一个规律:很多在 native_functions.yaml 中注册的函数会转而调用 at::native::
中的函数,这些函数是在 aten/src/ATen/native/ 文件夹下有定义的,比如 at::native::index_put
和 at::native::index_put_
。也有一些函数比如 at::native::index_put_out
是用代码生成的,如 native_functions.yaml 中 index_put_
函数的 autogen 字段所述。
我们先记住 m.impl
能记录函数派发,暂时不展开。
追踪 index_put.out
的 autogen
前面我们已经看到了 index_put
对应 at::native::index_put
,index_put_
对应 at::native::index_put_
,而且都能在 aten/src/ATen/native/TensorAdvancedIndexing.cpp 中找到定义,但是 at::native::index_put_out
是找不到定义的。这一个小节用来阅读 autogen
字段请求生成的代码。
代码生成后,在 build/aten/src/ATen/CompositeViewCopyKernels.cpp 中(和 autogen: index_put.out
对应):
// namespace at::native
at::Tensor & index_put_out(const at::Tensor & self, const c10::List<::std::optional<at::Tensor>> & indices, const at::Tensor & values, bool accumulate, at::Tensor & out) {
auto tmp_output = at::_ops::index_put::call(self, indices, values, accumulate);
resize_out_helper(out, tmp_output);
copy_arg(out, tmp_output);
return out;
}
在 build/aten/src/ATen/OperatorsEverything.cpp 中:
// namespace at::_ops
// aten::index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor
static C10_NOINLINE c10::TypedOperatorHandle<index_put::schema> create_index_put_typed_handle() {
return c10::Dispatcher::singleton()
.findSchemaOrThrow(index_put::name, index_put::overload_name)
.typed<index_put::schema>();
}
// ✅
// aten::index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor
at::Tensor index_put::call(const at::Tensor & self, const c10::List<::std::optional<at::Tensor>> & indices, const at::Tensor & values, bool accumulate) {
static auto op = create_index_put_typed_handle();
return op.call(self, indices, values, accumulate);
}
// aten::index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor
at::Tensor index_put::redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const c10::List<::std::optional<at::Tensor>> & indices, const at::Tensor & values, bool accumulate) {
static auto op = create_index_put_typed_handle();
return op.redispatch(dispatchKeySet, self, indices, values, accumulate);
}
这个 findSchemaOrThrow
函数在 aten/src/ATen/core/dispatch/Dispatcher.cpp 中定义。aten/src/ATen/core/dispatch/Dispatcher.h 则有对函数接口较为详细的说明。具体来说会从自己的 table 来查找字符串名,看看有没有结果,具体来说是 operatorLookupTable_.read
。
这里还用到了 index_put
,可以搜索 struct TORCH_API index_put
:
定义为:
struct TORCH_API index_put {
using schema = at::Tensor (const at::Tensor &, const c10::List<::std::optional<at::Tensor>> &, const at::Tensor &, bool);
using ptr_schema = schema*;
// See Note [static constexpr char* members for windows NVCC]
static constexpr const char* name = "aten::index_put";
static constexpr const char* overload_name = "";
static constexpr const char* schema_str = "index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor";
static at::Tensor call(const at::Tensor & self, const c10::List<::std::optional<at::Tensor>> & indices, const at::Tensor & values, bool accumulate);
static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const c10::List<::std::optional<at::Tensor>> & indices, const at::Tensor & values, bool accumulate);
};
总体来说调用链是:
m.impl("index_put.out", ...) // build/aten/src/ATen/RegisterCompositeExplicitAutogradEverything.cpp
at::native::index_put_out
at::_ops::index_put::call
...
所以说 at::native::index_put_out
实际上会派发给 at::native::index_put
。
是谁把算子写进查找表的呢?
相对地,Dispatcher::findOrRegisterName_
会调用 operatorLookupTable_.write
,这样就能将算子注册进去。这个函数被 register{Name,Impl,Def}
调用,可以按照扩展正则表达式搜索 \.register(Name|Impl|Def)\(
。然后发现 RegisterOperators::registerOp_
在调用 registerDef
和 registerImpl
。
链路:
explicit RegisterOperators(const std::string&, FuncType&&, Options&&) // (可选)
c10::RegisterOperators::op
c10::RegisterOperators::checkSchemaAndRegisterOp_
c10::RegisterOperators::registerOp_ // 会获取 Dispatcher 的单例,往里面加东西
c10::Dispatcher::singleton().register{Name,Def,Impl}
c10::Dispatcher::findOrRegisterName_
c10::Dispatcher::operatorLookupTable_.write
RegisterOperators
似乎是个临时类型,拿来注册用的。
还有一条是 Library
的路径:
torch::Library::_{impl,def}
c10::RegisterOperators::registerOp_ // 会获取 Dispatcher 的单例,往里面加东西
c10::Dispatcher::singleton().register{Name,Def,Impl}
c10::Dispatcher::findOrRegisterName_
c10::Dispatcher::operatorLookupTable_.write
aten/src/ATen/core/op_registration/README.md 有相关的说明,可以看看理清思路。
回忆一下,之前已经有 m.impl("index_put", ...)
了,它转而调用了 at::native::index_put
,因此这个表项已经被注册过了。现在可以把调用链补充完整:
// build/aten/src/ATen/RegisterCompositeExplicitAutogradEverything.cpp
m.impl("index_put.out", ...)
at::native::index_put_out // build/aten/src/ATen/CompositeViewCopyKernels.cpp
at::_ops::index_put::call // build/aten/src/ATen/OperatorsEverything.cpp
|
| find op from dispatcher
|
// build/aten/src/ATen/RegisterCompositeExplicitAutogradEverything.cpp
m.impl("index_put", ...)
// build/aten/src/ATen/RegisterCompositeExplicitAutogradEverything.cpp
TORCH_FN(wrapper_CompositeExplicitAutograd__index_put)
// aten/src/ATen/native/TensorAdvancedIndexing.cpp
at::native::index_put
通过再次派发,autogen 生成的函数能调用原函数,从而减少了编写接口的负担。
structured
算子:以 index_out
为例
at::cuda::index_out
这个函数被 at::native::masked_select_out_cuda_impl
调用了,我之前也一直疑惑 masked_select
是在哪里调用 nonzero
的,现在就通过这个例子来查找调用路径。我们先来看 at::cuda::index_out
函数的实现,接着看 m.impl
是怎么实现的。
在 native_functions.yaml 中找到的最相关的记录为:
- func: index.Tensor_out(Tensor self, Tensor?[] indices, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck
structured: True
structured_inherits: TensorIteratorBase
precomputed:
- indices -> DimVector sizes, DimVector strides
dispatch:
CPU, CUDA, MPS: index_out
注意这个 structured
标记。很多算子在 native_functions.yaml 中都会有 structured
标记,或者 structured_delegate
标记。在代码生成之后,相关声明也会生成在 build/aten/src/ATen/NativeFunctions.h 文件中。按我理解,structured
把算子的运算流程分离成几步,允许开发者对每个方法分别定义。
生成的代码
在 build/aten/src/ATen/RegisterCUDAEverything.cpp 中:
namespace at {
at::Tensor & wrapper_CUDA_index_out_Tensor_out(const at::Tensor & self, const c10::List<::std::optional<at::Tensor>> & indices, at::Tensor & out) {
// No device check
structured_index_out_out op(out);
auto precompute = op.meta(self, at::IOptTensorListRef(indices));
(void)precompute;
op.impl(self, precompute.sizes, precompute.strides, op.maybe_get_output(0));
if (op.proxy_outputs_[0].has_value()) op.outputs_[0].get().copy_(*op.proxy_outputs_[0]);
return out;
}
namespace cuda {
at::Tensor & index_out(at::Tensor & out, const at::Tensor & self, const c10::List<::std::optional<at::Tensor>> & indices) {
return wrapper_CUDA_index_out_Tensor_out(self, indices, out);
}
}
}
最主要的是 structured_index_out_out
这个类,还有 meta
和 impl
两个方法。
impl
先看定义(impl
)。在 aten/src/ATen/native/TensorAdvancedIndexing.cpp 中:
TORCH_IMPL_FUNC(index_out)
(const Tensor& self, DimVector sizes, DimVector strides, const Tensor& result) {
index_stub(device_type(), *this, sizes, strides);
}
而 TORCH_IMPL_FUNC
定义如下:
#define TORCH_IMPL_FUNC(name) void structured_##name::impl
因此会展开成这样(因为是宏展开,代码里面搜不到):
void structured_index_out::impl
(const Tensor& self, DimVector sizes, DimVector strides, const Tensor& result) {
index_stub(device_type(), *this, sizes, strides);
}
对应的声明是代码生成的,在 build/aten/src/ATen/NativeFunctions.h 中:
struct TORCH_API structured_index_out : public at::meta::structured_index_Tensor {
void impl(const at::Tensor & self, at::DimVector sizes, at::DimVector strides, const at::Tensor & out);
};
meta
接着看预处理(meta
)。在 aten/src/ATen/TensorMeta.h 中:
// ...
#define TORCH_PRECOMPUTE_META_FUNC(name) \
structured_##name::meta_return_ty structured_##name::meta
#define TORCH_PRECOMPUTE_META_FUNC2(name, overload) \
structured_##name##_##overload::meta_return_ty \
structured_##name##_##overload::meta
// Use this to create a precompute struct in a meta function.
#define TORCH_PRECOMPUTE_STRUCT(name) structured_##name::precompute_out<>
#define TORCH_PRECOMPUTE_STRUCT2(name, overload) \
structured_##name##_##overload::precompute_out<>
在 aten/src/ATen/native/TensorAdvancedIndexing.cpp 中:
TORCH_PRECOMPUTE_META_FUNC2(index, Tensor)
(const Tensor& self, at::IOptTensorListRef indices) {
// ...
auto info = at::native::make_info(self, std::move(indices));
build_index_op(*this, info, result);
return TORCH_PRECOMPUTE_STRUCT2(index, Tensor)()
.set_sizes(std::move(info.indexed_sizes))
.set_strides(std::move(info.indexed_strides));
}
因此 structured_index_Tensor
的 meta
函数能对 indices 进行预计算,在 make_info
中将 BoolTensor 转换成 LongTensor(代码略),并获取 sizes、strides 信息。
放在一起看
我们找到的 impl 是 structured_index_out::impl
,找到的 meta 是 structured_index_Tensor::meta
,还要将他们和 structured_index_out_out
联系起来。
在 build/aten/src/ATen/RegisterCUDAEverything.cpp 中:
struct structured_index_out_out final : public at::native::structured_index_out
因为继承关系,structured_index_out_out
自然得到了 structured_index_out::impl
。
在 build/aten/src/ATen/NativeFunctions.h 中:
struct TORCH_API structured_index_out : public at::meta::structured_index_Tensor
因为继承关系,structured_index_out
又得到了 structured_index_Tensor::meta
。
继承链为:
- at::TensorIteratorBase
- at::meta::structured_index_Tensor (provides meta)
- at::native::structured_index_out (provides impl)
- at::native::structured_index_out_out
所以 masked_select 调用了 index_out,后者按顺序调用 structured_index_out_out 的 meta 和 impl 方法。meta 方法来自 at::meta::structured_index_Tensor::meta,内部会调用 expandTensors 函数将 mask 转换成 LongTensor 下标(代码略)。impl 方法来自 at::native::structured_index_out::impl,会调用 index_stub。
再来看 m.impl
(at
名字空间,而非 at::cuda
)
也是结构化地先 meta 后 impl 的实现方式,和非 structured 的算子生成方式不同。这里在 wrapper 中没有调用 at::native 名字空间函数(at::native::index_out
),所以不需要为其提供定义,代码其他地方也找不到。回忆前面,index_out 的定义实际上是通过 TORCH_IMPL_FUNC(index_out)
提供的,即 at::native::structured_index_out::impl
,并非 at::native::index_out
函数。
namespace at {
at::Tensor & wrapper_CUDA_index_out_Tensor_out(const at::Tensor & self, const c10::List<::std::optional<at::Tensor>> & indices, at::Tensor & out) {
// No device check
structured_index_out_out op(out);
auto precompute = op.meta(self, at::IOptTensorListRef(indices));
(void)precompute;
op.impl(self, precompute.sizes, precompute.strides, op.maybe_get_output(0));
if (op.proxy_outputs_[0].has_value()) op.outputs_[0].get().copy_(*op.proxy_outputs_[0]);
return out;
}
TORCH_LIBRARY_IMPL(aten, CUDA, m) {
m.impl("index.Tensor", TORCH_FN(wrapper_CUDA_index_Tensor));
m.impl("index.Tensor_out", TORCH_FN(wrapper_CUDA_index_out_Tensor_out));
}
}
注意到这个 at::wrapper_CUDA_index_out_Tensor_out
在前面也用到了!实际上就是相同的实现,在 m.impl
用到了,在 at::cuda::index_out
中也用到了。