文中涉及到大量的Pytorch的C++源码,版本为1.4.0a,适合有一定Pytorch源码基础的童鞋观看,同时也涉及到一些python中的C/C++拓展的一些基础知识,其中每一段代码的第一行表明了该代码的文件位置。需要注意有些代码是自动生成的,原始工程中并没有,需要编译。
还要注意一点,因为Pytorch仍在积极开发中,所以代码接口变化还是比较频繁,当你看到本文的时候,有可能展示的源码与master版的略有不同,但是大部分的代码逻辑变动不大,我们只需要知道核心工作原理即可。
那开始吧!
现在有一个Tensor,不,是两个,创建两个rand后的tensor然后加起来。
import torch res = torch.rand(3, 4)[0] + torch.rand(3, 4)
执行后输出:
tensor([[0.3091, 0.5503, 1.0780, 0.9044], [0.5770, 0.5245, 0.3225, 1.4672], [0.1581, 1.0439, 0.3313, 0.9924]])
呃,输出不重要,先将上述代码细分下:
_t1 = torch.rand(3, 4) _t2 = _t1.__getitem__(0) del _t1 _t3 = torch.rand(3, 4) res = _t2.__add__(_t3) del _t2 del _t3 # 最后res还在
看第一句发生了什么:
_t1 = torch.rand(3, 4) # <-- _t2 = _t1.__getitem__(0) del _t1 _t3 = torch.rand(3, 4) res = _t2.__add__(_t3) del _t2 del _t3
其实torch.rand
在torch_C._VariableFunctions
这个模块中,torch.rand
不是一个python的函数,只是一个模块中方法的名称,通过torch.rand
调用torch
模块中的rand
方法,而这个模块是通过python的C/C++拓展机制生成的,实际中torch.rand
对应的代码是通过一个yaml文本自动生成的。
这个文件是一个自动生成代码函数的参数列表,Pytorch源码中有很多的代码文件是通过gen.py
自动生成的,至于为什么要自动生成,是因为很多的函数代码比较相似,重复性较多,通过自动生成可以避免大部分重复的工作量。
// aten/src/ATen/native/native_functions.yaml - func: scalar_tensor(Scalar s, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - func: rand(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - func: rand(int[] size, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - func: rand(int[] size, *, Tensor(a!) out) -> Tensor(a!) - func: rand(int[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!) - func: rand_like(Tensor self) -> Tensor - func: rand_like(Tensor self, *, ScalarType dtype, Layout layout, Device device, bool pin_memory=False) -> Tensor
通过上述的自动生成代码文件,在如下的代码的${py_method_defs}
的位置生成rand
以及其他函数的方法。
// tools/autograd/templates/python_torch_functions.cpp static PyMethodDef torch_functions[] = { {"arange", (PyCFunction)THPVariable_arange, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, {"as_tensor", (PyCFunction)THPVariable_as_tensor, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, {"dsmm", (PyCFunction)THPVariable_mm, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, {"from_numpy", (PyCFunction)THPVariable_from_numpy, METH_STATIC | METH_O, NULL}, {"hsmm", (PyCFunction)THPVariable_hspmm, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, {"_promote_types", (PyCFunction)THPVariable__promote_types, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, {"nonzero", (PyCFunction)THPVariable_nonzero, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, {"randint", (PyCFunction)THPVariable_randint, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, {"range", (PyCFunction)THPVariable_range, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, {"saddmm", (PyCFunction)THPVariable_sspaddmm, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, {"sparse_coo_tensor", (PyCFunction)THPVariable_sparse_coo_tensor, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, {"spmm", (PyCFunction)THPVariable_mm, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, {"tensor", (PyCFunction)THPVariable_tensor, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, {"get_device", (PyCFunction)THPVariable_get_device, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, ${py_method_defs} {NULL} };
将上述的native_functions.yaml
中的函数参数通过生成机制,在上述代码的${py_method_defs}
位置,生成新的代码以及新的文件,我们可以看到我们的"rand"
:
//torch/csrc/autograd/generated/python_torch_functions.cpp static PyMethodDef torch_functions[] = { {"arange", (PyCFunction)(void(*)(void))THPVariable_arange, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, {"as_tensor", (PyCFunction)(void(*)(void))THPVariable_as_tensor, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, {"dsmm", (PyCFunction)(void(*)(void))THPVariable_mm, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, {"from_numpy", (PyCFunction)THPVariable_from_numpy, METH_STATIC | METH_O, NULL}, {"hsmm", (PyCFunction)(void(*)(void))THPVariable_hspmm, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, {"nonzero", (PyCFunction)(void(*)(void))THPVariable_nonzero, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, {"randint", (PyCFunction)(void(*)(void))THPVariable_randint, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, {"range", (PyCFunction)(void(*)(void))THPVariable_range, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, {"saddmm", (PyCFunction)(void(*)(void))THPVariable_sspaddmm, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, {"sparse_coo_tensor", (PyCFunction)(void(*)(void))THPVariable_sparse_coo_tensor, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, {"spmm", (PyCFunction)(void(*)(void))THPVariable_mm, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, {"tensor", (PyCFunction)(void(*)(void))THPVariable_tensor, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, {"get_device", (PyCFunction)(void(*)(void))THPVariable_get_device, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, // 这部分以上与上面的代码相同,下面为自动生成的代码 {"numel", (PyCFunction)(void(*)(void))THPVariable_numel, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, {"__and__", (PyCFunction)(void(*)(void))THPVariable___and__, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, ... {"quantized_rnn_tanh_cell", (PyCFunction)(void(*)(void))THPVariable_quantized_rnn_tanh_cell, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, {"rand", (PyCFunction)(void(*)(void))THPVariable_rand, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, {"rand_like", (PyCFunction)(void(*)(void))THPVariable_rand_like, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, {"randint_like", (PyCFunction)(void(*)(void))THPVariable_randint_like, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, {"randn", (PyCFunction)(void(*)(void))THPVariable_randn, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, {"randn_like", (PyCFunction)(void(*)(void))THPVariable_randn_like, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, {"randperm", (PyCFunction)(void(*)(void))THPVariable_randperm, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, ... {"zeros", (PyCFunction)(void(*)(void))THPVariable_zeros, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, {"zeros_like", (PyCFunction)(void(*)(void))THPVariable_zeros_like, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, {NULL} };
由上面代码可以看到,"rand"
对应的绑定函数为THPVariable_rand
。具体探究这个函数之前,我们首先需要初始化,因为这个函数要绑定在python端,将上述的一堆方法(tp_methods)与类型对象(PyTypeObject)绑定:
// tools/autograd/templates/python_torch_functions.cpp static PyTypeObject THPVariableFunctions = { PyVarObject_HEAD_INIT(NULL, 0) "torch._C._VariableFunctions", /* tp_name */ Py_TPFLAGS_DEFAULT, /* tp_flags */ NULL, /* tp_doc */ torch_functions, /* tp_methods */ ... };
然后进行初始化,将上述的类型对象初始化为python中的模块:
void initTorchFunctions(PyObject* module) { if (PyType_Ready(&THPVariableFunctions) < 0) { throw python_error(); } Py_INCREF(&THPVariableFunctions); if (PyModule_AddObject(module, "_VariableFunctions", (PyObject*)&THPVariableFunctions) < 0) { throw python_error(); } }
这样我们在python端调用的时候会在生成的torch_C._VariableFunctions
中找这个方法:
for name in dir(_C._VariableFunctions): if name.startswith('__'): continue globals()[name] = getattr(_C._VariableFunctions, name)
好,现在我们具体讨论一下{"rand", (PyCFunction)(void(*)(void))THPVariable_rand, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}
这个方法对应的函数吧。
// torch/csrc/autograd/generated/python_torch_functions.cpp static PyObject * THPVariable_rand(PyObject* self_, PyObject* args, PyObject* kwargs) { HANDLE_TH_ERRORS static PythonArgParser parser({ "rand(IntArrayRef size, *, DimnameList? names, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool pin_memory=False, bool requires_grad=False)", "rand(IntArrayRef size, *, Generator generator, DimnameList? names, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool pin_memory=False, bool requires_grad=False)", "rand(IntArrayRef size, *, Generator generator, Tensor out=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool pin_memory=False, bool requires_grad=False)", "rand(IntArrayRef size, *, Tensor out=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool pin_memory=False, bool requires_grad=False)", }, /*traceable=*/true); ParsedArgs<9> parsed_args; auto r = parser.parse(args, kwargs, parsed_args); if (r.idx == 0) { auto size = r.intlist(0); auto __names = r.toDimnameListOptional(1); c10::optional<DimnameList> names = __names ? c10::make_optional(DimnameList(__names.value())) : c10::nullopt; auto dtype = r.scalartype(2); auto device = r.device(4); const auto options = TensorOptions() .dtype(dtype) .device(device) .layout(r.layout(3).layout) .requires_grad(r.toBool(6)) .pinned_memory(r.toBool(5)); return wrap(dispatch_rand(size, names, options)); ... } else if (r.idx == 3) { // 最终执行到这一个分支 if (r.isNone(1)) { auto size = r.intlist(0); auto dtype = r.scalartype(2); auto device = r.device(4); const auto options = TensorOptions() .dtype(dtype) .device(device) .layout(r.layout(3).layout) .requires_grad(r.toBool(6)) .pinned_memory(r.toBool(5)); return wrap(dispatch_rand(size, options)); } else { check_out_type_matches(r.tensor(1), r.scalartype(2), r.isNone(2), r.layout(3), r.isNone(3), r.device(4), r.isNone(4)); return wrap(dispatch_rand(r.intlist(0), r.tensor(1)).set_requires_grad(r.toBool(6))); } } Py_RETURN_NONE; END_HANDLE_TH_ERRORS }
可以看到上述函数最终实际执行的是dispatch_rand
,这里需要注意,这个函数释放了GIL锁,这会使当前的执行代码和python中执行的代码互不影响:
// torch/csrc/autograd/generated/python_torch_functions_dispatch.h inline Tensor dispatch_rand(IntArrayRef size, const TensorOptions & options) { maybe_initialize_cuda(options); /* 释放GIL锁 */ AutoNoGIL no_gil; return torch::rand(size, generator, options); }
然后我们进入torch::rand
,这里有一点需要注意,在torch::rand
这个函数中我们最终返回的是autograd::make_variable
后的tensor,也就是说如果我们不需要differentiable的tensor的话,是可以直接返回at::rand
。
这也就是为什么在Pytorch的C++前端中提到如果直接使用at::rand
构造的Tensor是没有自动求导功能的:
// torch/csrc/autograd/generated/variable_factories.h inline at::Tensor rand(at::IntArrayRef size, const at::TensorOptions & options = {}) { torch::jit::Node* node = nullptr; std::shared_ptr<jit::tracer::TracingState> tracer_state; if (jit::tracer::isTracing()) { // 这个分支不会进入,因为我们并没有使用Jit tracer_state = jit::tracer::getTracingState(); at::Symbol op_name; op_name = jit::Symbol::fromQualString("aten::rand"); node = tracer_state->graph->create(op_name, /*num_outputs=*/0); jit::tracer::recordSourceLocation(node); jit::tracer::addInputs(node, "size", size); jit::tracer::addInputs(node, "options", options); tracer_state->graph->insertNode(node); jit::tracer::setTracingState(nullptr); } at::Tensor tensor = ([&]() { at::AutoNonVariableTypeMode non_var_type_mode(true); return at::rand(size, at::TensorOptions(options).is_variable(false)); })(); at::Tensor result = autograd::make_variable(std::move(tensor), /*requires_grad=*/options.requires_grad()); if (tracer_state) { jit::tracer::setTracingState(std::move(tracer_state)); jit::tracer::addOutput(node, result); } return result; }
然后我们继续进入at::rand
:
// build/aten/src/ATen/Functions.h static inline Tensor rand(IntArrayRef size, const TensorOptions & options) { #ifdef USE_STATIC_DISPATCH return TypeDefault::rand(size, options); #else // 从以下开始执行 globalLegacyTypeDispatch().initForTensorTypeSet(at::detail::multi_dispatch_tensor_type_set(options)); static auto table = globalATenDispatch().getOpTable("aten::rand(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"); return table->callUnboxed<Tensor, IntArrayRef, const TensorOptions &>(size, options); #endif }
我们可以看到上述的代码中getOpTable
函数的参数是我们具体调用函数的一些string,也就是说getOpTable
方法可以根据字符串类型的表示找到相对应的函数,我们先看看globalATenDispatch()
是什么:
// aten/src/ATen/core/ATenDispatch.cpp // 像不像单例? ATenDispatch & globalATenDispatch() { static ATenDispatch singleton; return singleton; }
很像设计模式中的单例模式吧?那ATenDispatch
又是什么?看到registerOp
这个方法没,多么的熟悉啊,显然这是一个op注册机制的类。
// aten/src/ATen/core/ATenDispatch.h class CAFFE2_API ATenDispatch { public: template<class FuncType> ATenDispatch& registerOp(TensorTypeId id, const char* schema, FuncType* fn) { std::lock_guard<std::mutex> lock(mutex_); if (op_tables_.find(schema) == op_tables_.end()) { op_tables_.insert(std::make_pair(schema, ATenOpTable(schema))); } op_tables_.at(schema).registerOp(id, reinterpret_cast<void*>(fn)); return *this; } ATenDispatch& registerFallbackBoxedOp(TensorTypeId id, FallbackBoxedFunction* fn) { std::lock_guard<std::mutex> lock(mutex_); boxed_fallback_table_[static_cast<size_t>(id)] = fn; return *this; } const ATenOpTable* getOpTable(const char* schema) const { auto iter = op_tables_.find(schema); TORCH_CHECK(iter != op_tables_.end(), "No functions are registered for schema ", schema); return &iter->second; } FallbackBoxedFunction* getFallbackBoxedOp(TensorTypeId tid) const { return boxed_fallback_table_[static_cast<size_t>(tid)]; } private: std::unordered_map<std::string, ATenOpTable> op_tables_; FallbackBoxedFunction* boxed_fallback_table_[static_cast<int64_t>(TensorTypeId::NumTensorIds)] = {nullptr}; std::mutex mutex_; };
那么与std::string
组成map的ATenOpTable
又是什么呢?下面的介绍已经比较清楚了,这个类储存了不同backend下的实现方法,同时也可以应用于Variables。
// ATenOpTable stores the implementations for each backend, in addition to // an implementation for variables. // aten/src/ATen/core/ATenDispatch.h class CAFFE2_API ATenOpTable { public: ATenOpTable(std::string schema) : schema_(std::move(schema)) {} // NB: No universal forwarding template<class Result, class... Args> Result callUnboxed(Args... args) const; private: void registerOp(TensorTypeId tid, void* fn) { TORCH_CHECK(function_table_[static_cast<int64_t>(tid)] == nullptr, "Attempting to register function for schema ", schema_, " and tensor type ", toString(tid), " but there is already a function registered"); function_table_[static_cast<int64_t>(tid)] = fn; } C10_NORETURN void reportError(TensorTypeId tid) const; friend class ATenDispatch; std::string schema_; void* function_table_[static_cast<int64_t>(TensorTypeId::NumTensorIds)] = {nullptr}; };
好了,回到上面rand
函数中的最后一句return table->callUnboxed<Tensor, IntArrayRef, const TensorOptions &>(size, options);
。我们可以看到table就是ATenOpTable
类的一个实例,而callUnboxed
是它的一个方法,这个方法根据传递的模板参数返回了特定的函数:
// build/aten/src/ATen/TypeDefault.cpp Tensor rand(IntArrayRef size, const TensorOptions & options) { const DeviceGuard device_guard(options.device()); return at::native::rand(size, options); }
进入at::native::rand
:
// aten/src/ATen/native/TensorFactories.cpp Tensor rand(IntArrayRef size, const TensorOptions& options) { return native::rand(size, nullptr, options); }
进入native::rand
:
// aten/src/ATen/native/TensorFactories.cpp Tensor rand(IntArrayRef size, Generator* generator, const TensorOptions& options) { auto result = at::empty(size, options); return result.uniform_(0, 1, generator); }
进入at::empty
:
// build/aten/src/ATen/Functions.h static inline Tensor empty(IntArrayRef size, const TensorOptions & options, c10::optional<MemoryFormat> memory_format) { #ifdef USE_STATIC_DISPATCH switch(tensorTypeIdToBackend(impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(options)))) { case Backend::CPU: return CPUType::empty(size, options, memory_format); break; case Backend::SparseCPU: return SparseCPUType::empty(size, options, memory_format); break; default: AT_ERROR("empty not implemented for ", at::toString(at::detail::multi_dispatch_tensor_type_set(options))); } #else globalLegacyTypeDispatch().initForTensorTypeSet(at::detail::multi_dispatch_tensor_type_set(options)); static auto table = globalATenDispatch().getOpTable("aten::empty.memory_format(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor"); return table->callUnboxed<Tensor, IntArrayRef, const TensorOptions &, c10::optional<MemoryFormat>>(size, options, memory_format); #endif }
这次继续按照之前的方式来找到"aten::empty.memory_format(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor"
这个op。需要注意这个op函数也是自动生成的,对应不同的backend。
// aten/src/ATen/native/native_functions.yaml - func: empty.memory_format(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor dispatch: CPU: empty_cpu CUDA: empty_cuda MkldnnCPU: empty_mkldnn SparseCPU: empty_sparse SparseCUDA: empty_sparse
所以最终table中Unbox后的函数为:
// build/aten/src/ATen/CPUType.cpp Tensor empty(IntArrayRef size, const TensorOptions & options, c10::optional<MemoryFormat> memory_format) { const DeviceGuard device_guard(options.device()); return at::native::empty_cpu(size, options, memory_format); }
我们进入at::native::empty_cpu
。
// aten/src/ATen/native/TensorFactories.cpp Tensor empty_cpu(IntArrayRef size, const TensorOptions& options, c10::optional<c10::MemoryFormat> optional_memory_format) { AT_ASSERT(options.device().type() == DeviceType::CPU); AT_ASSERT(!options.is_variable()); // is_variable should have been 'unpacked' // TODO: remove this when Variable and Tensor are merged check_size_nonnegative(size); c10::Allocator* allocator; if (options.pinned_memory()) { allocator = detail::getCUDAHooks().getPinnedMemoryAllocator(); } else { allocator = at::getCPUAllocator(); // 执行这句 } int64_t nelements = prod_intlist(size); auto dtype = options.dtype(); auto storage_impl = c10::make_intrusive<StorageImpl>( dtype, nelements, allocator->allocate(nelements * dtype.itemsize()), allocator, /*resizeable=*/true); auto tensor = detail::make_tensor<TensorImpl>(std::move(storage_impl), at::TensorTypeId::CPUTensorId); // Default TensorImpl has size [0] if (size.size() != 1 || size[0] != 0) { tensor.unsafeGetTensorImpl()->set_sizes_contiguous(size); } auto memory_format = optional_memory_format.value_or(MemoryFormat::Contiguous); tensor.unsafeGetTensorImpl()->empty_tensor_restride(memory_format); return tensor; }
这个时候因为我们调用的是CPU版的empty,这时需要在CPU申请空间,首先得到正确的申请空间的“方法”函数,赋予c10::Allocator*
:
先是这个:
// aten/src/ATen/Context.cpp Allocator* getCPUAllocator() { return getTHDefaultAllocator(); }
深入:
// aten/src/TH/THAllocator.cpp at::Allocator* getTHDefaultAllocator() { return c10::GetCPUAllocator(); }
再深入:
// c10/core/CPUAllocator.cpp at::Allocator* GetCPUAllocator() { return GetAllocator(DeviceType::CPU); }
再深入,可以发现这个allocator是allocator_array
中的一个,在下面的函数GetAllocator
中根据索引标号来取出:
// c10/core/Allocator.cpp at::Allocator* GetAllocator(const at::DeviceType& t) { auto* alloc = allocator_array[static_cast<int>(t)]; AT_ASSERTM(alloc, "Allocator for ", t, " is not set."); return alloc; }
allocator_array
这个东西是怎么来的?其实它是一个全局变量,用来储存各种allocator,同时配备了SetAllocator
和GetAllocator
来设置和获取相应的分配器:
// c10/core/Allocator.cpp C10_API at::Allocator* allocator_array[at::COMPILE_TIME_MAX_DEVICE_TYPES]; // 通过SetAllocator函数将设备的类型索引与alloc联系成一个哈希表 void SetAllocator(at::DeviceType t, at::Allocator* alloc) { allocator_array[static_cast<int>(t)] = alloc; }
然后使用REGISTER_ALLOCATOR
来注册这些alloc。
// c10/core/Allocator.h template <DeviceType t> struct AllocatorRegisterer { explicit AllocatorRegisterer(Allocator* alloc) { SetAllocator(t, alloc); } }; #define REGISTER_ALLOCATOR(t, f) \ namespace { \ static AllocatorRegisterer<t> g_allocator_d(f); \ }
比如,把DefaultCPUAllocator
注册为DeviceType::CPU
,而DeviceType::CPU
就是枚举成员,对应着一个数字。
// c10/core/CPUAllocator.cpp static DefaultCPUAllocator g_cpu_alloc; REGISTER_ALLOCATOR(DeviceType::CPU, &g_cpu_alloc);
而DefaultCPUAllocator
就是我们在CPU中开辟空间实际要调用的alloc类,它继承了at::Allocator
:
// c10/core/CPUAllocator.cpp struct C10_API DefaultCPUAllocator final : at::Allocator { DefaultCPUAllocator() {} ~DefaultCPUAllocator() override {} at::DataPtr allocate(size_t nbytes) const override { void* data = alloc_cpu(nbytes); if (FLAGS_caffe2_report_cpu_memory_usage && nbytes > 0) { getMemoryAllocationReporter().New(data, nbytes); return {data, data, &ReportAndDelete, at::Device(at::DeviceType::CPU)}; } return {data, data, &free_cpu, at::Device(at::DeviceType::CPU)}; } static void ReportAndDelete(void* ptr) { if (!ptr) { return; } getMemoryAllocationReporter().Delete(ptr); free_cpu(ptr); } at::DeleterFnPtr raw_deleter() const override { if (FLAGS_caffe2_report_cpu_memory_usage) { return &ReportAndDelete; } return &free_cpu; } protected: static MemoryAllocationReporter& getMemoryAllocationReporter() { static MemoryAllocationReporter reporter_; return reporter_; } };
其中实际的开辟函数alloc_cpu
和free_cpu
,这两个函数在开辟空间和删除空间的时候会被调用:
// c10/core/CPUAllocator.cpp void* alloc_cpu(size_t nbytes) { if (nbytes == 0) { return nullptr; } // We might have clowny upstream code that tries to alloc a negative number // of bytes. Let's catch it early. CAFFE_ENFORCE( ((ptrdiff_t)nbytes) >= 0, "alloc_cpu() seems to have been called with negative number: ", nbytes); void* data; #ifdef __ANDROID__ data = memalign(gAlignment, nbytes); #elif defined(_MSC_VER) data = _aligned_malloc(nbytes, gAlignment); #else int err = posix_memalign(&data, gAlignment, nbytes); if (err != 0) { CAFFE_THROW( "DefaultCPUAllocator: can't allocate memory: you tried to allocate ", nbytes, " bytes. Error code ", err, " (", strerror(err), ")"); } #endif CAFFE_ENFORCE( data, "DefaultCPUAllocator: not enough memory: you tried to allocate ", nbytes, " bytes. Buy new RAM!"); // move data to a thread's NUMA node NUMAMove(data, nbytes, GetCurrentNUMANode()); CHECK( !FLAGS_caffe2_cpu_allocator_do_zero_fill || !FLAGS_caffe2_cpu_allocator_do_junk_fill) << "Cannot request both zero-fill and junk-fill at the same time"; if (FLAGS_caffe2_cpu_allocator_do_zero_fill) { memset(data, 0, nbytes); } else if (FLAGS_caffe2_cpu_allocator_do_junk_fill) { memset_junk(data, nbytes); } return data; } void free_cpu(void* data) { #ifdef _MSC_VER _aligned_free(data); #else free(data); #endif }
接着继续回到at::native::empty_cpu
,因为empty_cpu要构建tensor变量,而tensor变量首先需要对应的storage,也就是Tensor中的实际储存的地址,而StorageImpl
是继承intrusive_ptr_target
的一个子类。实际代码中通过c10::make_intrusive
构建出storage_impl
:
Tensor empty_cpu(IntArrayRef size, const TensorOptions& options) { ...... int64_t nelements = prod_intlist(size); auto dtype = options.dtype(); auto storage_impl = c10::make_intrusive<StorageImpl>( dtype, nelements, allocator->allocate(nelements * dtype.itemsize()), allocator, /*resizeable=*/true);
make_intrusive
是模板元函数,其中TTarget
即传递过来的StorageImpl
类,而在函数参数位置中的Args&&... args
对应模板中的class... Args
,为变长参数列表,将c10::make_intrusive<StorageImpl>( dtype, nelements, allocator->allocate(nelements * dtype.itemsize()), allocator, /*resizeable=*/true);
中的函数参数通过Args
传递过来变为args
。
// c10/util/intrusive_ptr.h template < class TTarget, class NullType = detail::intrusive_target_default_null_type<TTarget>, class... Args> inline intrusive_ptr<TTarget, NullType> make_intrusive(Args&&... args) { return intrusive_ptr<TTarget, NullType>::make(std::forward<Args>(args)...); }
通过make函数最终返回一个用intrusive_ptr
包裹的TTarget
类型的类,其中TTarget
就是StorageImpl
:
template <class... Args> static intrusive_ptr make(Args&&... args) { auto result = intrusive_ptr(new TTarget(std::forward<Args>(args)...)); // We can't use retain_(), because we also have to increase weakcount // and because we allow raising these values from 0, which retain_() // has an assertion against. ++result.target_->refcount_; ++result.target_->weakcount_; return result; }
intrusive_ptr
是一个智能指针,与intrusive_ptr_target
配合,只有继承intrusive_ptr_target
的类才可以使用intrusive_ptr<T>
,与shared_ptr<T>
不同,intrusive_ptr<T>
不会陷入循环计数的怪圈。
// c10/util/intrusive_ptr.h template < class TTarget, class NullType = detail::intrusive_target_default_null_type<TTarget>> class intrusive_ptr final { public: intrusive_ptr(const intrusive_ptr& rhs) : target_(rhs.target_) { retain_(); } ~intrusive_ptr() noexcept { reset_(); } private: TTarget* target_; void retain_() { size_t new_refcount = ++target_->refcount_; } void reset_() noexcept { if (target_ != NullType::singleton() && --target_->refcount_ == 0) { auto weak_count = --target_->weakcount_; const_cast<c10::guts::remove_const_t<TTarget>*>(target_)->release_resources(); if (weak_count == 0) { delete target_; } }
intrusive_ptr_target
不会循环计数的两个核心成员变量,支持原子操作。
class C10_API intrusive_ptr_target { mutable std::atomic<size_t> refcount_; mutable std::atomic<size_t> weakcount_;
显然StorageImpl
继承自intrusive_ptr_target
:
// c10/core/StorageImpl.h struct C10_API StorageImpl final : public c10::intrusive_ptr_target { public: StorageImpl(caffe2::TypeMeta data_type, int64_t numel, at::DataPtr data_ptr, at::Allocator* allocator, bool resizable); private: caffe2::TypeMeta data_type_; // 数据类型 DataPtr data_ptr_; // 指向存储数据的内存块 int64_t numel_; // 数据总数 bool resizable_; bool received_cuda_; Allocator* allocator_; // 内存分配器
可以看到实际的数据块的类型为DataPtr
,其中包含了删除器以及当前数据的设备信息。
// c10/core/Allocator.h class C10_API DataPtr { private: c10::detail::UniqueVoidPtr ptr_; Device device_; public: DataPtr() : ptr_(), device_(DeviceType::CPU) {} DataPtr(void* data, Device device) : ptr_(data), device_(device) {} DataPtr(void* data, void* ctx, DeleterFnPtr ctx_deleter, Device device) : ptr_(data, ctx, ctx_deleter), device_(device) {}
再看其中的UniqueVoidPtr
,这个ptr类似于unique_ptr
,但还是有几点不同的地方,例如该指针只针对void类型。
// c10/util/UniqueVoidPtr.h class UniqueVoidPtr { private: // Lifetime tied to ctx_ void* data_; std::unique_ptr<void, DeleterFnPtr> ctx_; public: UniqueVoidPtr() : data_(nullptr), ctx_(nullptr, &deleteNothing) {} explicit UniqueVoidPtr(void* data) : data_(data), ctx_(nullptr, &deleteNothing) {} UniqueVoidPtr(void* data, void* ctx, DeleterFnPtr ctx_deleter) : data_(data), ctx_(ctx, ctx_deleter ? ctx_deleter : &deleteNothing) {} void* operator->() const { return data_; } void clear() { ctx_ = nullptr; data_ = nullptr; }
…
回到empty_cpu,在初始化storage_impl
后开始构建TensorImpl
,通过make_tensor
传递Tensor的类型以及相关函数参数:
// aten/src/ATen/native/TensorFactories.cpp Tensor empty_cpu(IntArrayRef size, const TensorOptions& options) { ...... auto tensor = detail::make_tensor<TensorImpl>(storage_impl, at::CPUTensorId());
make_tensor
函数中返回Tensor
类,从而构造了一个Tensor。
// build/aten/src/ATen/core/TensorBody.h template <typename T, typename... Args> Tensor make_tensor(Args&&... args) { return Tensor(c10::make_intrusive<T>(std::forward<Args>(args)...)); }
这个Tensor是一个通用的对象,包含一个指向TensorImpl
对象的指针,实际开辟的空间位置指针还在TensorImpl
中的storage_
中。
// build/aten/src/ATen/core/TensorBody.h class CAFFE2_API Tensor { protected: c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> impl_; public: int64_t dim() const { return impl_->dim(); } int64_t storage_offset() const { return impl_->storage_offset(); } Tensor abs() const; Tensor& abs_(); Tensor add(const Tensor & other, Scalar alpha=1) const;
TensorImpl
类也是继承了intrusive_ptr_target
,拥有智能指针的功能。
// c10/core/TensorImpl.h struct C10_API TensorImpl : public c10::intrusive_ptr_target { public: virtual int64_t dim() const; virtual int64_t storage_offset() const; private: Storage storage_; #ifdef NAMEDTENSOR_ENABLED std::unique_ptr<c10::NamedTensorMetaInterface> named_tensor_meta_ = nullptr; #endif c10::VariableVersion version_counter_; PyObject* pyobj_ = nullptr; // weak reference SmallVector<int64_t,5> sizes_; SmallVector<int64_t,5> strides_; int64_t storage_offset_ = 0; int64_t numel_ = 1; caffe2::TypeMeta data_type_; c10::optional<c10::Device> device_opt_; TensorTypeId type_id_; bool is_contiguous_ = true; bool is_wrapped_number_ = false; bool allow_tensor_metadata_change_ = true; bool reserved_ = false; ...
再回顾一下创建Tensor时实际涉及到的类:
class CAFFE2_API Tensor { c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> impl_; ... struct C10_API TensorImpl : public c10::intrusive_ptr_target { Storage storage_; ... struct C10_API Storage { protected: c10::intrusive_ptr<StorageImpl> storage_impl_; ... struct C10_API StorageImpl final : public c10::intrusive_ptr_target { DataPtr data_ptr_; ... class C10_API DataPtr { c10::detail::UniqueVoidPtr ptr_; ... class UniqueVoidPtr { std::unique_ptr<void, DeleterFnPtr> ctx_; ...
接下来回到rand
,在通过at::empty
构造出empty的Tensor后需要使用uniform_
对其进行初始化。
// aten/src/ATen/native/TensorFactories.cpp Tensor rand(IntArrayRef size, Generator* generator, const TensorOptions& options) { auto result = at::empty(size, options); return result.uniform_(0, 1, generator); }
Tensor::uniform_
是Tensor类中的一个方法,实现对Tensor中数据的操作。
// build/aten/src/ATen/core/TensorMethods.h inline Tensor & Tensor::uniform_(double from, double to, Generator * generator) const { static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::uniform_", ""}).value(); return c10::Dispatcher::singleton().callUnboxedOnly<Tensor &, Tensor &, double, double, Generator *>( op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast<Tensor&>(*this), from, to, generator); }
但是Tensor::uniform_
实际调用的函数是找到通过注册机制注册好的函数,这个函数是在编译的过程中按照native_functions.yaml
文件中的指示代码生成。
可以看到,在native_functions.yaml
中的函数uniform_
还对应了两个不同平台(CPU和GPU)的方法,这里我们主要看legacy::cpu::_th_uniform_
// aten/src/ATen/native/native_functions.yaml - func: uniform_(Tensor(a!) self, float from=0, float to=1, *, Generator? generator=None) -> Tensor(a!) variants: method dispatch: CPU: legacy::cpu::_th_uniform_ CUDA: uniform_cuda_
生成的代码如下,也就是callUnboxedOnly
中根据模板元参数和函数参数实际返回并执行的函数:
// build/aten/src/ATen/CPUType.cpp Tensor & uniform_(Tensor & self, double from, double to, Generator * generator) { const OptionalDeviceGuard device_guard(device_of(self)); return at::native::legacy::cpu::_th_uniform_(self, from, to, generator); }
其中at::native::legacy::cpu::_th_uniform_
是自动生成的代码,生成规则如下:
// aten/src/ATen/Declarations.cwrap name: _th_uniform_ types: - floating_point backends: - CPU cname: uniform variants: function return: self arguments: - THTensor* self - double from - double to - THGenerator* generator
进入at::native::legacy::cpu::_th_uniform_
,显然默然会选择ScalarType::Float
这个分支:
// build/aten/src/ATen/LegacyTHFunctionsCPU.cpp Tensor & _th_uniform_(Tensor & self, double from, double to, Generator * generator) { #ifdef BUILD_NAMEDTENSOR #endif // DeviceGuard omitted auto dispatch_scalar_type = infer_scalar_type(self); switch (dispatch_scalar_type) { case ScalarType::Double: { auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_uniform_", false, DeviceType::CPU, ScalarType::Double); THDoubleTensor_uniform(self_, from, to, generator); return self; break; } case ScalarType::Float: { auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_uniform_", false, DeviceType::CPU, ScalarType::Float); THFloatTensor_uniform(self_, from, to, generator); return self; break; } default: AT_ERROR("_th_uniform_ not supported on CPUType for ", dispatch_scalar_type); } }
需要注意Pytorch中使用C语言的宏定义语法实现了多态,上述的THFloatTensor_uniform
对应通过宏定义展开的函数,也就是下面的函数在编译过程中通过宏定义的方式展开生成THFloatTensor_uniform
,具体的解释可以看这里。
而下面这个函数中的TH_TENSOR_APPLY
类似于map函数,对Tensor中每一个元素执行该操作,具体这里不进行深入。
void THTensor_(uniform)(THTensor *self, double a, double b, at::Generator *_generator) { auto gen = at::get_generator_or_default<at::CPUGenerator>(_generator, at::detail::getDefaultCPUGenerator()); // See Note [Acquire lock when using random generators] std::lock_guard<std::mutex> lock(gen->mutex_); #if defined(TH_REAL_IS_FLOAT) at::uniform_real_distribution<float> uniform((float)a, (float)b); TH_TENSOR_APPLY(scalar_t, self, *self_data = (scalar_t)uniform(gen);); #else at::uniform_real_distribution<double> uniform(a, b); TH_TENSOR_APPLY(scalar_t, self, *self_data = (scalar_t)uniform(gen);); #endif }
紧接着进行下一步,在对Tensor初始化之后,我们该执行torch.rand(3, 4)[0]
这一步中最后的索引[0]
操作,对应:
_t1 = torch.rand(3, 4) _t2 = _t1.__getitem__(0) # <--- here del _t1 _t3 = torch.rand(3, 4) r = _t2.__add__(_t3) del _t2 del _t3
剩下的步骤就与之前的原理相同,之后只展示代码流程,就不进行详细描述了:
# torch/tensor.py class Tensor(torch._C._TensorBase):
// torch/csrc/autograd/python_variable.cpp PyTypeObject THPVariableType = { PyVarObject_HEAD_INIT(nullptr, 0) "torch._C._TensorBase", /* tp_name */ sizeof(THPVariable), /* tp_basicsize */ (destructor)THPVariable_dealloc, /* tp_dealloc */ &THPVariable_as_mapping, /* tp_as_mapping */ Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC, /* tp_flags */ (traverseproc)THPVariable_traverse, /* tp_traverse */ (inquiry)THPVariable_clear, /* tp_clear */ THPVariable_properties, /* tp_getset */ THPVariable_pynew /* tp_new */ };
PyObject* THPVariable_getitem(PyObject* self, PyObject* index) { if (index == Py_None) { return wrap(self_.unsqueeze(0)); } else if (index == Py_Ellipsis) { return wrap(at::alias(self_)); } else if (THPUtils_checkLong(index)) { return wrap(applySelect(self_, 0, THPUtils_unpackLong(index))); } else if (PySlice_Check(index)) { return wrap(applySlice(self_, 0, index, true)); } // wrap index in a tuple if it's not already one THPObjectPtr holder = wrapTuple(index); variable_list variableIndices; Variable sliced = applySlicing(self_, holder.get(), variableIndices); ... static Variable applySelect(const Variable& self, int64_t dim, int64_t index, int64_t real_dim=0) { int64_t size = self.size(dim); return self.select(dim, index); }
// aten/src/ATen/core/TensorMethods.h inline Tensor Tensor::select(int64_t dim, int64_t index) const { static auto table = globalATenDispatch().getOpTable("aten::select(Tensor(a) self, int dim, int index) -> Tensor(a)"); return table->getOp<Tensor (const Tensor &, int64_t, int64_t)>(tensorTypeIdToBackend(type_id()), is_variable())(*this, dim, index); }
aten/src/ATen/native/native_functions.yaml
- func: select(Tensor(a) self, int dim, int index) -> Tensor(a) variants: function, method device_guard: False named_guard: False
build/aten/src/ATen/TypeDefault.cpp
auto registerer = torch::RegisterOperators() .op(torch::RegisterOperators::options() .schema("aten::select.int(Tensor(a) self, int dim, int index) -> Tensor(a)") .impl_unboxedOnlyC10CatchAllKernel<Tensor (const Tensor &, int64_t, int64_t), &TypeDefault::select>() .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) ... Tensor TypeDefault::select(const Tensor & self, int64_t dim, int64_t index) { return at::native::select(self, dim, index); }
aten/src/ATen/native/TensorShape.cpp
Tensor select(const Tensor& self, int64_t dim, int64_t index) { auto sizes = self.sizes().vec(); auto strides = self.strides().vec(); auto storage_offset = self.storage_offset() + index * strides[dim]; sizes.erase(sizes.begin() + dim); strides.erase(strides.begin() + dim); auto result = self.as_strided(sizes, strides, storage_offset);
build/aten/src/ATen/core/TensorMethods.h
inline Tensor Tensor::as_strided(IntArrayRef size, IntArrayRef stride, c10::optional<int64_t> storage_offset) const { static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::as_strided", ""}).value(); return c10::Dispatcher::singleton().callUnboxedOnly<Tensor, const Tensor &, IntArrayRef, IntArrayRef, c10::optional<int64_t>>( op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast<Tensor&>(*this), size, stride, storage_offset); }
aten/src/ATen/native/native_functions.yaml
- func: as_strided(Tensor(a) self, int[] size, int[] stride, int? storage_offset=None) -> Tensor(a) variants: function, method dispatch: CPU: as_strided_tensorimpl CUDA: as_strided_tensorimpl
aten/src/ATen/native/TensorShape.cpp
Tensor as_strided_tensorimpl(const Tensor& self, IntArrayRef size, IntArrayRef stride, optional<int64_t> storage_offset_) { auto storage_offset = storage_offset_.value_or(self.storage_offset()); auto result = detail::make_tensor<TensorImpl>(Storage(self.storage()), self.type_set()); setStrided(result, size, stride, storage_offset); return result; }
c10/core/Storage.h
struct C10_API Storage { protected: c10::intrusive_ptr<StorageImpl> storage_impl_;
接下来一步,释放_t1.
_t1 = torch.rand(3, 4) _t2 = _t1.__getitem__(0) del _t1 # <--- here _t3 = torch.rand(3, 4) r = _t2.__add__(_t3) del _t2 del _t3
torch/tensor.py
class Tensor(torch._C._TensorBase):
torch/csrc/autograd/python_variable.cpp
PyTypeObject THPVariableType = { PyVarObject_HEAD_INIT(nullptr, 0) "torch._C._TensorBase", /* tp_name */ sizeof(THPVariable), /* tp_basicsize */ (destructor)THPVariable_dealloc, /* tp_dealloc */ &THPVariable_as_mapping, /* tp_as_mapping */ Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC, /* tp_flags */ (traverseproc)THPVariable_traverse, /* tp_traverse */ (inquiry)THPVariable_clear, /* tp_clear */ THPVariable_properties, /* tp_getset */ THPVariable_pynew /* tp_new */ };
static void THPVariable_dealloc(THPVariable* self) { PyObject_GC_UnTrack(self); THPVariable_clear(self); self->cdata.~Variable(); Py_TYPE(self)->tp_free((PyObject*)self); }
torch/csrc/autograd/python_variable.h
struct THPVariable { PyObject_HEAD torch::autograd::Variable cdata; PyObject* backward_hooks = nullptr; };
torch/csrc/autograd/variable.h
struct TORCH_API Variable : public at::Tensor { ... class CAFFE2_API Tensor { c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> impl_; ... struct C10_API TensorImpl : public c10::intrusive_ptr_target { Storage storage_; ... struct C10_API Storage { protected: c10::intrusive_ptr<StorageImpl> storage_impl_; ... struct C10_API StorageImpl final : public c10::intrusive_ptr_target { DataPtr data_ptr_; ... class C10_API DataPtr { c10::detail::UniqueVoidPtr ptr_; ... class UniqueVoidPtr { std::unique_ptr<void, DeleterFnPtr> ctx_; ... void free_cpu(void* data) { #ifdef _MSC_VER _aligned_free(data); #else free(data); #endif }
最后一步,相加。
_t1 = torch.rand(3, 4) _t2 = _t1.__getitem__(0) del _t1 _t3 = torch.rand(3, 4) r = _t2.__add__(_t3) # <--- here del _t2 del _t3
tools/autograd/templates/python_variable_methods.cpp
PyMethodDef variable_methods[] = { {"__add__", (PyCFunction)THPVariable_add, METH_VARARGS | METH_KEYWORDS, NULL}, {"__radd__", (PyCFunction)THPVariable_add, METH_VARARGS | METH_KEYWORDS, NULL}, {"__iadd__", (PyCFunction)THPVariable_add_, METH_VARARGS | METH_KEYWORDS, NULL},
bool THPVariable_initModule(PyObject *module) { static std::vector<PyMethodDef> methods; THPUtils_addPyMethodDefs(methods, torch::autograd::variable_methods); PyModule_AddObject(module, "_TensorBase", (PyObject *)&THPVariableType);
aten/src/ATen/native/native_functions.yaml
- func: add(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor variants: function, method dispatch: CPU: add CUDA: add SparseCPU: add SparseCUDA: add MkldnnCPU: mkldnn_add
torch/csrc/autograd/generated/python_variable_methods.cpp
static PyObject * THPVariable_add(PyObject* self_, PyObject* args, PyObject* kwargs) { HANDLE_TH_ERRORS static PythonArgParser parser({ "add(Scalar alpha, Tensor other)|deprecated", "add(Tensor other, *, Scalar alpha=1)", }, /*traceable=*/true); auto& self = reinterpret_cast<THPVariable*>(self_)->cdata; ParsedArgs<3> parsed_args; auto r = parser.parse(args, kwargs, parsed_args); if (r.idx == 0) { return wrap(dispatch_add(self, r.scalar(0), r.tensor(1))); } else if (r.idx == 1) { return wrap(dispatch_add(self, r.tensor(0), r.scalar(1))); } Py_RETURN_NONE; END_HANDLE_TH_ERRORS }
torch/csrc/autograd/generated/python_variable_methods_dispatch.h
inline Tensor dispatch_add(Tensor & self, const Tensor & other, Scalar alpha) { AutoNoGIL no_gil; return self.add(other, alpha); }
build/aten/src/ATen/core/TensorMethods.h
inline Tensor Tensor::add(const Tensor & other, Scalar alpha) const { static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::add", "Tensor"}).value(); return c10::Dispatcher::singleton().callUnboxed<Tensor, const Tensor &, const Tensor &, Scalar>( op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other)), const_cast<Tensor&>(*this), other, alpha); }
aten/src/ATen/native/BinaryOps.cpp
namespace at { namespace native { Tensor add(const Tensor& self, const Tensor& other, Scalar alpha) { Tensor result; auto iter = TensorIterator::binary_op(result, self, other); add_stub(iter->device_type(), *iter, alpha); return iter->output(); }
aten/src/ATen/native/TensorIterator.cpp
TensorIterator TensorIterator::binary_op(Tensor& out, const Tensor& a, const Tensor& b, bool check_mem_overlap) { auto iter = TensorIterator(); iter.set_check_mem_overlap(check_mem_overlap); iter.add_output(out); iter.add_input(a); iter.add_input(b); iter.allow_cpu_scalars_ = true; iter.build(); return iter; }
void TensorIterator::build() { // set is_output and is_read_write flags on appropriate tensors mark_outputs(); // Check that the outputs have no internal overlap // and do not share memory with inputs. check_mem_overlaps(); // compute the broadcasted shape compute_shape(); // compute each tensor's stride after broadcasting compute_strides(); // re-order dimensions to improve coalescing reorder_dimensions(); // compute the result dtype and device compute_types(); // allocate the output tensor if it's not provided allocate_outputs(); // coalesce adjacent dimensions when possible coalesce_dimensions(); for (auto& op : operands_) { TORCH_INTERNAL_ASSERT(op.tensor.defined()); op.data = op.tensor.data_ptr(); } }
void TensorIterator::allocate_outputs() { for (int i = 0; i < num_outputs_; i++) { auto& op = operands_[i]; if (!op.tensor.defined()) { TORCH_INTERNAL_ASSERT(op.is_type_defined(), "no type for operand", i); int element_size = elementSize(op.dtype); op.stride_bytes = compatible_stride(element_size); auto tensor_shape = invert_perm(shape_); auto tensor_stride = invert_perm(op.stride_bytes); for (int dim = 0; dim < ndim(); dim++) { tensor_stride[dim] /= element_size; } op.tensor = at::empty_strided(tensor_shape, tensor_stride, op.options()); } } }
aten/src/ATen/native/BinaryOps.h
using binary_fn_alpha = void(*)(TensorIterator&, Scalar alpha); DECLARE_DISPATCH(binary_fn_alpha, add_stub);
aten/src/ATen/native/cpu/BinaryOpsKernel.cpp
REGISTER_DISPATCH(add_stub, &add_kernel);
aten/src/ATen/native/cpu/BinaryOpsKernel.cpp
void add_kernel(TensorIterator& iter, Scalar alpha_scalar) { if (iter.dtype() == ScalarType::Bool) { using scalar_t = bool; auto alpha = alpha_scalar.to<scalar_t>(); cpu_kernel(iter, [=](scalar_t a, scalar_t b) -> scalar_t { return a + alpha * b; }); } else { AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, iter.dtype(), "add_cpu/sub_cpu", [&]() { auto alpha = alpha_scalar.to<scalar_t>(); auto alpha_vec = Vec256<scalar_t>(alpha); cpu_kernel_vec(iter, [=](scalar_t a, scalar_t b) -> scalar_t { return a + alpha * b; }, [=](Vec256<scalar_t> a, Vec256<scalar_t> b) { return vec256::fmadd(b, alpha_vec, a); }); }); } }
之后的操作前面已经介绍过了,不再赘述。
_t1 = torch.rand(3, 4) _t2 = _t1.__getitem__(0) del _t1 _t3 = torch.rand(3, 4) r = _t2.__add__(_t3) del _t2 # <--- here del _t3
至此所有操作以及源码流程结束。
参考
https://www.52coding.com.cn/2019/05/05/PyTorch5/
https://github.com/Microsoft/vscode-cpptools/issues/891
https://github.com/pytorch/pytorch/wiki/Life-of-a-Tensor