点击小眼睛开启蜘蛛网特效

TORCH.FX第二篇——PTQ量化实操

《TORCH.FX第二篇——PTQ量化实操》

好久不见各位,哈哈,又鸽了好久。

本文紧接上一篇《实践torch.fx第一篇——基于Pytorch的模型优化量化神器》继续说,主要讲如何利用FX进行模型量化

为什么这篇文章拖了这么久,有部分原因是因为Pytorch的FX变动有点频繁,我在使用过程中也尝试补充些代码和官方对齐,而且官方的更新比较频繁,很多琐碎的API偶尔会变化。因为怕文章的实时性不够,所以拖了一段时间,所幸比较好的观察了一段时间,发现FX主要API不怎么变,整体流程不会变化,还好还好。

《TORCH.FX第二篇——PTQ量化实操》

目前基于6月24日的FX版本进行讲解,借助FX跑一遍量化的过程,FX推出一大亮点就是支持量化,比起之前Pytorch的Eager Mode Quantization好用了不少,虽然还有很多需要补充的功能,但是已经可以实现一些常见模型的量化任务了。

下一篇文章打算讲的fx2trt,可以将FX量化的模型部署到TensorRT。这个工具也在最近从Pytorch主仓库移动到了这里,合并到了Pytorch/TensorRT当中,后续我也会按照新的仓库来讲解,不过总体上变化不大。

回顾一下

因为距上一篇有一段时间了,首先简单回顾下FX的功能:

  • A practical analysis of the features of program capture and transformation that are important for deep learning programs.
  • A Python-only program capture library that implements these features and can be customized to capture different levels of program detail
  • A simple 6 instruction IR for representing captured programs that focuses on ease of understanding and ease of doing static analysis
  • A code generation system for returning transformed code back to the host language’s ecosystem
  • Case studies in how torch.fx has been used in practice to develop features for performance optimization, program analysis, device lowering, and more

上述就是FX的功能组件介绍,简单来说就是可以trace你的nn.module,然后可以做一些变换,然后还可以生成新的经过变换后的nn.module。上一篇中已经介绍了一些fx的使用场景:

  • 自动化修改网络
  • profile网络
  • debug网络
  • 定制hook等

而这篇文章就是利用FX的transform和analysis以及codegen功能去生成已经量化完的模型。

可以做量化的框架

除了FX,目前可以做量化的框架有不少,我们经常使用的训练框架Pytorch和TensorFlow目前都可以原生量化。而很多推理框架也可以进行量化,比如ONNXruntime和TVM。国内也有很多好用的量化工具,其中个人觉着比较好用的是PPQ,支持多种后端,主要是人家教程出的也不少,方便我们快速上手使用,这点好评。

这里也列一下其他可以做量化的框架(或者说有自己的量化工具):

本文主要介绍Pytorch的FX量化工具,作为Pytorch原生支持的量化工具,在某些方面肯定是有些优势的。不过需要注意的是,FX目前的开发仍然在积极推进中,最起码每天都有一些pull request吧,我每隔一段时间就会重新同步下官方的代码,都快跟不上了。

Pytorch量化方式

Pytorch目前支持两种量化方法:Eager Mode以及FX,FX没出来之前大家都是用Eager Mode进行量化,后续FX出世后,Pytorch官方建议优先使用FX:

New users of quantization are encouraged to try out FX Graph Mode Quantization first, if it does not work, user may try to follow the guideline of using FX Graph Mode Quantization or fall back to eager mode quantization.

列下两者的区别:

《TORCH.FX第二篇——PTQ量化实操》

Eager Mode的缺点很明显:

  • 需要手动设置哪些节点需要量化哪些节点不需要量化,哪些节点需要融合(比如CONV+BN+RELU)哪些不需要
  • 某些比较特殊的op,例如add和concat需要特殊对待
  • 对于没有通过Class包装的op,比如functional.conv2d或者functional.linear,无能为力

其实最重要的就是缺乏自动化,啥都要自己写:

import torch

# define a floating point model where some layers could be statically quantized
class M(torch.nn.Module):
    def __init__(self):
        super(M, self).__init__()
        # QuantStub converts tensors from floating point to quantized
        self.quant = torch.quantization.QuantStub()
        self.conv = torch.nn.Conv2d(1, 1, 1)
        self.relu = torch.nn.ReLU()
        # DeQuantStub converts tensors from quantized to floating point
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        # 自己指定开始量化的层
        x = self.quant(x)
        x = self.conv(x)
        x = self.relu(x)
        # 指定结束量化的层
        x = self.dequant(x)
        return x

# create a model instance
model_fp32 = M()
# model must be set to eval mode for static quantization logic to work
model_fp32.eval()
model_fp32.qconfig = torch.quantization.get_default_qconfig('fbgemm')
# 指定融合的层
model_fp32_fused = torch.quantization.fuse_modules(model_fp32, [['conv', 'relu']])
model_fp32_prepared = torch.quantization.prepare(model_fp32_fused)
input_fp32 = torch.randn(4, 1, 4, 4)
model_fp32_prepared(input_fp32)
model_int8 = torch.quantization.convert(model_fp32_prepared)
res = model_int8(input_fp32)

小模型好说,大模型的话,除非模型简单都可以直接量化,否则需要在torch.nn.Module中添加很多torch.quantization.QuantStub()的标记精细化整个模型的量化策略,这个其实和之前在量化番外篇——TensorRT-8的量化细节介绍的QDQ挺像,这篇中的TensorRT处理的QDQ模型就是通过FX导出来的,只不过QDQ是FX自动生成插入的,不像Eager Mode需要自个儿写…可以省去很多工作量。

官方总结的FX量化的优点,可以把FX理解为一个编译器:

  • Simple quantization flow, minimal manual steps
  • Unlocks the possibility of doing higher level optimizations like automatic precision selection

不管是eager还是fx,Pytorch都支持三种量化类型:

  • dynamic quantization(weights quantized with activations read/stored in floating point and quantized for compute.)
  • static quantization (weights quantized, activations quantized, calibration required post training)
  • static quantization aware training (weights quantized, activations quantized, quantization numerics modeled during training)

上述详细介绍可以看官方文档,这里就不赘述了。其实static quantizationstatic quantization aware training基本上就是我们常说的PTQ(训练后量化)和QAT(训练中量化):

  • Post Training Quantization (apply quantization after training, quantization parameters are calculated based on sample calibration data)
  • Quantization Aware Training (simulate quantization during training so that the quantization parameters can be learned together with the model using training data)

FX支持这两种常见量化类型。

TORCH-FX量化

本篇主要介绍FX中的PTQ方法,也就是我们一般常用的后训练量化方法,PTQ方法的优点就是不需要数据进行训练,量化框架只要把所有网络节点搭好,不需要反向传播,正向推理收集量化信息即可。QAT(训练中量化)则麻烦点,后续文章中会介绍。

使用FX做PTQ量化的基本代码结构如下,整体比较简单:

import torch
from torch.quantization import get_default_qconfig
from torch.quantization.quantize_fx import prepare_fx, convert_fx
float_model.eval()  # 因为是PTQ,所以就推理模式就够了
qconfig = get_default_qconfig("fbgemm")  # 指定量化细节配置
qconfig_dict = {"": qconfig}             # 指定量化选项
def calibrate(model, data_loader):       # 校准功能函数
    model.eval()
    with torch.no_grad():
        for image, target in data_loader:
            model(image)
prepared_model = prepare_fx(float_model, qconfig_dict)  # 准备量化模型,比如融合CONV+BN+RELU,然后插入量化观察节点
calibrate(prepared_model, data_loader_test)  # 校准数据集进行标准
quantized_model = convert_fx(prepared_model)  # 把校准后的模型转化为量化版本模型

代码很简单,配置好config之后,调用prepare_fx函数准备模型到量化状态(插入了量化观察节点),然后输入数据集进行校准,之后将校准后的带有scalezero-point的模型变换为真正的量化模型。

上述代码prepare_fx(float_model, qconfig_dict)没有指定is_reference参数,那么convert后的pytorch模型就是实打实的量化模型,所有的算子的精度都是INT8然后运行在CPU上,Pytorch支持以下的INT8后端:

  • x86 CPUs with AVX2 support or higher (without AVX2 some operations have inefficient implementations), via fbgemm
  • ARM CPUs (typically found in mobile/embedded devices), via qnnpack
  • (early prototype) support for NVidia GPU via TensorRT through fx2trt (to be open sourced)

如果加上is_reference参数,量化后的模型则会仅仅保存量化信息,但实际跑的还是FP32精度的op(通过quantize->dequantize->fp32-op->quantize->dequantize)模型,一般称之为simulator quantize,也就是说模型可以通过quantize->dequantize这种fake quantize来模拟量化的过程和量化误差,计算的时候使用的FP32算子,但是计算的输入的input和weight都是经过量化反量化操作得来的。

《TORCH.FX第二篇——PTQ量化实操》

如上图所示,加上is_reference参数后,convert后的模型就是带有fake量化节点的模型,根据相应的convertor,可以将fake量化节点QDQ按照TensorRT中的IQuantizeLayerIDequantizeLayer搭建,即通过fx2trt转化为TensorRT-engine,这个之后会说。关于TensorRT的量化细节也可以参考这篇文章量化番外篇——TensorRT-8的量化细节

因为下一章要转TensorRT,所以这一步选择与TensorRT相同的量化策略:

《TORCH.FX第二篇——PTQ量化实操》

设置整体的量化规则:

  • 整体模型量化方式:activation为per-tensor,weight为per-channel
  • int8对称量化 -128-127
  • 量化的模型是Centernet-resnet50,包含卷积、反卷积、add、concat,bn

设置好FX的量化config:

qconfig = ao.quantization.qconfig.QConfig(
    activation=ao.quantization.observer.HistogramObserver.with_args(
        qscheme=torch.per_tensor_symmetric, dtype=torch.qint8
    ),
    weight=ao.quantization.observer.default_per_channel_weight_observer
)

然后单独对模型中的某一类型算子操作torch.nn.ConvTranspose2d进行设置,这个qconfig会优先匹配,优先级比整体qconfig高,具体细节可以参考_propagate_qconfig_helper这个函数。

为啥要单独配置torch.nn.ConvTranspose2d,因为torch.fx中默认对torch.nn.ConvTranspose2dper-tensor的量化,精度会受影响,我这里修改为per-channel量化,同时指定量化维度ch_axis=1

完整的config如下:

prepared = prepare_fx(fx_model, {"": qconfig,
                                "object_type":[  # 这里设置反卷积的量化规则,注意看维度的per-channel量化ch_axis=1
                                (torch.nn.ConvTranspose2d,
                                    ao.quantization.qconfig.QConfig(
                                            activation=ao.quantization.observer.HistogramObserver.with_args(
                                                qscheme=torch.per_tensor_symmetric, dtype=torch.qint8, 
                                            ),
                                            weight=ao.quantization.observer.PerChannelMinMaxObserver.with_args(
                                                ch_axis=1, dtype=torch.qint8, qscheme=torch.per_channel_symmetric)) )
                                ]
                                },
                                example_inputs=(torch.randn(1, 3, 512, 512),),
                                backend_config_dict=get_tensorrt_backend_config_dict()
                                )

设置好之后就可以开始量化了。

整体量化流程

整体一共这几个步骤:

  • fuse模型,也就是通常的优化,比如conv+bn啥的,利用fx对模型进行transform
  • 插入量化观察算子,即observer
  • 输入数据进行校准,收集weights和activation的max和min信息
  • 把经过数据推理得到的量化数据整理合并到每一层中

首先看一下最开始的模型,Centernet-res50典型的backbone+neck+head,其中neck是upsample,主要由反卷积组成,head就是普通的head(最常见的结构,卷积加点激活层,然后最后conv输出需要的特征维度),就不画图了,看结构比较直观:

CenterNet(
  (backbone): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
          (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
     ...
  (upsampler): UpsampleLayer(
    (deconv_layers): Sequential(
      (0): Conv2d(2048, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ConvTranspose2d(256, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): ReLU(inplace=True)
      (4): ConvTranspose2d(256, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (5): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (6): ReLU(inplace=True)
      (7): ConvTranspose2d(256, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (8): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (9): ReLU(inplace=True)
    )
  )
  (head): Head(
    (hm): Sequential(
      (0): Conv2d(256, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU(inplace=True)
      (2): Conv2d(64, 2, kernel_size=(1, 1), stride=(1, 1))
    )
    (reg): Sequential(
      (0): Conv2d(256, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU(inplace=True)
      (2): Conv2d(64, 2, kernel_size=(1, 1), stride=(1, 1))
    )
    (wh): Sequential(
      (0): Conv2d(256, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU(inplace=True)
      (2): Conv2d(64, 2, kernel_size=(1, 1), stride=(1, 1))
    )
  )
))

fuse 图优化

这一步是一般图优化,和量化无直接关系,不量化的模型也可以这样搞,这样搞之后对量化也有好处。

从上节的模型结构可以看到一些通用、可以应用的图优化策略:

  • conv+bn+relu
  • convtranspose+bn
  • bn+relu

当然还有更激进的优化策略,不过因为FX可能并不代表最终量化模型的运行框架(因为有可能我们经过FX量化后的模型会迁移到其他可以框架中,比如TensorRT),所以其他一些其他平台相关的优化策略就无法实施了。

FX目前的融合策略有,基本的CONV+BN+RELU、CONV+BN、CONV+RELU等等。也包含了常见的融合方法,比如吸bn等操作:

# pytorch/torch/ao/quantization/fuser_method_mappings.py
DEFAULT_OP_LIST_TO_FUSER_METHOD: Dict[Tuple, Union[nn.Sequential, Callable]] = {
    (nn.Conv1d, nn.BatchNorm1d): fuse_conv_bn,
    (nn.Conv1d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
    (nn.Conv2d, nn.BatchNorm2d): fuse_conv_bn,
    (nn.Conv2d, nn.BatchNorm2d, nn.ReLU): fuse_conv_bn_relu,
    (nn.Conv3d, nn.BatchNorm3d): fuse_conv_bn,
    (nn.Conv3d, nn.BatchNorm3d, nn.ReLU): fuse_conv_bn_relu,
    (nn.Conv1d, nn.ReLU): sequential_wrapper2(nni.ConvReLU1d),
    (nn.Conv2d, nn.ReLU): sequential_wrapper2(nni.ConvReLU2d),
    (nn.Conv3d, nn.ReLU): sequential_wrapper2(nni.ConvReLU3d),
    (nn.Linear, nn.BatchNorm1d): fuse_linear_bn,
    (nn.Linear, nn.ReLU): sequential_wrapper2(nni.LinearReLU),
    (nn.BatchNorm2d, nn.ReLU): sequential_wrapper2(nni.BNReLU2d),
    (nn.BatchNorm3d, nn.ReLU): sequential_wrapper2(nni.BNReLU3d),
    (nn.ConvTranspose1d, nn.BatchNorm1d): fuse_convtranspose_bn,
    (nn.ConvTranspose2d, nn.BatchNorm2d): fuse_convtranspose_bn,
    (nn.ConvTranspose3d, nn.BatchNorm3d): fuse_convtranspose_bn,
}

FX的融合匹配策略代码如下,将graph的有向无环图reverse后,从output开始,倒着开始匹配,匹配也很简单粗暴,for循环遍历就行:

  for node in reversed(graph.nodes):
      if node.name not in match_map:
          for pattern, value in patterns.items():
              matched_node_pattern: List[Node] = []
              if is_match(modules, node, pattern):
                  apply_match(pattern, node, (node, pattern, value(node)), matched_node_pattern, node_to_subpattern)
                  break

融合之后将新的node拷贝到新的graph,即fused_graph,构建新的融合后的graphmodule

# pytorch/torch/ao/quantization/fx/fuse.py
# 寻找匹配的 pairs
fusion_pairs = _find_matches(
    input_root, input_graph, fusion_pattern_to_fuse_handler_cls)
fused_graph = Graph()  # 这里新建一个graph
env: Dict[Any, Any] = {}  # env记录已经融合后复制到新graph的node

for node in input_graph.nodes:
    maybe_last_node, pattern, matched_node_pattern, obj, node_to_subpattern = \
        fusion_pairs.get(node.name, (None, None, None, None, None))
    # get the corresponding subpattern for the current node
    if node_to_subpattern is not None:
        node_subpattern = node_to_subpattern.get(node, None)
    else:
        node_subpattern = None
    if maybe_last_node is node:
        assert obj is not None
        root_node_getter = fusion_pattern_to_root_node_getter.get(pattern, default_root_node_getter)
        root_node = root_node_getter(matched_node_pattern)  # type: ignore[index]
        extra_inputs_getter = fusion_pattern_to_extra_inputs_getter.get(pattern, None)
        extra_inputs = []
        if extra_inputs_getter is not None:
            extra_inputs = extra_inputs_getter(matched_node_pattern)
        # TODO: add validation that root_node is a module and has the same type
        # as the root_module in the configuration
        env[node.name] = obj.fuse(
            load_arg, named_modules, fused_graph, root_node, extra_inputs, matched_node_pattern,  # type: ignore[arg-type]
            fuse_custom_config, fuser_method_mapping, is_qat)
    elif maybe_last_node is None or node_subpattern is MatchAllNode:
      # 这里进行融合后的node构建
        env[node.name] = fused_graph.node_copy(node, load_arg)
    # node matched in patterns and is not root is removed here

看一下图优化后的模型:

GraphModule(
  (backbone): Module(
    (conv1): ConvReLU2d(
      (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
      (1): ReLU(inplace=True)
    )
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Module(
      (0): Module(
        (conv1): ConvReLU2d(
          (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
          (1): ReLU(inplace=True)
        )
        (conv2): ConvReLU2d(
          (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): ReLU(inplace=True)
        )
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))
        (downsample): Module(
          (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))
        )
        (relu): ReLU(inplace=True)
      )
      ...
      (2): Module(
        (conv1): ConvReLU2d(
          (0): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1))
          (1): ReLU(inplace=True)
        )
        (conv2): ConvReLU2d(
          (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): ReLU(inplace=True)
        )
        (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1))
        (relu): ReLU(inplace=True)
      )
    )
  )
  (upsampler): Module(
    (deconv_layers): Module(
      (0): Conv2d(2048, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ConvTranspose2d(256, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (3): ReLU(inplace=True)
      (4): ConvTranspose2d(256, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (6): ReLU(inplace=True)
      (7): ConvTranspose2d(256, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (9): ReLU(inplace=True)
    )
  )
  (head): Module(
    (hm): Module(
      (0): ConvReLU2d(
        (0): Conv2d(256, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace=True)
      )
      (2): Conv2d(64, 2, kernel_size=(1, 1), stride=(1, 1))
    )
    (wh): Module(
      (0): ConvReLU2d(
        (0): Conv2d(256, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace=True)
      )
      (2): Conv2d(64, 2, kernel_size=(1, 1), stride=(1, 1))
    )
    (reg): Module(
      (0): ConvReLU2d(
        (0): Conv2d(256, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace=True)
      )
      (2): Conv2d(64, 2, kernel_size=(1, 1), stride=(1, 1))
    )
  )
)

可以看到ConvReLU2dconv+relu或者conv+bn+relu的产物,而ConvTranspose2d后面的BN也被吸进ConvTranspose2d里头了。

插入量化观察算子

模型融合后,就可以进行模型量化了,首先我们需要在模型中插入量化观察算子,具体对应代码中的insert_observers_for_model操作。

不过在开始执行量化的时候,FX会检测之前传入的qconfig是否合法,也就是我们之前传递的反卷积的qconfig是否正确(activation是per-tensor量化,weight是per-channel量化)。因为我们的模型有反卷积操作,因此这里修改了官方的代码,注释掉了torch.ao.quantization.PerChannelMinMaxObserver,就可以使用了(看到pr有更好的解法 https://github.com/pytorch/pytorch/pull/79233):

def assert_valid_qconfig(qconfig: Optional[QConfig],
                         mod: torch.nn.Module) -> None:
    """
    Verifies that this `qconfig` is valid.
    """
    if qconfig is None:
        return
    is_conv_transpose_mod = (
        isinstance(mod, torch.nn.ConvTranspose1d) or
        isinstance(mod, torch.nn.ConvTranspose2d) or
        isinstance(mod, torch.nn.ConvTranspose3d))
    if is_conv_transpose_mod:
        if qconfig.weight is None:
            # for now, we assume that any qconfig for ConvTranspose without a weight is valid
            return
        example_observer = qconfig.weight()
        is_per_channel = (
            # isinstance(example_observer, torch.ao.quantization.PerChannelMinMaxObserver) or  把这句去掉
            isinstance(example_observer, torch.ao.quantization.MovingAveragePerChannelMinMaxObserver)
        )
        assert not is_per_channel, \
            'Per channel weight observer is not supported yet for ConvTranspose{n}d.'  # 实测可以支持

把这个解决后,我们重点看insert_observers_for_model这个函数,负责插入量化观察节点。因为权重不需要推理数据观察,所以只需要插入激活值的observer节点即可。

此时模型的具体op实现还是原先FP32的实现,但是在合适的位置已经插入了观察节点,我们可以运行推理来进行PTQ收集activations和weights的量化信息。

# 插入观察节点后的模型forward示例
def forward(self, input):
    input_1 = input
    activation_post_process_0 = self.activation_post_process_0(input_1);  input_1 = None
    backbone_conv1 = self.backbone.conv1(activation_post_process_0);  activation_post_process_0 = None
    activation_post_process_1 = self.activation_post_process_1(backbone_conv1);  backbone_conv1 = None
    ...
    head_angle_2 = getattr(self.head.angle, "2")(activation_post_process_83);  activation_post_process_83 = None
    activation_post_process_84 = self.activation_post_process_84(head_angle_2);  head_angle_2 = None
    return (activation_post_process_78, activation_post_process_80, activation_post_process_82, activation_post_process_84)

看下模型的部分结构如下,可以发现多出了HistogramObserver,都是activation_post_process_xx,用于观察激活值的分布信息。

(upsampler): Module(
    (deconv_layers): Module(
    (0): Conv2d(2048, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ConvTranspose2d(256, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): ConvTranspose2d(256, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): ConvTranspose2d(256, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (9): ReLU(inplace=True)
    )
)
(activation_post_process_71): HistogramObserver()
(activation_post_process_72): HistogramObserver()
(activation_post_process_73): HistogramObserver()
(activation_post_process_74): HistogramObserver()
(activation_post_process_75): HistogramObserver()
(activation_post_process_76): HistogramObserver()
(activation_post_process_77): HistogramObserver()
(head): Module(
    (hm): Module(
    (0): ConvReLU2d(
        (0): Conv2d(256, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace=True)
    )
    (2): Conv2d(64, 2, kernel_size=(1, 1), stride=(1, 1))
    )
    (wh): Module(
    (0): ConvReLU2d(
        (0): Conv2d(256, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace=True)
    )
    (2): Conv2d(64, 2, kernel_size=(1, 1), stride=(1, 1))
    )
    (reg): Module(
    (0): ConvReLU2d(
        (0): Conv2d(256, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace=True)
    )
    (2): Conv2d(64, 2, kernel_size=(1, 1), stride=(1, 1))
    )
)
(activation_post_process_78): HistogramObserver()
(activation_post_process_79): HistogramObserver()
(activation_post_process_80): HistogramObserver()
(activation_post_process_81): HistogramObserver()
(activation_post_process_82): HistogramObserver()
(activation_post_process_83): HistogramObserver()
)

收集过程中:

  • 激活层使用的是HistogramObserver
  • 权重层使用的是PerChannelMinMaxObserver

接下来就可以喂入数据进行推理校准了,和我们平常的方式一样,准备好图像数据然后可以组batch输入进去,此时的input会输入到我们模型的forward当中:

def forward(self, input):
    input_1 = input
    activation_post_process_0 = self.activation_post_process_0(input_1);  input_1 = None
    backbone_conv1 = self.backbone.conv1(activation_post_process_0);  activation_post_process_0 = None
    activation_post_process_1 = self.activation_post_process_1(backbone_conv1);  backbone_conv1 = None
    backbone_maxpool = self.backbone.maxpool(activation_post_process_1);  activation_post_process_1 = None
    ... 

第一行中,activation_post_process_0 = self.activation_post_process_0(input_1);,实际进入的是HistogramObserver这个观察者对象,其中的forward函数主要是收集min和max信息,最终返回的还是原始输入:

# HistogramObserver::forward
def forward(self, x_orig: torch.Tensor) -> torch.Tensor:
    if x_orig.numel() == 0:
        return x_orig
    x = x_orig.detach()
    min_val = self.min_val
    max_val = self.max_val
    same_values = min_val.item() == max_val.item()
    is_uninitialized = min_val == float("inf") and max_val == float("-inf")
    if is_uninitialized or same_values:
        min_val, max_val = torch.aminmax(x)
        self.min_val.resize_(min_val.shape)
        self.min_val.copy_(min_val)
        self.max_val.resize_(max_val.shape)
        self.max_val.copy_(max_val)
        assert (
            min_val.numel() == 1 and max_val.numel() == 1
        ), "histogram min/max values must be scalar."
        torch.histc(
            x, self.bins, min=int(min_val), max=int(max_val), out=self.histogram
        )
    else:
        new_min, new_max = torch.aminmax(x)
        combined_min = torch.min(new_min, min_val)
        combined_max = torch.max(new_max, max_val)
        # combine the existing histogram and new histogram into 1 histogram
        # We do this by first upsampling the histogram to a dense grid
        # and then downsampling the histogram efficiently
        (
            combined_min,
            combined_max,
            downsample_rate,
            start_idx,
        ) = self._adjust_min_max(combined_min, combined_max, self.upsample_rate)
        assert (
            combined_min.numel() == 1 and combined_max.numel() == 1
        ), "histogram min/max values must be scalar."
        combined_histogram = torch.histc(
            x, self.bins, min=int(combined_min), max=int(combined_max)
        )
        if combined_min == min_val and combined_max == max_val:
            combined_histogram += self.histogram
        else:
            combined_histogram = self._combine_histograms(
                combined_histogram,
                self.histogram,
                self.upsample_rate,
                downsample_rate,
                start_idx,
                self.bins,
            )

        self.histogram.detach_().resize_(combined_histogram.shape)
        self.histogram.copy_(combined_histogram)
        self.min_val.detach_().resize_(combined_min.shape)
        self.min_val.copy_(combined_min)
        self.max_val.detach_().resize_(combined_max.shape)
        self.max_val.copy_(combined_max)
    return x_orig

推理过程中,仅仅涉及到激活层信息的收集,因为PTQ就是前向推理收集激活层信息,不涉及到权重的更新。但是QAT中模型权重会更新,不过这个后话了。

转化量化模型 convert

收集好信息后,我们需要将收集好的min-max转化为实际可用的scaleoffset

转换代码也很简单,调用FX提供的convert_fx,需要加入is_reference=True参数,这里表明我们转换后的量化模型仅仅是包含量化参数,但实际上运行的还是FP32的精度,这种模型是为了之后转换trt做准备。

quantized_fx = convert_fx(model, 
                is_reference=True,  # 选择reference模式
                )   
"""
细节看这里
We will convert an observed model (a module with observer calls) to a reference
quantized model, the rule is simple:
1. for each observer module call in the graph, we'll convert it to calls to
    quantize and dequantize functions based on the observer instance
2. for weighted operations like linear/conv, we need to convert them to reference
    quantized module, this requires us to know whether the dtype configured for the
    weight is supported in the backend, this is done in prepare step and the result
    is stored in observed_node_names, we can decide whether we need to swap the
    module based on this set
"""

那怎么处理呢?我们有很多activation_post_process_xx层,这些层是可以转化为quantize and dequantize层的,具体的函数调用看下面这段代码,其中调用了with graph.inserting_before(node)node.replace_all_uses_with等Graph Manipulation方法去对模型进行修改:

    def replace_observer_with_quantize_dequantize_node(
            model: torch.nn.Module,
            graph: Graph,
            node: Node,
            modules: Dict[str, torch.nn.Module],
            node_name_to_scope: Dict[str, Tuple[str, type]],
            qconfig_map: Dict[str, QConfigAny]) -> None:
        """ Replace activation_post_process module call node with quantize and
        dequantize node

        Before:
        ... -> observer_0(x) -> ...
        After:
        ... -> torch.quantize_per_tensor(x, ...) -> x.dequantize() -> ...
        """
        assert modules is not None
        assert isinstance(node.target, str)
        module_path, prefix = get_module_path_and_prefix(node, node_name_to_scope, qconfig_map)
        observer_module = modules[node.target]
        maybe_quantize_node_info = get_quantize_node_info(observer_module)
        # Skip replacing observers to quant/dequant nodes if the qconfigs of all
        # consumers and producers of this observer are None
        skip_replacement = all([
            has_none_qconfig(n, qconfig_map) for n in
            list(node.args) + list(node.users.keys())])
        ...
        else:
            # otherwise, we can convert the observer moduel call to quantize/dequantize node
            node_type, quantize_op, qparams = maybe_quantize_node_info
            # replace observer node with quant - dequant node
            with graph.inserting_before(node):
                input_node = node.args[0]
                inputs = [input_node]
                for key, value in qparams.items():
                    # TODO: we can add the information of whether a value needs to
                    # be registered as an attribute in qparams dict itself
                    if key in ['_scale_', '_zero_point_']:
                        # For scale and zero_point values we register them as buffers in the root module.
                        # TODO: maybe need more complex attr name here
                        qparam_node = create_getattr_from_value(model, graph, module_path + prefix + key, value)
                        inputs.append(qparam_node)
                    else:
                        # for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph.
                        inputs.append(value)
                # 构建quantized_node和dequantized_node
                quantized_node = graph.create_node(node_type, quantize_op, tuple(inputs), {})
                dequantized_node = graph.call_method("dequantize", args=(quantized_node,))
                node.replace_all_uses_with(dequantized_node)
                graph.erase_node(node)

那权重层怎么处理?converter_fx在遍历整个模型的时候会区分该层的node是什么类型,如果是is_activation_post_process就会进入上面的replace_observer_with_quantize_dequantize_node函数,如果判断为权重则会进入convert_weighted_module函数:

    for node in list(model.graph.nodes):
        if node.op == 'placeholder':
            ...
        elif node.op == "output":
            ...
        elif node.op == "call_module":
            if is_activation_post_process(modules[node.target]):
                observed_node = node.args[0]
                if observed_node in statically_quantized_custom_module_nodes:
                    replace_observer_with_dequantize_node(node, model.graph)
                else:
                    replace_observer_with_quantize_dequantize_node(
                        model, model.graph, node, modules, node_name_to_scope,
                        qconfig_map)
            elif is_observed_standalone_module(modules[node.target]):
                convert_standalone_module(
                    node, modules, model, is_reference, backend_config_dict)
            elif type(modules[node.target]) in set(
                    root_module_classes).union(qat_module_classes).union(fused_module_classes):
                # extra check for fused module classes to make sure they are fused module classes
                # of target modules
                if type(modules[node.target]) in fused_module_classes and \
                   type(modules[node.target][0]) not in root_module_classes:
                    continue
                convert_weighted_module(
                    node, modules, observed_node_names, qconfig_map, backend_config_dict)
            elif type(modules[node.target]) in custom_module_classes:
                convert_custom_module(
                    node, model.graph, modules, custom_module_class_mapping,
                    statically_quantized_custom_module_nodes)

convert_weighted_module函数中主要就是处理weight的量化信息,首先根据设定好的config来进行处理,比如下面代码中的weight_post_process其实就是PerChannelMinMaxObserver对象,执行的时候会收集该层权重的min-max信息,收集好之后通过get_qparam_dict计算出scale和offset并存入wq_or_wq_dict中:

# pytorch/torch/ao/quantization/fx/convert.py
    ...
    else:
        # weight_post_process is None means the original module is not a QAT module
        # we need to get weight_post_process from qconfig in this case
        if weight_post_process is None:
            weight_post_process = qconfig.weight()  # type: ignore[union-attr, operator]
        # run weight observer
        # TODO: This is currently a hack for QAT to get the right shapes for scale and zero point.
        # In the future, we should require the user to calibrate the model after calling prepare
        # Issue: https://github.com/pytorch/pytorch/issues/73941
        weight_post_process(float_module.weight)  # type: ignore[operator]
        wq_or_wq_dict = get_qparam_dict(weight_post_process)

    # We use the same reference module for all modes of quantization: static, dynamic, weight_only
    # root_module_to_quantized_reference_module: module mapping from root (floating point) module class
    # to quantized reference module class, e.g. nn.Conv2d to nn.quantized._reference.Conv2d
    root_module_to_quantized_reference_module = get_root_module_to_quantized_reference_module(backend_config_dict)
    ref_qmodule_cls = root_module_to_quantized_reference_module.get(type(float_module), None)
    assert ref_qmodule_cls is not None, f"No reference quantized module class configured for {type(float_module)}"
    ref_qmodule = ref_qmodule_cls.from_float(float_module, wq_or_wq_dict)  # type: ignore[attr-defined]
    if fused_module is not None:
        fused_module[0] = ref_qmodule  # type: ignore[operator]
    else:
        parent_name, name = _parent_name(node.target)
        setattr(modules[parent_name], name, ref_qmodule)

得到权重层的wq_or_wq_dict信息后,通过get_root_module_to_quantized_reference_module获取FP32-OP对应的量化版本OP,对应关系如下:

<class 'torch.nn.modules.conv.Conv1d'>: <class 'torch.nn.quantized._reference.modules.conv.Conv1d'>
<class 'torch.nn.modules.conv.ConvTranspose1d'>: <class 'torch.nn.quantized._reference.modules.conv.ConvTranspose1d'>
<class 'torch.nn.modules.conv.Conv2d'>: <class 'torch.nn.quantized._reference.modules.conv.Conv2d'>
<class 'torch.nn.modules.conv.ConvTranspose2d'>: <class 'torch.nn.quantized._reference.modules.conv.ConvTranspose2d'>
...

如果是conv2d,则ref_qmodule_clstorch.nn.quantized._reference.modules.conv.Conv2d,通过from_float(float_module, wq_or_wq_dict)传入FP32版本的conv2d-op,通过fp32版本的参数和之前收集好的wq_or_wq_dict构建量化版本的conv2d,直接替换模型中的FP32版本的op,此时模型中conv2d -> quantized-reference-conv2d,卷积和反卷积都变成了reference版本。

最终的reference量化模型

经过以上步骤,经过convert_fx后的模型,怎么说,其实就是simulator quantization,也就是模拟量化,我们校准得到的scale和offset用于模拟模型的量化误差,实际模型执行的时候是这样:

def forward(self, input):
    input_1 = input
    # 首先得到量化参数scale和zero-point
    backbone_conv1_input_scale_0 = self.backbone_conv1_input_scale_0
    backbone_conv1_input_zero_point_0 = self.backbone_conv1_input_zero_point_0
    # 然后量化输入
    quantize_per_tensor = torch.quantize_per_tensor(input_1, backbone_conv1_input_scale_0, backbone_conv1_input_zero_point_0, torch.qint8);  
    input_1 = backbone_conv1_input_scale_0 = backbone_conv1_input_zero_point_0 = None
    # 然后反量化输入
    dequantize = quantize_per_tensor.dequantize();  quantize_per_tensor = None
    backbone_conv1 = self.backbone.conv1(dequantize);  dequantize = None
    ...
    dequantize_80 = quantize_per_tensor_83.dequantize();  quantize_per_tensor_83 = None
    head_angle_2 = getattr(self.head.angle, "2")(dequantize_80);  dequantize_80 = None
    head_angle_2_output_scale_0 = self.head_angle_2_output_scale_0
    head_angle_2_output_zero_point_0 = self.head_angle_2_output_zero_point_0
    quantize_per_tensor_84 = torch.quantize_per_tensor(head_angle_2, head_angle_2_output_scale_0, head_angle_2_output_zero_point_0, torch.qint8);  head_angle_2 = head_angle_2_output_scale_0 = head_angle_2_output_zero_point_0 = None
    dequantize_81 = quantize_per_tensor_78.dequantize();  quantize_per_tensor_78 = None
    dequantize_82 = quantize_per_tensor_80.dequantize();  quantize_per_tensor_80 = None
    dequantize_83 = quantize_per_tensor_82.dequantize();  quantize_per_tensor_82 = None
    dequantize_84 = quantize_per_tensor_84.dequantize();  quantize_per_tensor_84 = None
    return {'hm': dequantize_81, 'wh': dequantize_82, 'reg': dequantize_83, 'angle': dequantize_84}

看一下converter后reference模型结构,可以看到该融合的都融合了,所有conv带有参数的计算层都替换了为Quantizedxxxx(Reference)版本,其他比如maxpooling和add、concat的不需要变动,到时候在转trt的时候,在trt内部会进行处理:

GraphModule(
  (backbone): Module(
    (conv1): ConvReLU2d(
      (0): QuantizedConv2d(Reference)(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
      (1): ReLU(inplace=True)
    )
    )
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Module(
      (0): Module(
        (conv1): ConvReLU2d(
          (0): QuantizedConv2d(Reference)(64, 64, kernel_size=(1, 1), stride=(1, 1))
          (1): ReLU(inplace=True)
        )
        (conv2): ConvReLU2d(
          (0): QuantizedConv2d(Reference)(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): ReLU(inplace=True)
        )
        (conv3): QuantizedConv2d(Reference)(64, 256, kernel_size=(1, 1), stride=(1, 1))
        (downsample): Module(
          (0): QuantizedConv2d(Reference)(64, 256, kernel_size=(1, 1), stride=(1, 1))
        )
        (relu): ReLU(inplace=True)
      )
    )
      ...
  (upsampler): Module(
    (deconv_layers): Module(
      (0): QuantizedConv2d(Reference)(2048, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): QuantizedConvTranspose2d(Reference)(256, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (3): ReLU(inplace=True)
      ...
  )
  (head): Module(
    (hm): Module(
      (0): ConvReLU2d(
        (0): QuantizedConv2d(Reference)(256, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace=True)
      )
      (2): QuantizedConv2d(Reference)(64, 2, kernel_size=(1, 1), stride=(1, 1))
    )
  )
    ...
))

同时也可以看此时模型的IR信息:

opcode         name                                                        target                                                                  args                                                                                                                                                                      kwargs
-------------  ----------------------------------------------------------  ----------------------------------------------------------------------  ------------------------------------------------------------------------------------------------------------------------------------------------------------------------  --------
placeholder    input_1                                                     input                                                                   ()                                                                                                                                                                        {}
get_attr       backbone_base_base_layer_0_input_scale_0                    backbone_base_base_layer_0_input_scale_0                                ()                                                                                                                                                                        {}
get_attr       backbone_base_base_layer_0_input_zero_point_0               backbone_base_base_layer_0_input_zero_point_0                           ()                                                                                                                                                                        {}
call_function  quantize_per_tensor                                         <built-in method quantize_per_tensor of type object at 0x7f4f0d8491a0>  (input_1, backbone_base_base_layer_0_input_scale_0, backbone_base_base_layer_0_input_zero_point_0, torch.qint8)                                                           {}
call_method    dequantize                                                  dequantize                                                              (quantize_per_tensor,)                                                                                                                                                    {}
call_module    backbone_base_base_layer_0                                  backbone.base.base_layer.0                                              (dequantize,)                                                                                                                                                             {}
get_attr       backbone_base_base_layer_0_output_scale_0                   backbone_base_base_layer_0_output_scale_0                               ()                                                                                                                                                                        {}
get_attr       backbone_base_base_layer_0_output_zero_point_0              backbone_base_base_layer_0_output_zero_point_0                          ()                                                                                                                                                                        {}
call_function  quantize_per_tensor_1                                       <built-in method quantize_per_tensor of type object at 0x7f4f0d8491a0>  (backbone_base_base_layer_0, backbone_base_base_layer_0_output_scale_0, backbone_base_base_layer_0_output_zero_point_0, torch.qint8)                                      {}
call_method    dequantize_1                                                dequantize                                                              (quantize_per_tensor_1,)                                                                                                                                                  {}
call_module    backbone_base_level0_0                                      backbone.base.level0.0                                                  (dequantize_1,)                                                                                                                                                           {}
get_attr       backbone_base_level0_0_output_scale_0                       backbone_base_level0_0_output_scale_0                                   ()                                                                                                                                                                        {}
get_attr       backbone_base_level0_0_output_zero_point_0                  backbone_base_level0_0_output_zero_point_0                              ()                                                                                                                                                                        {}
call_function  quantize_per_tensor_2                                       <built-in method quantize_per_tensor of type object at 0x7f4f0d8491a0>  (backbone_base_level0_0, backbone_base_level0_0_output_scale_0, backbone_base_level0_0_output_zero_point_0, torch.qint8)                                                  {}
call_method    dequantize_2                                                dequantize                                                              (quantize_per_tensor_2,)                                                                                                                                                  {}
call_module    backbone_base_level1_0                                      backbone.base.level1.0                                                  (dequantize_2,)                                                                                                                                                           {}
get_attr       backbone_base_level1_0_output_scale_0                       backbone_base_level1_0_output_scale_0                                   ()                                                                                                                                                                        {}
get_attr       backbone_base_level1_0_output_zero_point_0                  backbone_base_level1_0_output_zero_point_0                              ()                                                                                                                                                                        {}
call_function  quantize_per_tensor_3                                       <built-in method quantize_per_tensor of type object at 0x7f4f0d8491a0>  (backbone_base_level1_0, backbone_base_level1_0_output_scale_0, backbone_base_level1_0_output_zero_point_0, torch.qint8)                                                  {}
call_method    dequantize_3                                                dequantize                                                              (quantize_per_tensor_3,)                                                                                                                                                  {}
call_module    backbone_base_level2_downsample                             backbone.base.level2.downsample                                         (dequantize_3,)                                                                                                                                                           {}
get_attr       backbone_base_level2_downsample_output_scale_0              backbone_base_level2_downsample_output_scale_0                          ()                                                                                                                                                                        {}
get_attr       backbone_base_level2_downsample_output_zero_point_0         backbone_base_level2_downsample_output_zero_point_0                     ()                                                                                                                                                                        {}
call_function  quantize_per_tensor_4                                       <built-in method quantize_per_tensor of type object at 0x7f4f0d8491a0>  (backbone_base_level2_downsample, backbone_base_level2_downsample_output_scale_0, backbone_base_level2_downsample_output_zero_point_0, torch.qint8)                       {}                                                                                                                                            

至此,我们就得到了量化后的模型,这个模型的类型是GraphModule,和nn.Module类似,有对应的forward函数。我们可以直接在Pytorch中执行这个模型测试精度,不过需要注意,这里仅仅是测试模拟的量化模型精度,也是测试校准后得到的scale和offset有没有问题,在转化为TensorRT后精度可能会略有差异,毕竟实际推理框架内部实现的一些算子细节我们是不知道的。

type(quantized_fx)
<class 'torch.fx.graph_module.GraphModule.__new__.<locals>.GraphModuleImpl'>

再提一句,其实目前FX的量化对于TensorRT的转换是比较友好的,FX只需要把量化后模型的QDQ层相应地转化为trt的QDQ层即可,这时TensorRT的network中会包含TensorRT定义的QDQ层,TensorRT内部会对QDQ层进行自动优化,最终生成的engine中QDQ中的参数已经被吸进其它层中,也算是图优化过程的一部分。

运行模拟量化模型

我这边简单在COCO数据集上测试了下量化前后的Centernet模型精度,直接测试的mAP,精度误差相差在1%以内,一般来说检测模型在1%以内都算正常。

再强调下,我这里的模型在量化后默认是reference模式,也就是模拟量化的方式(因为之后要转为TensorRT),此时的量化模型运行的精度还是FP32,只不过模型中的算子会在计算时进行quantizedequantize的操作。

为啥要这样搞,这样搞可以方便地不需要硬件(也就是可以实际运行INT8指令集的硬件)便可以模拟量化误差,方便定位问题,如果模拟量化过程中就已经有问题了,那么在硬件上运行肯定也有问题。但反之则不然,如果在硬件上运行发现精度不够,但是模拟量化的精度够,那就是INT8算子实现的bug问题了。

conv2d举例子,Pytorch模拟量化的算子在pytorch/torch/nn/quantized/_reference/modules/目录下:

class Conv2d(_ConvNd, nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1, bias=True,
                 padding_mode='zeros',
                 device=None,
                 dtype=None,
                 weight_qparams: Optional[Dict[str, Any]] = None):
        nn.Conv2d.__init__(
            self, in_channels, out_channels, kernel_size, stride, padding, dilation,
            groups, bias, padding_mode, device, dtype)
        self._init_weight_qparams(weight_qparams, device)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        we have:
        w(float) -- quant - dequant \
        x(float) ------------- F.conv2d ---

        In the full model, we will see
        w(float) -- quant - *dequant \
        x -- quant --- *dequant --  *F.conv2d --- *quant - dequant
        and the backend should be able to fuse the ops with `*` into a quantized conv2d
        """
        weight_quant_dequant = self.get_weight()  # 对权重进行量化以及反量化操作
        result = F.conv2d(
            x, weight_quant_dequant, self.bias, self.stride,
            self.padding, self.dilation, self.groups)
        return result

    def _get_name(self):
        return "QuantizedConv2d(Reference)"

    @classmethod
    def from_float(cls, float_conv, weight_qparams):
        return _ConvNd.from_float(cls, float_conv, weight_qparams)

forward输入的input是上一层quantize + dequantize后的input,权重也是quantize + dequantize的权重,而执行的conv2d是FP32实现的,体现了一个模拟的过程。我们也可以补充一个forward_fp32成员方法,使用原始的FP32权重就可以,来实现非量化的操作,用于作对比。

DEBUG 精度

利用reference模型,我们可以自己写个简单的小工具,来跑一下模拟量化模型的每一层精度怎么样如何。参照官方教程中的ShapeProp类,我们可以模仿着写一个:

import torch
import torch.fx
from torch.fx.node import Node

from typing import Dict

class ShapeProp:
    """
    Shape propagation. This class takes a `GraphModule`.
    Then, its `propagate` method executes the `GraphModule`
    node-by-node with the given arguments. As each operation
    executes, the ShapeProp class stores away the shape and
    element type for the output values of each operation on
    the `shape` and `dtype` attributes of the operation's
    `Node`.
    """
    def __init__(self, mod):
        self.mod = mod
        self.graph = mod.graph
        self.modules = dict(self.mod.named_modules())

    def propagate(self, *args):
        args_iter = iter(args)
        env : Dict[str, Node] = {}

        def load_arg(a):
            return torch.fx.graph.map_arg(a, lambda n: env[n.name])

        def fetch_attr(target : str):
            target_atoms = target.split('.')
            attr_itr = self.mod
            for i, atom in enumerate(target_atoms):
                if not hasattr(attr_itr, atom):
                    raise RuntimeError(f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}")
                attr_itr = getattr(attr_itr, atom)
            return attr_itr

        # 主要修改以下部分
        for node in self.graph.nodes:
            if node.op == 'placeholder':
                result = next(args_iter)
            elif node.op == 'get_attr':
                result = fetch_attr(node.target)
            elif node.op == 'call_function':
                result = node.target(*load_arg(node.args), **load_arg(node.kwargs))
            elif node.op == 'call_method':
                self_obj, *args = load_arg(node.args)
                kwargs = load_arg(node.kwargs)
                result = getattr(self_obj, node.target)(*args, **kwargs)
            elif node.op == 'call_module':
                result = self.modules[node.target](*load_arg(node.args), **load_arg(node.kwargs))

            # This is the only code specific to shape propagation.
            # you can delete this `if` branch and this becomes
            # a generic GraphModule interpreter.
            if isinstance(result, torch.Tensor):
                node.shape = result.shape
                node.dtype = result.dtype

            env[node.name] = result

        return load_arg(self.graph.result)

将for循环中的推理部分修改为,其中forward_fp32是上节提到补充的FP32实现方法,用于作对比:

# op_sim即当前的node-op
result_fp32_layer = op_sim.forward_fp32(*load_arg(node.args), **load_arg(node.kwargs))
result_int8_layer = op_sim(*load_arg(node.args), **load_arg(node.kwargs))
result_fp32_model = op_sim.forward_fp32(*load_arg_fp32(node.args), **load_arg_fp32(node.kwargs))
activation_dif_accmulated = torch_cosine_similarity(result_int8_layer, result_fp32_model)
activation_dif_layer = torch_cosine_similarity(result_int8_layer, result_fp32_layer)
weight_dif = torch_cosine_similarity(op_sim.weight, op_sim.get_weight())

对比三个地方:

  • 当前激活层FP32-INT8误差
  • 当前激活层FP32-INT8累计误差
  • 当前层权重误差

以下是COCO数据集在Centernet下的精度对比信息,一般来说,余弦相似度大于0.99就问题不大:

Quantize similarity : 
dequantize [activation_dif_layer:0.9945, activation_dif_accmulated:0.9945]
backbone_conv1 [activation_dif_layer:1.0000, activation_dif_accmulated:0.9975, weight_dif:1.0000]
dequantize_1 [activation_dif_layer:0.9999, activation_dif_accmulated:0.9978]
backbone_maxpool [activation_dif_layer:0.9999, activation_dif_accmulated:0.9978, weight_dif:1.0000]
dequantize_2 [activation_dif_layer:1.0000, activation_dif_accmulated:0.9989]
backbone_layer1_0_conv1 [activation_dif_layer:0.9999, activation_dif_accmulated:0.9983, weight_dif:1.0000]
dequantize_3 [activation_dif_layer:0.9999, activation_dif_accmulated:0.9987]
backbone_layer1_0_conv2 [activation_dif_layer:1.0000, activation_dif_accmulated:0.9991, weight_dif:0.9999]
dequantize_4 [activation_dif_layer:0.9999, activation_dif_accmulated:0.9987]
...

模型可视化

TORCH-FX提供了使用graphviz画FX模型的可视化工具——FxGraphDrawer,直接调用以下接口就可以画当前的FX模型了:

g = FxGraphDrawer(quantized_fx, "centernet_fx_quantize")
g.get_main_dot_graph().write_svg("centernet_fx_quantize.svg")

我们来展示下插入量化观察节点的模型:

《TORCH.FX第二篇——PTQ量化实操》

这个是经过converter融合后,带有QDQ的模型:

《TORCH.FX第二篇——PTQ量化实操》

后记

Pytorch.fx是个有潜力的工具,量化功能做的也不错,但是实际使用中仍然有很多局限性,很多功能不完善,有一些bug需要自己去趟。

我自己使用今年2月份的FX可以成功量化模型以及部署TensorRT,但是隔了几个月再更新就发现变了很多,需要自己花点精力再去同步下。个人感觉FX目前适合尝鲜或者动手能力强一点的人去用,适合折腾。

下一篇会继续写如何将FX量化后的模型转化为TensorRT,当然还是会有坑,不过,下篇文章见吧~

参考资料

  点赞
本篇文章采用 署名-非商业性使用-禁止演绎 4.0 国际 进行许可
转载请务必注明来源: https://oldpan.me/archives/torch-fx-second-quantize-with-fx

   关注Oldpan博客微信公众号,你最需要的及时推送给你。


  1. y
    youyou说道:

    大佬,写一篇基于FX的QAT量化不?

  2. m
    moonriver说道:

    大佬,能不能写一篇基于FX的QAT量化?

  3. l
    lk说道:

    很强很强,学习了!!!

发表评论

邮箱地址不会被公开。 必填项已用*标注

评论审核已启用。您的评论可能需要一段时间后才能被显示。