点击小眼睛开启蜘蛛网特效

一个Tensor的生命历程(Pytorch版)

《一个Tensor的生命历程(Pytorch版)》

文中涉及到大量的Pytorch的C++源码,版本为1.4.0a,适合有一定Pytorch源码基础的童鞋观看,同时也涉及到一些python中的C/C++拓展的一些基础知识,其中每一段代码的第一行表明了该代码的文件位置。需要注意有些代码是自动生成的,原始工程中并没有,需要编译。

还要注意一点,因为Pytorch仍在积极开发中,所以代码接口变化还是比较频繁,当你看到本文的时候,有可能展示的源码与master版的略有不同,但是大部分的代码逻辑变动不大,我们只需要知道核心工作原理即可。

那开始吧!

现在有一个Tensor,不,是两个,创建两个rand后的tensor然后加起来。

import torch
res = torch.rand(3, 4)[0] + torch.rand(3, 4)

执行后输出:

tensor([[0.3091, 0.5503, 1.0780, 0.9044],
        [0.5770, 0.5245, 0.3225, 1.4672],
        [0.1581, 1.0439, 0.3313, 0.9924]])

呃,输出不重要,先将上述代码细分下:

_t1 = torch.rand(3, 4)
_t2 = _t1.__getitem__(0)
del _t1
_t3 = torch.rand(3, 4)
res = _t2.__add__(_t3)
del _t2
del _t3
# 最后res还在

看第一句发生了什么:

_t1 = torch.rand(3, 4)  # <--  
_t2 = _t1.__getitem__(0)
del _t1
_t3 = torch.rand(3, 4)
res = _t2.__add__(_t3)
del _t2
del _t3

其实torch.randtorch_C._VariableFunctions这个模块中,torch.rand不是一个python的函数,只是一个模块中方法的名称,通过torch.rand调用torch模块中的rand方法,而这个模块是通过python的C/C++拓展机制生成的,实际中torch.rand对应的代码是通过一个yaml文本自动生成的。

这个文件是一个自动生成代码函数的参数列表,Pytorch源码中有很多的代码文件是通过gen.py自动生成的,至于为什么要自动生成,是因为很多的函数代码比较相似,重复性较多,通过自动生成可以避免大部分重复的工作量。

// aten/src/ATen/native/native_functions.yaml

- func: scalar_tensor(Scalar s, *, ScalarType? dtype=None, Layout? layout=None, 
    Device? device=None, bool? pin_memory=None) -> Tensor

- func: rand(int[] size, *, ScalarType? dtype=None, Layout? layout=None, 
    Device? device=None, bool? pin_memory=None) -> Tensor

- func: rand(int[] size, *, Generator? generator, ScalarType? dtype=None, 
    Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor

- func: rand(int[] size, *, Tensor(a!) out) -> Tensor(a!)

- func: rand(int[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!)

- func: rand_like(Tensor self) -> Tensor

- func: rand_like(Tensor self, *, ScalarType dtype, Layout layout, 
    Device device, bool pin_memory=False) -> Tensor

通过上述的自动生成代码文件,在如下的代码的${py_method_defs}的位置生成rand以及其他函数的方法。

// tools/autograd/templates/python_torch_functions.cpp

static PyMethodDef torch_functions[] = {
  {"arange", (PyCFunction)THPVariable_arange, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
  {"as_tensor", (PyCFunction)THPVariable_as_tensor, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
  {"dsmm", (PyCFunction)THPVariable_mm, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
  {"from_numpy", (PyCFunction)THPVariable_from_numpy, METH_STATIC | METH_O, NULL},
  {"hsmm", (PyCFunction)THPVariable_hspmm, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
  {"_promote_types", (PyCFunction)THPVariable__promote_types, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
  {"nonzero", (PyCFunction)THPVariable_nonzero, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
  {"randint", (PyCFunction)THPVariable_randint, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
  {"range", (PyCFunction)THPVariable_range, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
  {"saddmm", (PyCFunction)THPVariable_sspaddmm, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
  {"sparse_coo_tensor", (PyCFunction)THPVariable_sparse_coo_tensor, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
  {"spmm", (PyCFunction)THPVariable_mm, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
  {"tensor", (PyCFunction)THPVariable_tensor, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
  {"get_device", (PyCFunction)THPVariable_get_device, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
  ${py_method_defs}
  {NULL}
};

将上述的native_functions.yaml中的函数参数通过生成机制,在上述代码的${py_method_defs}位置,生成新的代码以及新的文件,我们可以看到我们的"rand"

//torch/csrc/autograd/generated/python_torch_functions.cpp

static PyMethodDef torch_functions[] = {
  {"arange", (PyCFunction)(void(*)(void))THPVariable_arange, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
  {"as_tensor", (PyCFunction)(void(*)(void))THPVariable_as_tensor, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
  {"dsmm", (PyCFunction)(void(*)(void))THPVariable_mm, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
  {"from_numpy", (PyCFunction)THPVariable_from_numpy, METH_STATIC | METH_O, NULL},
  {"hsmm", (PyCFunction)(void(*)(void))THPVariable_hspmm, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
  {"nonzero", (PyCFunction)(void(*)(void))THPVariable_nonzero, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
  {"randint", (PyCFunction)(void(*)(void))THPVariable_randint, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
  {"range", (PyCFunction)(void(*)(void))THPVariable_range, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
  {"saddmm", (PyCFunction)(void(*)(void))THPVariable_sspaddmm, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
  {"sparse_coo_tensor", (PyCFunction)(void(*)(void))THPVariable_sparse_coo_tensor, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
  {"spmm", (PyCFunction)(void(*)(void))THPVariable_mm, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
  {"tensor", (PyCFunction)(void(*)(void))THPVariable_tensor, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
  {"get_device", (PyCFunction)(void(*)(void))THPVariable_get_device, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
  // 这部分以上与上面的代码相同,下面为自动生成的代码
  {"numel", (PyCFunction)(void(*)(void))THPVariable_numel, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
  {"__and__", (PyCFunction)(void(*)(void))THPVariable___and__, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
  ...
   {"quantized_rnn_tanh_cell", (PyCFunction)(void(*)(void))THPVariable_quantized_rnn_tanh_cell, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
  {"rand", (PyCFunction)(void(*)(void))THPVariable_rand, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
  {"rand_like", (PyCFunction)(void(*)(void))THPVariable_rand_like, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
  {"randint_like", (PyCFunction)(void(*)(void))THPVariable_randint_like, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
  {"randn", (PyCFunction)(void(*)(void))THPVariable_randn, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
  {"randn_like", (PyCFunction)(void(*)(void))THPVariable_randn_like, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
  {"randperm", (PyCFunction)(void(*)(void))THPVariable_randperm, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
  ...
  {"zeros", (PyCFunction)(void(*)(void))THPVariable_zeros, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
  {"zeros_like", (PyCFunction)(void(*)(void))THPVariable_zeros_like, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
  {NULL}
};

由上面代码可以看到,"rand"对应的绑定函数为THPVariable_rand。具体探究这个函数之前,我们首先需要初始化,因为这个函数要绑定在python端,将上述的一堆方法(tp_methods)与类型对象(PyTypeObject)绑定:

// tools/autograd/templates/python_torch_functions.cpp

static PyTypeObject THPVariableFunctions = {
  PyVarObject_HEAD_INIT(NULL, 0)
  "torch._C._VariableFunctions",         /* tp_name */
  Py_TPFLAGS_DEFAULT,                    /* tp_flags */
  NULL,                                  /* tp_doc */
  torch_functions,                       /* tp_methods */
...
};

然后进行初始化,将上述的类型对象初始化为python中的模块:

void initTorchFunctions(PyObject* module) {
  if (PyType_Ready(&THPVariableFunctions) < 0) {
    throw python_error();
  }
  Py_INCREF(&THPVariableFunctions);
  if (PyModule_AddObject(module, "_VariableFunctions", (PyObject*)&THPVariableFunctions) < 0) {
    throw python_error();
  }
}

这样我们在python端调用的时候会在生成的torch_C._VariableFunctions中找这个方法:

for name in dir(_C._VariableFunctions):    
    if name.startswith('__'):    
        continue    
    globals()[name] = getattr(_C._VariableFunctions, name)

好,现在我们具体讨论一下{"rand", (PyCFunction)(void(*)(void))THPVariable_rand, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}这个方法对应的函数吧。

// torch/csrc/autograd/generated/python_torch_functions.cpp

static PyObject * THPVariable_rand(PyObject* self_, PyObject* args, PyObject* kwargs)
{
  HANDLE_TH_ERRORS
  static PythonArgParser parser({
    "rand(IntArrayRef size, *, DimnameList? names, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool pin_memory=False, bool requires_grad=False)",
    "rand(IntArrayRef size, *, Generator generator, DimnameList? names, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool pin_memory=False, bool requires_grad=False)",
    "rand(IntArrayRef size, *, Generator generator, Tensor out=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool pin_memory=False, bool requires_grad=False)",
    "rand(IntArrayRef size, *, Tensor out=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool pin_memory=False, bool requires_grad=False)",
  }, /*traceable=*/true);

  ParsedArgs<9> parsed_args;
  auto r = parser.parse(args, kwargs, parsed_args);

  if (r.idx == 0) {
    auto size = r.intlist(0);
    auto __names = r.toDimnameListOptional(1);
    c10::optional<DimnameList> names = __names ? c10::make_optional(DimnameList(__names.value())) : c10::nullopt;
    auto dtype = r.scalartype(2);
    auto device = r.device(4);
    const auto options = TensorOptions()
        .dtype(dtype)
        .device(device)
        .layout(r.layout(3).layout)
        .requires_grad(r.toBool(6))
        .pinned_memory(r.toBool(5));
    return wrap(dispatch_rand(size, names, options));
  ...
  } else if (r.idx == 3) { // 最终执行到这一个分支
    if (r.isNone(1)) {
      auto size = r.intlist(0);
      auto dtype = r.scalartype(2);
      auto device = r.device(4);
      const auto options = TensorOptions()
          .dtype(dtype)
          .device(device)
          .layout(r.layout(3).layout)
          .requires_grad(r.toBool(6))
          .pinned_memory(r.toBool(5));
      return wrap(dispatch_rand(size, options));
    } else {
      check_out_type_matches(r.tensor(1), r.scalartype(2), r.isNone(2),
                             r.layout(3), r.isNone(3),
                             r.device(4), r.isNone(4));
      return wrap(dispatch_rand(r.intlist(0), r.tensor(1)).set_requires_grad(r.toBool(6)));
    }
  }
  Py_RETURN_NONE;
  END_HANDLE_TH_ERRORS
}

可以看到上述函数最终实际执行的是dispatch_rand,这里需要注意,这个函数释放了GIL锁,这会使当前的执行代码和python中执行的代码互不影响:

// torch/csrc/autograd/generated/python_torch_functions_dispatch.h
inline Tensor dispatch_rand(IntArrayRef size, const TensorOptions & options) {
  maybe_initialize_cuda(options);
   /* 释放GIL锁 */
  AutoNoGIL no_gil;
  return torch::rand(size, generator, options);
}

然后我们进入torch::rand,这里有一点需要注意,在torch::rand这个函数中我们最终返回的是autograd::make_variable后的tensor,也就是说如果我们不需要differentiable的tensor的话,是可以直接返回at::rand

这也就是为什么在Pytorch的C++前端中提到如果直接使用at::rand构造的Tensor是没有自动求导功能的:

// torch/csrc/autograd/generated/variable_factories.h

inline at::Tensor rand(at::IntArrayRef size, const at::TensorOptions & options = {}) {
  torch::jit::Node* node = nullptr;
  std::shared_ptr<jit::tracer::TracingState> tracer_state;
  if (jit::tracer::isTracing()) { // 这个分支不会进入,因为我们并没有使用Jit
    tracer_state = jit::tracer::getTracingState();
    at::Symbol op_name;
    op_name = jit::Symbol::fromQualString("aten::rand");
    node = tracer_state->graph->create(op_name, /*num_outputs=*/0);
    jit::tracer::recordSourceLocation(node);
    jit::tracer::addInputs(node, "size", size);
    jit::tracer::addInputs(node, "options", options);
    tracer_state->graph->insertNode(node);
  
    jit::tracer::setTracingState(nullptr);
  }
  at::Tensor tensor = ([&]() {
    at::AutoNonVariableTypeMode non_var_type_mode(true);
    return at::rand(size, at::TensorOptions(options).is_variable(false));
  })();
  at::Tensor result = 
    autograd::make_variable(std::move(tensor), /*requires_grad=*/options.requires_grad());
  if (tracer_state) {
    jit::tracer::setTracingState(std::move(tracer_state));
    jit::tracer::addOutput(node, result);
  }
  return result;
}

然后我们继续进入at::rand

// build/aten/src/ATen/Functions.h

static inline Tensor rand(IntArrayRef size, const TensorOptions & options) {
#ifdef USE_STATIC_DISPATCH
    return TypeDefault::rand(size, options);
#else // 从以下开始执行
    globalLegacyTypeDispatch().initForTensorTypeSet(at::detail::multi_dispatch_tensor_type_set(options));
    static auto table = globalATenDispatch().getOpTable("aten::rand(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor");
    return table->callUnboxed<Tensor, IntArrayRef, const TensorOptions &>(size, options);
#endif
}

我们可以看到上述的代码中getOpTable函数的参数是我们具体调用函数的一些string,也就是说getOpTable方法可以根据字符串类型的表示找到相对应的函数,我们先看看globalATenDispatch()是什么:

// aten/src/ATen/core/ATenDispatch.cpp
// 像不像单例?
ATenDispatch & globalATenDispatch() {
  static ATenDispatch singleton;
  return singleton;
}

很像设计模式中的单例模式吧?那ATenDispatch又是什么?看到registerOp这个方法没,多么的熟悉啊,显然这是一个op注册机制的类。

// aten/src/ATen/core/ATenDispatch.h

class CAFFE2_API ATenDispatch {
 public:
  template<class FuncType>
  ATenDispatch& registerOp(TensorTypeId id, const char* schema, FuncType* fn) {
    std::lock_guard<std::mutex> lock(mutex_);
    if (op_tables_.find(schema) == op_tables_.end()) {
      op_tables_.insert(std::make_pair(schema, ATenOpTable(schema)));
    }
    op_tables_.at(schema).registerOp(id, reinterpret_cast<void*>(fn));
    return *this;
  }

  ATenDispatch& registerFallbackBoxedOp(TensorTypeId id, FallbackBoxedFunction* fn) {
    std::lock_guard<std::mutex> lock(mutex_);
    boxed_fallback_table_[static_cast<size_t>(id)] = fn;
    return *this;
  }

  const ATenOpTable* getOpTable(const char* schema) const {
    auto iter = op_tables_.find(schema);
    TORCH_CHECK(iter != op_tables_.end(),
        "No functions are registered for schema ", schema);
    return &iter->second;
  }

  FallbackBoxedFunction* getFallbackBoxedOp(TensorTypeId tid) const {
    return boxed_fallback_table_[static_cast<size_t>(tid)];
  }

 private:
  std::unordered_map<std::string, ATenOpTable> op_tables_;
  FallbackBoxedFunction* boxed_fallback_table_[static_cast<int64_t>(TensorTypeId::NumTensorIds)] = {nullptr};
  std::mutex mutex_;
};

那么与std::string组成map的ATenOpTable又是什么呢?下面的介绍已经比较清楚了,这个类储存了不同backend下的实现方法,同时也可以应用于Variables。

// ATenOpTable stores the implementations for each backend, in addition to
// an implementation for variables.

// aten/src/ATen/core/ATenDispatch.h
class CAFFE2_API ATenOpTable {
 public:
  ATenOpTable(std::string schema)
    : schema_(std::move(schema)) {}

  // NB: No universal forwarding
  template<class Result, class... Args>
  Result callUnboxed(Args... args) const;

 private:

  void registerOp(TensorTypeId tid, void* fn) {
    TORCH_CHECK(function_table_[static_cast<int64_t>(tid)] == nullptr,
        "Attempting to register function for schema ", schema_,
        " and tensor type ", toString(tid),
        " but there is already a function registered");
    function_table_[static_cast<int64_t>(tid)] = fn;
  }

  C10_NORETURN void reportError(TensorTypeId tid) const;

  friend class ATenDispatch;

  std::string schema_;
  void* function_table_[static_cast<int64_t>(TensorTypeId::NumTensorIds)] = {nullptr};
};

好了,回到上面rand函数中的最后一句return table->callUnboxed<Tensor, IntArrayRef, const TensorOptions &>(size, options);。我们可以看到table就是ATenOpTable类的一个实例,而callUnboxed是它的一个方法,这个方法根据传递的模板参数返回了特定的函数:

// build/aten/src/ATen/TypeDefault.cpp

Tensor rand(IntArrayRef size, const TensorOptions & options) {
    const DeviceGuard device_guard(options.device());
    return at::native::rand(size, options);
}

进入at::native::rand

// aten/src/ATen/native/TensorFactories.cpp

Tensor rand(IntArrayRef size, const TensorOptions& options) {
  return native::rand(size, nullptr, options);
}

进入native::rand

// aten/src/ATen/native/TensorFactories.cpp

Tensor rand(IntArrayRef size, Generator* generator, const TensorOptions& options) {
  auto result = at::empty(size, options);
  return result.uniform_(0, 1, generator);
}

进入at::empty

// build/aten/src/ATen/Functions.h

static inline Tensor empty(IntArrayRef size, const TensorOptions & options, c10::optional<MemoryFormat> memory_format) {
#ifdef USE_STATIC_DISPATCH
    switch(tensorTypeIdToBackend(impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(options)))) {
        case Backend::CPU:
            return CPUType::empty(size, options, memory_format);
            break;
        case Backend::SparseCPU:
            return SparseCPUType::empty(size, options, memory_format);
            break;
        default:
            AT_ERROR("empty not implemented for ", at::toString(at::detail::multi_dispatch_tensor_type_set(options)));
    }
#else
    globalLegacyTypeDispatch().initForTensorTypeSet(at::detail::multi_dispatch_tensor_type_set(options));
    static auto table = globalATenDispatch().getOpTable("aten::empty.memory_format(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor");
    return table->callUnboxed<Tensor, IntArrayRef, const TensorOptions &, c10::optional<MemoryFormat>>(size, options, memory_format);
#endif
}

这次继续按照之前的方式来找到"aten::empty.memory_format(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor"这个op。需要注意这个op函数也是自动生成的,对应不同的backend。

// aten/src/ATen/native/native_functions.yaml

- func: empty.memory_format(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
  dispatch:
    CPU: empty_cpu
    CUDA: empty_cuda
    MkldnnCPU: empty_mkldnn
    SparseCPU: empty_sparse
    SparseCUDA: empty_sparse

所以最终table中Unbox后的函数为:

// build/aten/src/ATen/CPUType.cpp

Tensor empty(IntArrayRef size, const TensorOptions & options, c10::optional<MemoryFormat> memory_format) {
    const DeviceGuard device_guard(options.device());
    return at::native::empty_cpu(size, options, memory_format);
}

我们进入at::native::empty_cpu

// aten/src/ATen/native/TensorFactories.cpp

Tensor empty_cpu(IntArrayRef size, const TensorOptions& options, c10::optional<c10::MemoryFormat> optional_memory_format) {
  AT_ASSERT(options.device().type() == DeviceType::CPU);
  AT_ASSERT(!options.is_variable());  // is_variable should have been 'unpacked'  // TODO: remove this when Variable and Tensor are merged
  check_size_nonnegative(size);

  c10::Allocator* allocator;
  if (options.pinned_memory()) {
    allocator = detail::getCUDAHooks().getPinnedMemoryAllocator();
  } else {
    allocator = at::getCPUAllocator(); // 执行这句
  }

  int64_t nelements = prod_intlist(size);
  auto dtype = options.dtype();
  auto storage_impl = c10::make_intrusive<StorageImpl>(
    dtype,
    nelements,
    allocator->allocate(nelements * dtype.itemsize()),
    allocator,
    /*resizeable=*/true);

  auto tensor = detail::make_tensor<TensorImpl>(std::move(storage_impl), at::TensorTypeId::CPUTensorId);
  // Default TensorImpl has size [0]
  if (size.size() != 1 || size[0] != 0) {
    tensor.unsafeGetTensorImpl()->set_sizes_contiguous(size);
  }

  auto memory_format = optional_memory_format.value_or(MemoryFormat::Contiguous);
  tensor.unsafeGetTensorImpl()->empty_tensor_restride(memory_format);
  return tensor;
}

这个时候因为我们调用的是CPU版的empty,这时需要在CPU申请空间,首先得到正确的申请空间的“方法”函数,赋予c10::Allocator*

先是这个:

// aten/src/ATen/Context.cpp

Allocator* getCPUAllocator() {
  return getTHDefaultAllocator();
}

深入:

// aten/src/TH/THAllocator.cpp

at::Allocator* getTHDefaultAllocator() {
  return c10::GetCPUAllocator();
}

再深入:

// c10/core/CPUAllocator.cpp

at::Allocator* GetCPUAllocator() {
  return GetAllocator(DeviceType::CPU);
}

再深入,可以发现这个allocator是allocator_array中的一个,在下面的函数GetAllocator中根据索引标号来取出:

// c10/core/Allocator.cpp

at::Allocator* GetAllocator(const at::DeviceType& t) {
  auto* alloc = allocator_array[static_cast<int>(t)];
  AT_ASSERTM(alloc, "Allocator for ", t, " is not set.");
  return alloc;
}

allocator_array这个东西是怎么来的?其实它是一个全局变量,用来储存各种allocator,同时配备了SetAllocatorGetAllocator来设置和获取相应的分配器:

// c10/core/Allocator.cpp

C10_API at::Allocator* allocator_array[at::COMPILE_TIME_MAX_DEVICE_TYPES];

// 通过SetAllocator函数将设备的类型索引与alloc联系成一个哈希表
void SetAllocator(at::DeviceType t, at::Allocator* alloc) {
  allocator_array[static_cast<int>(t)] = alloc;
}

然后使用REGISTER_ALLOCATOR来注册这些alloc。

// c10/core/Allocator.h

template <DeviceType t>
struct AllocatorRegisterer {
  explicit AllocatorRegisterer(Allocator* alloc) {
    SetAllocator(t, alloc);
  }
};

#define REGISTER_ALLOCATOR(t, f)                    \
  namespace {                                       \
  static AllocatorRegisterer<t> g_allocator_d(f); \
  }

比如,把DefaultCPUAllocator注册为DeviceType::CPU,而DeviceType::CPU就是枚举成员,对应着一个数字。

// c10/core/CPUAllocator.cpp

static DefaultCPUAllocator g_cpu_alloc;

REGISTER_ALLOCATOR(DeviceType::CPU, &g_cpu_alloc);

DefaultCPUAllocator就是我们在CPU中开辟空间实际要调用的alloc类,它继承了at::Allocator

// c10/core/CPUAllocator.cpp

struct C10_API DefaultCPUAllocator final : at::Allocator {
  DefaultCPUAllocator() {}
  ~DefaultCPUAllocator() override {}
  at::DataPtr allocate(size_t nbytes) const override {
    void* data = alloc_cpu(nbytes);
    if (FLAGS_caffe2_report_cpu_memory_usage && nbytes > 0) {
      getMemoryAllocationReporter().New(data, nbytes);
      return {data, data, &ReportAndDelete, at::Device(at::DeviceType::CPU)};
    }
    return {data, data, &free_cpu, at::Device(at::DeviceType::CPU)};
  }

  static void ReportAndDelete(void* ptr) {
    if (!ptr) {
      return;
    }
    getMemoryAllocationReporter().Delete(ptr);
    free_cpu(ptr);
  }

  at::DeleterFnPtr raw_deleter() const override {
    if (FLAGS_caffe2_report_cpu_memory_usage) {
      return &ReportAndDelete;
    }
    return &free_cpu;
  }

 protected:
  static MemoryAllocationReporter& getMemoryAllocationReporter() {
    static MemoryAllocationReporter reporter_;
    return reporter_;
  }

};

其中实际的开辟函数alloc_cpufree_cpu,这两个函数在开辟空间和删除空间的时候会被调用:

// c10/core/CPUAllocator.cpp

void* alloc_cpu(size_t nbytes) {
  if (nbytes == 0) {
    return nullptr;
  }
  // We might have clowny upstream code that tries to alloc a negative number
  // of bytes. Let's catch it early.
  CAFFE_ENFORCE(
    ((ptrdiff_t)nbytes) >= 0,
    "alloc_cpu() seems to have been called with negative number: ", nbytes);

  void* data;
#ifdef __ANDROID__
  data = memalign(gAlignment, nbytes);
#elif defined(_MSC_VER)
  data = _aligned_malloc(nbytes, gAlignment);
#else
  int err = posix_memalign(&data, gAlignment, nbytes);
  if (err != 0) {
    CAFFE_THROW(
        "DefaultCPUAllocator: can't allocate memory: you tried to allocate ",
        nbytes,
        " bytes. Error code ",
        err,
        " (",
        strerror(err),
        ")");
  }
#endif

  CAFFE_ENFORCE(
      data,
      "DefaultCPUAllocator: not enough memory: you tried to allocate ",
      nbytes,
      " bytes. Buy new RAM!");

  // move data to a thread's NUMA node
  NUMAMove(data, nbytes, GetCurrentNUMANode());
  CHECK(
      !FLAGS_caffe2_cpu_allocator_do_zero_fill ||
      !FLAGS_caffe2_cpu_allocator_do_junk_fill)
    << "Cannot request both zero-fill and junk-fill at the same time";
  if (FLAGS_caffe2_cpu_allocator_do_zero_fill) {
    memset(data, 0, nbytes);
  } else if (FLAGS_caffe2_cpu_allocator_do_junk_fill) {
    memset_junk(data, nbytes);
  }

  return data;
}


void free_cpu(void* data) {
#ifdef _MSC_VER
  _aligned_free(data);
#else
  free(data);
#endif
}

接着继续回到at::native::empty_cpu,因为empty_cpu要构建tensor变量,而tensor变量首先需要对应的storage,也就是Tensor中的实际储存的地址,而StorageImpl是继承intrusive_ptr_target的一个子类。实际代码中通过c10::make_intrusive构建出storage_impl

Tensor empty_cpu(IntArrayRef size, const TensorOptions& options) {
  ......
  int64_t nelements = prod_intlist(size);
  auto dtype = options.dtype();
  auto storage_impl = c10::make_intrusive<StorageImpl>(
    dtype,
    nelements,
    allocator->allocate(nelements * dtype.itemsize()),
    allocator,
    /*resizeable=*/true);

make_intrusive是模板元函数,其中TTarget即传递过来的StorageImpl类,而在函数参数位置中的Args&&... args对应模板中的class... Args,为变长参数列表,将c10::make_intrusive<StorageImpl>( dtype, nelements, allocator->allocate(nelements * dtype.itemsize()), allocator, /*resizeable=*/true);中的函数参数通过Args传递过来变为args

// c10/util/intrusive_ptr.h

template <
    class TTarget,
    class NullType = detail::intrusive_target_default_null_type<TTarget>,
    class... Args>
inline intrusive_ptr<TTarget, NullType> make_intrusive(Args&&... args) {
  return intrusive_ptr<TTarget, NullType>::make(std::forward<Args>(args)...);
}

通过make函数最终返回一个用intrusive_ptr包裹的TTarget类型的类,其中TTarget就是StorageImpl:

  template <class... Args>
  static intrusive_ptr make(Args&&... args) {
    auto result = intrusive_ptr(new TTarget(std::forward<Args>(args)...));
    // We can't use retain_(), because we also have to increase weakcount
    // and because we allow raising these values from 0, which retain_()
    // has an assertion against.
    ++result.target_->refcount_;
    ++result.target_->weakcount_;

    return result;
  }

intrusive_ptr是一个智能指针,与intrusive_ptr_target配合,只有继承intrusive_ptr_target的类才可以使用intrusive_ptr<T>,与shared_ptr<T>不同,intrusive_ptr<T>不会陷入循环计数的怪圈。

// c10/util/intrusive_ptr.h

template <    
    class TTarget,
    class NullType = detail::intrusive_target_default_null_type<TTarget>>
class intrusive_ptr final {
 public:    
  intrusive_ptr(const intrusive_ptr& rhs) : target_(rhs.target_) {    
    retain_();    
  }    
    
  ~intrusive_ptr() noexcept {    
    reset_();    
  }

 private:    
  TTarget* target_;    
    
  void retain_() {    
    size_t new_refcount = ++target_->refcount_;    
  }    
    
  void reset_() noexcept {    
    if (target_ != NullType::singleton() && --target_->refcount_ == 0) {    
      auto weak_count = --target_->weakcount_;    
      const_cast<c10::guts::remove_const_t<TTarget>*>(target_)->release_resources();    
      if (weak_count == 0) {    
        delete target_;    
      }    
    }

intrusive_ptr_target不会循环计数的两个核心成员变量,支持原子操作。

class C10_API intrusive_ptr_target {
  mutable std::atomic<size_t> refcount_;    
  mutable std::atomic<size_t> weakcount_;

显然StorageImpl继承自intrusive_ptr_target

// c10/core/StorageImpl.h

struct C10_API StorageImpl final : public c10::intrusive_ptr_target {
 public:
  StorageImpl(caffe2::TypeMeta data_type, int64_t numel, at::DataPtr data_ptr,
      at::Allocator* allocator, bool resizable);

  private:
    caffe2::TypeMeta data_type_;  // 数据类型
    DataPtr data_ptr_;            // 指向存储数据的内存块
    int64_t numel_;               // 数据总数
    bool resizable_;
    bool received_cuda_;
    Allocator* allocator_;        // 内存分配器

可以看到实际的数据块的类型为DataPtr,其中包含了删除器以及当前数据的设备信息。

// c10/core/Allocator.h

class C10_API DataPtr {
 private:
  c10::detail::UniqueVoidPtr ptr_;
  Device device_;

 public:
  DataPtr() : ptr_(), device_(DeviceType::CPU) {}
  DataPtr(void* data, Device device) : ptr_(data), device_(device) {}
  DataPtr(void* data, void* ctx, DeleterFnPtr ctx_deleter, Device device)
      : ptr_(data, ctx, ctx_deleter), device_(device) {}

再看其中的UniqueVoidPtr,这个ptr类似于unique_ptr,但还是有几点不同的地方,例如该指针只针对void类型。

// c10/util/UniqueVoidPtr.h

class UniqueVoidPtr {
 private:
  // Lifetime tied to ctx_
  void* data_;
  std::unique_ptr<void, DeleterFnPtr> ctx_;

 public:
  UniqueVoidPtr() : data_(nullptr), ctx_(nullptr, &deleteNothing) {}
  explicit UniqueVoidPtr(void* data)
      : data_(data), ctx_(nullptr, &deleteNothing) {}
  UniqueVoidPtr(void* data, void* ctx, DeleterFnPtr ctx_deleter)
      : data_(data), ctx_(ctx, ctx_deleter ? ctx_deleter : &deleteNothing) {}
  void* operator->() const {
    return data_;
  }
  void clear() {
    ctx_ = nullptr;
    data_ = nullptr;
  }


回到empty_cpu,在初始化storage_impl后开始构建TensorImpl,通过make_tensor传递Tensor的类型以及相关函数参数:

// aten/src/ATen/native/TensorFactories.cpp

Tensor empty_cpu(IntArrayRef size, const TensorOptions& options) {
  ......
  auto tensor = detail::make_tensor<TensorImpl>(storage_impl, at::CPUTensorId());

make_tensor函数中返回Tensor类,从而构造了一个Tensor。

// build/aten/src/ATen/core/TensorBody.h

template <typename T, typename... Args>
Tensor make_tensor(Args&&... args) {
  return Tensor(c10::make_intrusive<T>(std::forward<Args>(args)...));
}

这个Tensor是一个通用的对象,包含一个指向TensorImpl对象的指针,实际开辟的空间位置指针还在TensorImpl中的storage_中。

// build/aten/src/ATen/core/TensorBody.h

class CAFFE2_API Tensor {
 protected:
  c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> impl_;

 public:
  int64_t dim() const {
    return impl_->dim();
  }
  int64_t storage_offset() const {
    return impl_->storage_offset();
  }

  Tensor abs() const;
  Tensor& abs_();
  Tensor add(const Tensor & other, Scalar alpha=1) const;

TensorImpl类也是继承了intrusive_ptr_target,拥有智能指针的功能。

// c10/core/TensorImpl.h

struct C10_API TensorImpl : public c10::intrusive_ptr_target {    
 public:
  virtual int64_t dim() const;
  virtual int64_t storage_offset() const;

 private:
  Storage storage_;
#ifdef NAMEDTENSOR_ENABLED
  std::unique_ptr<c10::NamedTensorMetaInterface> named_tensor_meta_ = nullptr;    
#endif        
  c10::VariableVersion version_counter_;        
  PyObject* pyobj_ = nullptr; // weak reference    
  SmallVector<int64_t,5> sizes_;    
  SmallVector<int64_t,5> strides_;        
  int64_t storage_offset_ = 0;
  int64_t numel_ = 1;
  caffe2::TypeMeta data_type_;
  c10::optional<c10::Device> device_opt_;    
  TensorTypeId type_id_;    
  bool is_contiguous_ = true;    
  bool is_wrapped_number_ = false;    
  bool allow_tensor_metadata_change_ = true;    
  bool reserved_ = false;
...

再回顾一下创建Tensor时实际涉及到的类:

class CAFFE2_API Tensor {
    c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> impl_;
...

struct C10_API TensorImpl : public c10::intrusive_ptr_target {
  Storage storage_;
...

struct C10_API Storage {
 protected:
  c10::intrusive_ptr<StorageImpl> storage_impl_;
...

struct C10_API StorageImpl final : public c10::intrusive_ptr_target {
  DataPtr data_ptr_;
...

class C10_API DataPtr { 
  c10::detail::UniqueVoidPtr ptr_;
...

class UniqueVoidPtr {
 std::unique_ptr<void, DeleterFnPtr> ctx_;
...

接下来回到rand,在通过at::empty构造出empty的Tensor后需要使用uniform_对其进行初始化。

// aten/src/ATen/native/TensorFactories.cpp

Tensor rand(IntArrayRef size, Generator* generator, const TensorOptions& options) {    
  auto result = at::empty(size, options);    
  return result.uniform_(0, 1, generator);    
}

Tensor::uniform_是Tensor类中的一个方法,实现对Tensor中数据的操作。

// build/aten/src/ATen/core/TensorMethods.h

inline Tensor & Tensor::uniform_(double from, double to, Generator * generator) const {
    static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::uniform_", ""}).value();
    return c10::Dispatcher::singleton().callUnboxedOnly<Tensor &, Tensor &, double, double, Generator *>(
        op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast<Tensor&>(*this), from, to, generator);
}

但是Tensor::uniform_实际调用的函数是找到通过注册机制注册好的函数,这个函数是在编译的过程中按照native_functions.yaml文件中的指示代码生成。

可以看到,在native_functions.yaml中的函数uniform_还对应了两个不同平台(CPU和GPU)的方法,这里我们主要看legacy::cpu::_th_uniform_

// aten/src/ATen/native/native_functions.yaml

- func: uniform_(Tensor(a!) self, float from=0, float to=1, *, Generator? generator=None) -> Tensor(a!)
  variants: method
  dispatch:
    CPU: legacy::cpu::_th_uniform_
    CUDA: uniform_cuda_

生成的代码如下,也就是callUnboxedOnly中根据模板元参数和函数参数实际返回并执行的函数:

// build/aten/src/ATen/CPUType.cpp

Tensor & uniform_(Tensor & self, double from, double to, Generator * generator) {
    const OptionalDeviceGuard device_guard(device_of(self));
    return at::native::legacy::cpu::_th_uniform_(self, from, to, generator);
}

其中at::native::legacy::cpu::_th_uniform_是自动生成的代码,生成规则如下:

// aten/src/ATen/Declarations.cwrap

  name: _th_uniform_
  types:
    - floating_point
  backends:
    - CPU
  cname: uniform
  variants: function
  return: self
  arguments:
    - THTensor* self
    - double from
    - double to
    - THGenerator* generator

进入at::native::legacy::cpu::_th_uniform_,显然默然会选择ScalarType::Float这个分支:

// build/aten/src/ATen/LegacyTHFunctionsCPU.cpp

Tensor & _th_uniform_(Tensor & self, double from, double to, Generator * generator) {
#ifdef BUILD_NAMEDTENSOR

#endif
    // DeviceGuard omitted
    auto dispatch_scalar_type = infer_scalar_type(self);
    switch (dispatch_scalar_type) {
        case ScalarType::Double: {
            auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_uniform_", false, DeviceType::CPU, ScalarType::Double);
            THDoubleTensor_uniform(self_, from, to, generator);
            return self;
            break;
        }
        case ScalarType::Float: {
            auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_uniform_", false, DeviceType::CPU, ScalarType::Float);
            THFloatTensor_uniform(self_, from, to, generator);
            return self;
            break;
        }
        default:
            AT_ERROR("_th_uniform_ not supported on CPUType for ", dispatch_scalar_type);
    }
}

需要注意Pytorch中使用C语言的宏定义语法实现了多态,上述的THFloatTensor_uniform对应通过宏定义展开的函数,也就是下面的函数在编译过程中通过宏定义的方式展开生成THFloatTensor_uniform,具体的解释可以看这里

而下面这个函数中的TH_TENSOR_APPLY类似于map函数,对Tensor中每一个元素执行该操作,具体这里不进行深入。

void THTensor_(uniform)(THTensor *self, double a, double b, at::Generator *_generator)
{
  auto gen = at::get_generator_or_default<at::CPUGenerator>(_generator, at::detail::getDefaultCPUGenerator());
  // See Note [Acquire lock when using random generators]
  std::lock_guard<std::mutex> lock(gen->mutex_);

  #if defined(TH_REAL_IS_FLOAT)
  at::uniform_real_distribution<float> uniform((float)a, (float)b);
  TH_TENSOR_APPLY(scalar_t, self, *self_data = (scalar_t)uniform(gen););
  #else
  at::uniform_real_distribution<double> uniform(a, b);
  TH_TENSOR_APPLY(scalar_t, self, *self_data = (scalar_t)uniform(gen););
  #endif
}

紧接着进行下一步,在对Tensor初始化之后,我们该执行torch.rand(3, 4)[0]这一步中最后的索引[0]操作,对应:

_t1 = torch.rand(3, 4)
_t2 = _t1.__getitem__(0)    # <--- here
del _t1
_t3 = torch.rand(3, 4)
r = _t2.__add__(_t3)
del _t2
del _t3

剩下的步骤就与之前的原理相同,之后只展示代码流程,就不进行详细描述了:

# torch/tensor.py
class Tensor(torch._C._TensorBase):
// torch/csrc/autograd/python_variable.cpp

PyTypeObject THPVariableType = {
  PyVarObject_HEAD_INIT(nullptr, 0)
  "torch._C._TensorBase",                /* tp_name */
  sizeof(THPVariable),                   /* tp_basicsize */
  (destructor)THPVariable_dealloc,       /* tp_dealloc */
  &THPVariable_as_mapping,               /* tp_as_mapping */
  Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC, /* tp_flags */
  (traverseproc)THPVariable_traverse,    /* tp_traverse */
  (inquiry)THPVariable_clear,            /* tp_clear */
  THPVariable_properties,                /* tp_getset */
  THPVariable_pynew                      /* tp_new */
};
PyObject* THPVariable_getitem(PyObject* self, PyObject* index) {
  if (index == Py_None) {
    return wrap(self_.unsqueeze(0));
  } else if (index == Py_Ellipsis) {
    return wrap(at::alias(self_));
  } else if (THPUtils_checkLong(index)) {
    return wrap(applySelect(self_, 0, THPUtils_unpackLong(index)));
  } else if (PySlice_Check(index)) {
    return wrap(applySlice(self_, 0, index, true));
  }

  // wrap index in a tuple if it's not already one
  THPObjectPtr holder = wrapTuple(index);

  variable_list variableIndices;
  Variable sliced = applySlicing(self_, holder.get(), variableIndices);

...

static Variable applySelect(const Variable& self, int64_t dim, int64_t index, 
    int64_t real_dim=0) {
  int64_t size = self.size(dim);
  return self.select(dim, index);
}
// aten/src/ATen/core/TensorMethods.h
inline Tensor Tensor::select(int64_t dim, int64_t index) const {
    static auto table = globalATenDispatch().getOpTable("aten::select(Tensor(a) self, int dim, int index) -> Tensor(a)");
    return table->getOp<Tensor (const Tensor &, int64_t, int64_t)>(tensorTypeIdToBackend(type_id()), is_variable())(*this, dim, index);
}

aten/src/ATen/native/native_functions.yaml

- func: select(Tensor(a) self, int dim, int index) -> Tensor(a)
  variants: function, method
  device_guard: False
  named_guard: False

build/aten/src/ATen/TypeDefault.cpp

auto registerer = torch::RegisterOperators()
  .op(torch::RegisterOperators::options()
    .schema("aten::select.int(Tensor(a) self, int dim, int index) -> Tensor(a)")
    .impl_unboxedOnlyC10CatchAllKernel<Tensor (const Tensor &, int64_t, int64_t), &TypeDefault::select>()
    .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA))

...

Tensor TypeDefault::select(const Tensor & self, int64_t dim, int64_t index) {
    return at::native::select(self, dim, index);
}

aten/src/ATen/native/TensorShape.cpp

Tensor select(const Tensor& self, int64_t dim, int64_t index) {
  auto sizes = self.sizes().vec();
  auto strides = self.strides().vec();
  auto storage_offset = self.storage_offset() + index * strides[dim];
  sizes.erase(sizes.begin() + dim);
  strides.erase(strides.begin() + dim);
  auto result = self.as_strided(sizes, strides, storage_offset);

build/aten/src/ATen/core/TensorMethods.h

inline Tensor Tensor::as_strided(IntArrayRef size, IntArrayRef stride, c10::optional<int64_t> storage_offset) const {
    static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::as_strided", ""}).value();
    return c10::Dispatcher::singleton().callUnboxedOnly<Tensor, const Tensor &, IntArrayRef, IntArrayRef, c10::optional<int64_t>>(
        op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast<Tensor&>(*this), size, stride, storage_offset);
}

aten/src/ATen/native/native_functions.yaml

- func: as_strided(Tensor(a) self, int[] size, int[] stride, int? storage_offset=None) -> Tensor(a)
  variants: function, method
  dispatch:
    CPU: as_strided_tensorimpl
    CUDA: as_strided_tensorimpl

aten/src/ATen/native/TensorShape.cpp

Tensor as_strided_tensorimpl(const Tensor& self, IntArrayRef size, IntArrayRef stride, optional<int64_t> storage_offset_) {
  auto storage_offset = storage_offset_.value_or(self.storage_offset());
  auto result = detail::make_tensor<TensorImpl>(Storage(self.storage()), self.type_set());
  setStrided(result, size, stride, storage_offset);
  return result;
}

c10/core/Storage.h

struct C10_API Storage {
 protected:
  c10::intrusive_ptr<StorageImpl> storage_impl_;

接下来一步,释放_t1.

_t1 = torch.rand(3, 4)
_t2 = _t1.__getitem__(0)
del _t1                    # <--- here
_t3 = torch.rand(3, 4)
r = _t2.__add__(_t3)
del _t2
del _t3

torch/tensor.py

class Tensor(torch._C._TensorBase):

torch/csrc/autograd/python_variable.cpp

PyTypeObject THPVariableType = {
  PyVarObject_HEAD_INIT(nullptr, 0)
  "torch._C._TensorBase",                /* tp_name */
  sizeof(THPVariable),                   /* tp_basicsize */
  (destructor)THPVariable_dealloc,       /* tp_dealloc */
  &THPVariable_as_mapping,               /* tp_as_mapping */
  Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC, /* tp_flags */
  (traverseproc)THPVariable_traverse,    /* tp_traverse */
  (inquiry)THPVariable_clear,            /* tp_clear */
  THPVariable_properties,                /* tp_getset */
  THPVariable_pynew                      /* tp_new */
};
static void THPVariable_dealloc(THPVariable* self)
{
  PyObject_GC_UnTrack(self);
  THPVariable_clear(self);
  self->cdata.~Variable();
  Py_TYPE(self)->tp_free((PyObject*)self);
}

torch/csrc/autograd/python_variable.h

struct THPVariable {        
    PyObject_HEAD        
    torch::autograd::Variable cdata;        
    PyObject* backward_hooks = nullptr;        
};

torch/csrc/autograd/variable.h

struct TORCH_API Variable : public at::Tensor {

  ...

class CAFFE2_API Tensor {
    c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> impl_;

    ...

struct C10_API TensorImpl : public c10::intrusive_ptr_target {
  Storage storage_;

  ...

struct C10_API Storage {
 protected:
  c10::intrusive_ptr<StorageImpl> storage_impl_;

  ...

struct C10_API StorageImpl final : public c10::intrusive_ptr_target {
  DataPtr data_ptr_;
  
  ...

class C10_API DataPtr { 
  c10::detail::UniqueVoidPtr ptr_;

  ...

class UniqueVoidPtr {
 std::unique_ptr<void, DeleterFnPtr> ctx_;

 ...

void free_cpu(void* data) {        
#ifdef _MSC_VER        
  _aligned_free(data);        
#else        
  free(data);        
#endif        
}

最后一步,相加。

_t1 = torch.rand(3, 4)
_t2 = _t1.__getitem__(0)
del _t1
_t3 = torch.rand(3, 4)
r = _t2.__add__(_t3)       # <--- here
del _t2
del _t3

tools/autograd/templates/python_variable_methods.cpp

PyMethodDef variable_methods[] = {
  {"__add__", (PyCFunction)THPVariable_add, METH_VARARGS | METH_KEYWORDS, NULL},
  {"__radd__", (PyCFunction)THPVariable_add, METH_VARARGS | METH_KEYWORDS, NULL},
  {"__iadd__", (PyCFunction)THPVariable_add_, METH_VARARGS | METH_KEYWORDS, NULL},
bool THPVariable_initModule(PyObject *module)
{
  static std::vector<PyMethodDef> methods;
  THPUtils_addPyMethodDefs(methods, torch::autograd::variable_methods);
  PyModule_AddObject(module, "_TensorBase",   (PyObject *)&THPVariableType);

aten/src/ATen/native/native_functions.yaml

- func: add(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
  variants: function, method
  dispatch:
    CPU: add
    CUDA: add
    SparseCPU: add
    SparseCUDA: add
    MkldnnCPU: mkldnn_add

torch/csrc/autograd/generated/python_variable_methods.cpp

static PyObject * THPVariable_add(PyObject* self_, PyObject* args, PyObject* kwargs)
{
  HANDLE_TH_ERRORS
  static PythonArgParser parser({
    "add(Scalar alpha, Tensor other)|deprecated",
    "add(Tensor other, *, Scalar alpha=1)",
  }, /*traceable=*/true);
  auto& self = reinterpret_cast<THPVariable*>(self_)->cdata;
  ParsedArgs<3> parsed_args;
  auto r = parser.parse(args, kwargs, parsed_args);

  if (r.idx == 0) {
    return wrap(dispatch_add(self, r.scalar(0), r.tensor(1)));
  } else if (r.idx == 1) {
    return wrap(dispatch_add(self, r.tensor(0), r.scalar(1)));
  }
  Py_RETURN_NONE;
  END_HANDLE_TH_ERRORS
}

torch/csrc/autograd/generated/python_variable_methods_dispatch.h

inline Tensor dispatch_add(Tensor & self, const Tensor & other, Scalar alpha) {
  AutoNoGIL no_gil;
  return self.add(other, alpha);
}

build/aten/src/ATen/core/TensorMethods.h

inline Tensor Tensor::add(const Tensor & other, Scalar alpha) const {
    static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::add", "Tensor"}).value();
    return c10::Dispatcher::singleton().callUnboxed<Tensor, const Tensor &, const Tensor &, Scalar>(
        op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other)), const_cast<Tensor&>(*this), other, alpha);
}

aten/src/ATen/native/BinaryOps.cpp

namespace at {
namespace native {
Tensor add(const Tensor& self, const Tensor& other, Scalar alpha) { 
 Tensor result; 
 auto iter = TensorIterator::binary_op(result, self, other); 
 add_stub(iter->device_type(), *iter, alpha); 
 return iter->output(); 
}

aten/src/ATen/native/TensorIterator.cpp

TensorIterator TensorIterator::binary_op(Tensor& out, const Tensor& a,
    const Tensor& b, bool check_mem_overlap) {
  auto iter = TensorIterator();
  iter.set_check_mem_overlap(check_mem_overlap);
  iter.add_output(out);
  iter.add_input(a);
  iter.add_input(b);
  iter.allow_cpu_scalars_ = true;
  iter.build();
  return iter;
}
void TensorIterator::build() {
  // set is_output and is_read_write flags on appropriate tensors
  mark_outputs();
  // Check that the outputs have no internal overlap
  // and do not share memory with inputs.
  check_mem_overlaps();
  // compute the broadcasted shape
  compute_shape();
  // compute each tensor's stride after broadcasting
  compute_strides();
  // re-order dimensions to improve coalescing
  reorder_dimensions();
  // compute the result dtype and device
  compute_types();
  // allocate the output tensor if it's not provided
  allocate_outputs();
  // coalesce adjacent dimensions when possible
  coalesce_dimensions();

  for (auto& op : operands_) {
    TORCH_INTERNAL_ASSERT(op.tensor.defined());
    op.data = op.tensor.data_ptr();
  }
}
void TensorIterator::allocate_outputs() {
  for (int i = 0; i < num_outputs_; i++) {
    auto& op = operands_[i];
    if (!op.tensor.defined()) {
      TORCH_INTERNAL_ASSERT(op.is_type_defined(), "no type for operand", i);
      int element_size = elementSize(op.dtype);
      op.stride_bytes = compatible_stride(element_size);

      auto tensor_shape = invert_perm(shape_);
      auto tensor_stride = invert_perm(op.stride_bytes);
      for (int dim = 0; dim < ndim(); dim++) {
        tensor_stride[dim] /= element_size;
      }
      op.tensor = at::empty_strided(tensor_shape, tensor_stride, op.options());
    }
  }
}

aten/src/ATen/native/BinaryOps.h

using binary_fn_alpha = void(*)(TensorIterator&, Scalar alpha);
DECLARE_DISPATCH(binary_fn_alpha, add_stub);

aten/src/ATen/native/cpu/BinaryOpsKernel.cpp

REGISTER_DISPATCH(add_stub, &add_kernel);

aten/src/ATen/native/cpu/BinaryOpsKernel.cpp

void add_kernel(TensorIterator& iter, Scalar alpha_scalar) {
  if (iter.dtype() == ScalarType::Bool) {
      using scalar_t = bool;
      auto alpha = alpha_scalar.to<scalar_t>();
      cpu_kernel(iter,
        [=](scalar_t a, scalar_t b) -> scalar_t { return a + alpha * b; });
  } else {
    AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, iter.dtype(), "add_cpu/sub_cpu", [&]() {
      auto alpha = alpha_scalar.to<scalar_t>();
      auto alpha_vec = Vec256<scalar_t>(alpha);
      cpu_kernel_vec(iter,
        [=](scalar_t a, scalar_t b) -> scalar_t { return a + alpha * b; },
        [=](Vec256<scalar_t> a, Vec256<scalar_t> b) {
          return vec256::fmadd(b, alpha_vec, a);
        });
      });
  }
} 

之后的操作前面已经介绍过了,不再赘述。

_t1 = torch.rand(3, 4)
_t2 = _t1.__getitem__(0)
del _t1
_t3 = torch.rand(3, 4)
r = _t2.__add__(_t3)       
del _t2                # <--- here
del _t3

至此所有操作以及源码流程结束。

参考

https://www.52coding.com.cn/2019/05/05/PyTorch5/
https://github.com/Microsoft/vscode-cpptools/issues/891
https://github.com/pytorch/pytorch/wiki/Life-of-a-Tensor

  点赞
本篇文章采用 署名-非商业性使用-禁止演绎 4.0 国际 进行许可
转载请务必注明来源: https://oldpan.me/archives/life-of-a-tensor

   关注Oldpan博客微信公众号,你最需要的及时推送给你。


发表评论

邮箱地址不会被公开。 必填项已用*标注

评论审核已启用。您的评论可能需要一段时间后才能被显示。