点击小眼睛开启蜘蛛网特效

利用Caffe推理CenterNet(上篇)

《利用Caffe推理CenterNet(上篇)》

本文主要内容是记录一下将CenterNet模型转化为Caffe模型,并且成功推理的过程。虽然Caffe用的不多了吧,但是作为C++端的推理框架还是有不小的用武之地的,所以说本篇也可以称为CenterNet的C++/CUDA推理教程。

因为文章较长,所以这个实现部分共分为两个章节,本章节主要讲解将ResNet-50转换为Caffemodel,并且用C++代码和cuda代码将后处理步骤实现出来,并且成功预测的内容。

下一篇则会讲解如何将CenterNet的后处理过程全部并入Caffe的Layer当中,更优雅地实现整个流程。

Pytorch->Caffe

假设我们已经训练好了一个CenterNet模型。这里我拿ResNet50作为例子。

因为Caffe中没有DCN也就是可形变卷积层,当然也没有注册,正常来将对于CenterNet中的dla34这种包含卷积层的backbone肯定是无法转换的。不过我们可以自行为Caffe添加DCN算子从而实现这个转换过程,不过可能稍微麻烦些。

模型转换是直接将Pytorch的模型转换为Caffe(当然也可以将Pytorch首先转化为ONNX再转化到Caffe,未尝试)。其中,转换脚本使用的是 https://github.com/xxradon/PytorchToCaffe 这个仓库中的。

按照readme中的要求,我们通过conda命令安装pycaffe,然后执行转换脚本即可:

import torch
import torch.onnx
import pytorch_to_caffe

...

model = create_model(opt.arch, opt.heads, opt.head_conv)
model = load_model(model, model_path)
model = model.cpu()
model.eval()

input = torch.ones([1, 3, 512, 512])
pytorch_to_caffe.trans_net(model, input, 'res50')
pytorch_to_caffe.save_prototxt('{}.prototxt'.format('res50'))
pytorch_to_caffe.save_caffemodel('{}.caffemodel'.format('res50'))

导出模型即可,这里虽然指明了输入维度,但是在之后caffe之后是可以通过reshape操作动态改变输入维度的。

可以利用这个转换仓库中的验证.py看一下转出的模型和Pytorch版本的输出是否一致,如果嫌麻烦的话可以简单利用netron查看下转换前和转化后的模型参数看看是否一致。

Caffe读取模型

经过Pytorch转出的prototxt的pooling层有一个不一样的地方,有一个参数ceil_mode表示当前的池化层是向上取整还是向下取整,在Pytorch中默认为向下取整,而在Caffe中默认是向上取整;另外,因为我使用的caffe版本并没有ceil_mode这个参数,如果直接读取模型会报错,因此需要修改一些Caffe源码。

修改的地方一共三个:

  • src/caffe/proto/caffe.proto
  • include/caffe/layers/pooling_layer.hpp
  • src/caffe/layers/pooling_layer.cpp

修改教程可以看这篇:https://blog.csdn.net/qq_38451119/article/details/82252027 这里不赘述了。
需要注意的是添加ceil_mode之后一定记得把初始化代码也加上,也就是在pooling_layer.cppLayerSetUp函数中添加一句:ceil_mode_ = pool_param.ceil_mode();。否则ceil_mode_这个布尔变量是随机值,每次运行模型池化后的大小都不一样…别问我为啥知道。

《利用Caffe推理CenterNet(上篇)》

Caffe进行推断

以下的测试代码通过Caffe-C++版进行验证,这里参照的代码为 https://github.com/hexiangquan/CenterNetCPP ,感谢这位仁兄的无私贡献。不过这个仓库中的后处理代码略微有些bug,在输入维度长宽不一致的时候会出现问题,这里我对其后处理部分进行了修改,从而可以支持任意尺度的输入。

CenterNet的后处理步骤

CenterNet的模型推断过程复杂的主要是后处理部分,先来回顾一下。CenterNet的后处理与Anchor-base的目标检测框架不同,其余基于Anchor的检测框架最终输出的直接是box位置。而CenterNet最终输出的是3个数据数组,分别代表热点图hm数组、wh数组以及偏置reg数组。我们需要对这些输出的数据进行一些额外操作之后才可以转化为我们需要的box。

这3个数组数据分别对应训练过程中设定好的head-heads: {'hm': 2, 'wh': 2, 'reg': 2},这里'hm': 2代表有检测的目录为两类。

顺序步骤为:

  • 将关键点进行sigmoid操作,将关键点信息转化为0-1的格式;
  • 对关键点信息进行最大池化操作,过滤掉一些多余的关键点;
  • 对过滤后的关键点进行排序操作,取topK个关键点;
  • 根据取出的关键点信息和相应的坐标信息结合wh和reg信息得到相应的box信息

注意,只有关键点信息中包含了检测的类别信息,wh和reg数组在每个坐标点中有相应的信息,而这个坐标点的类别由这个坐标所应对的关键点决定。

官方的后处理代码:

# src/lib/detectors/ctdet.py
with torch.no_grad():
    output = self.model(images)[-1]
    # 首先对关键点数组做sigmoid
    hm = output['hm'].sigmoid_()
    wh = output['wh']
    reg = output['reg'] if self.opt.reg_offset else None
    ...
    # 再放入后处理步骤中对输出结果进行调整
    dets = ctdet_decode(hm, wh, reg=reg, cat_spec_wh=self.opt.cat_spec_wh, K=self.opt.K)

其中,CenterNet中的后处理使用最大池化层充当_nms操作,意义在于将多余的热点信息过滤掉。

def _nms(heat, kernel=3):
    pad = (kernel - 1) // 2
    # hmax为最大池化后的特征图,通过padding后与输入的heat特征图大小一致
    hmax = nn.functional.max_pool2d(
        heat, (kernel, kernel), stride=1, padding=pad)
    # keep索引存放hmax于heat相同的部分,也就是判断当前点是否为最大点
    keep = (hmax == heat).float()

    return heat * keep

需要注意上述代码中keep = (hmax == heat).float()是将池化后的特征图与原始特征图进行比较,并选取相同的点保留下来。

为此,我们需要在原始模型结构上添加一些新的处理层从而方便后处理操作,原始的最后三个输出层如下图所示,分别代表hm、wh、reg三个输出,这时的hm输出尚未经过sigmoid的处理:

这个就是原始的通过Pytorch转换后的caffe模型的最后几层部分:

《利用Caffe推理CenterNet(上篇)》

因为Caffe的模型是通过prototxt决定的,我们可以比较简单地在hm输出层后添加两个额外的层,一个sigmoid和一个最大池化层,这两个层加在hm输出层之后:

《利用Caffe推理CenterNet(上篇)》

这两个层对应着官方的后处理操作,对hm特征图首先进行sigmoid,其次最大池化。这两个层不需要参数所以可以直接在prototxt中添加,只在推理阶段使用。

这样的话,我们就会得到三个输出,分别是经过处理的hm层,wh层和reg层。

  // 经过处理的hm层
  Blob<float>* hm_result_blob = net_->output_blobs()[2];
  const float* hm_result = hm_result_blob->cpu_data();
  vector<int> hm_shape = hm_result_blob->shape();
  //wh层
  Blob<float>* result_blob1 = net_->output_blobs()[0];
  const float* result1 = result_blob1->cpu_data();
  //reg层
  Blob<float>* result_blob0 = net_->output_blobs()[1];
  const float* result0 = result_blob0->cpu_data();

  // 这一层是没有max-pooling的heat输出 用于作对比
  boost::shared_ptr<caffe::Blob<float>> layerData = net_->blob_by_name("sigmoid_blob55");  // 获得指定层的输出
  const float* pstart = layerData->cpu_data();

  // 这里相当于 之前所说的keep 操作
  int classes = hm_shape[1];
  int feature_size = hm_shape[2]*hm_shape[3];

  vector<vector<float>> fscore_max;
  for(int j=0; j < classes; j++)   //class
  {
    for(int k=0; k<feature_size; k++)
      if(pstart[j*feature_size+k]
          == hm_result[j*feature_size+k])
      {
        vector<float> inx;
        inx.push_back(j*feature_size+k);         // 位置
        inx.push_back(pstart[j*feature_size+k]); // 分数
        fscore_max.push_back(inx);
      }
  }

  std::sort(fscore_max.begin(), fscore_max.end(),[](const std::vector<float>& a, const std::vector<float>& b){ return a[1] > b[1];});
  // get top 100
  int iters = std::min<int>(fscore_max.size(), 100);
  int only_threshbox=0;
  for(int i=0;i<iters;i++)
  {
    // 这里根据阈值进行筛选
    if(fscore_max[i][1]<thresh)
    {
      break;
    }
    only_threshbox++;
  }

  // 这里进行行了一些修改,使得对于任意尺寸输入不会出错
  vector<vector<float>> boxes;
  for(int i = 0; i < only_threshbox;i++)
  {
    vector<float> box;
    int index = ((int)fscore_max[i][0]) / (hm_shape[2] * hm_shape[3]);
    int center_index = ((int)fscore_max[i][0]) % (hm_shape[2]*hm_shape[3]) - hm_shape[3];
    int cls = index;

    float xs= center_index % hm_shape[3];
    float ys= int(center_index / hm_shape[3] ) % hm_shape[2];

    //reg batch 1
    xs += result0[(int)(((int)ys)*hm_shape[3] + xs)];
    ys += result0[(int)(hm_shape[3]*hm_shape[2]+((int)ys)*hm_shape[3]+xs)];

    float w = result1[(int)(((int)ys)*hm_shape[3]+xs)];
    float h = result1[(int)(hm_shape[2]*hm_shape[3]+((int)ys)*hm_shape[3]+xs)];

    box.push_back((float)cls);
    box.push_back((float)fscore_max[i][1]);

    box.push_back((float)(xs-w/2.0));
    box.push_back((float)(ys-h/2.0));
    box.push_back((float)(xs+w/2.0));
    box.push_back((float)(ys+h/2.0));
    // 输出四个点
    boxes.push_back(box);
  }

最终得到的结果放入vector<float> box当中,这个结果是没有经过NMS的,需要的可以自行添加,CPU版的NMS代码如下:

vector<vector<float> > CenterNet_Detector::apply_nms(vector<vector<float>> &box, float thres)
{
  vector<vector<float> > rlt;
  if (box.empty())
    return vector<vector<float> >();

  std::sort(box.begin(), box.end(),[](const std::vector<float>& a, const std::vector<float>& b){ return a[1] > b[1];});
  std::vector<int> pindex;

  for(int i=0;i<box.size();i++)
  {
    if(std::find(pindex.begin(),pindex.end(),i)!=pindex.end())
    {
      continue;
    }
    vector<float> truth = box[i];
    for(int j=i+1;j<box.size();j++)
    {
      if(std::find(pindex.begin(),pindex.end(),j)!=pindex.end())
      {
        continue;
      }

      vector<float> lbox = box[j];
      float iou = cal_iou(lbox, truth);
      if(iou >= thres)
        pindex.push_back(j);//p[j] = 1

    }
  }
  for(int i=0;i<box.size();i++)
  {
    if(std::find(pindex.begin(),pindex.end(),i)==pindex.end())
    {
      rlt.push_back(box[i]);
    }
  }
  return rlt;
}

将后处理步骤使用CUDA运算

虽然验证了模型的正确,但上述CPU版的后处理比较比较慢,对于1024*564的输入图像,上面keep操作的两个for循环耗时50ms,非常耗时。因此我们需要将该操作移植到CUDA中,利用GPU去运行。

这里借鉴https://github.com/CaoWGG/TensorRT-CenterNet/中的代码,这里的后处理步骤与之前的类似,不过没有topK的操作,其实topK与阈值的选取有一些关系。并且topK有一个比较隐晦的隐患,对于比较小的图像,当下采样后的特征图总size大小比topK小的时候,topK就无法正常执行(在官方的repository中是这样的)。

GPU版的后处理代码如下,修改了一下CaoWGG的代码使其支持任意维度:

#include "PostprocessLayer.h"

__device__ float Logist(float data){ return 1./(1. + exp(-data)); }

__global__ void CTdetforward_kernel(const float *hm, const float *reg,const float *wh ,
    float *output,const int w,const int h,const int classes,const int kernel_size,const float visthresh  ) {

int idx = blockDim.x * blockIdx.x + threadIdx.x;
if (idx >= w * h * classes) return;
int padding = (kernel_size - 1) / 2;
int offset = -padding;
int stride = w * h;
int grid_x = idx % w;
int grid_y = (idx / w) % h;
int cls = idx/w/h ;
int  l, m;
int reg_index = idx - cls*stride;
float c_x, c_y;
float objProb = Logist(hm[idx]);
if (objProb > visthresh) {
    float max = -1;
    int max_index = 0;
    for (l = 0; l < kernel_size; ++l)
        for (m = 0; m < kernel_size; ++m) {
            int cur_x = offset + l + grid_x;
            int cur_y = offset + m + grid_y;
            int cur_index = cur_y * w + cur_x + stride * cls;
            int valid = (cur_x >= 0 && cur_x < w && cur_y >= 0 && cur_y < h);
            float val = (valid != 0) ? Logist(hm[cur_index]) : -1;
            max_index = (val > max) ? cur_index : max_index;
            max = (val > max) ? val : max;
        }

    if(idx == max_index){
        // 原子操作,在指针的第一个地址的数据上加1,用于累加挑选后的框的数量
        int resCount = (int) atomicAdd(output, 1);
        // 这里每次将获取到的box数据放入指定的位置 
        char *data = (char *) output + sizeof(float) + resCount * sizeof(Detection);
        Detection *det = (Detection *) (data);
        c_x = grid_x + reg[reg_index];
        c_y = grid_y + reg[reg_index + stride];
        det->bbox.x1 = (c_x - wh[reg_index] / 2) * 4;
        det->bbox.y1 = (c_y - wh[reg_index + stride] / 2) * 4;
        det->bbox.x2 = (c_x + wh[reg_index] / 2) * 4;
        det->bbox.y2 = (c_y + wh[reg_index + stride] / 2) * 4;
        det->classId = cls;
        det->prob = objProb;
    }
}
}

void CTdetforward_gpu(const float *hm, const float *reg,const float *wh ,float *output,
    const int w,const int h,const int classes,const int kernerl_size, const float visthresh ){
uint num = w * h * classes;
int nbrBlocks = ceil((float)num / (float)BLOCK);
CTdetforward_kernel<<<nbrBlocks,BLOCK>>>(hm,reg,wh,output,w,h,classes,kernerl_size,visthresh);
}

在将后处理代码写好之后,将其编译出来即可:nvcc -o libpostprocess.so -shared -Xcompiler -fPIC PostprocessLayer.cu -arch=sm_61 -std=c++11。这里我的GPU是1080TI,如果是2080TI,需要将计算能力修改为75。

改成cuda版本后,在1080TI上,后处理的耗时0.3ms左右,此时我们的推理代码需需要修改为:

Blob<float>* input_layer = net_->input_blobs()[0];
typedef std::chrono::duration<double, std::ratio<1, 1000>> ms;
auto total_t0 = std::chrono::high_resolution_clock::now();

auto t0 = std::chrono::high_resolution_clock::now();
input_layer->Reshape(1, num_channels_, img.rows, img.cols);
input_geometry_ = cv::Size(input_layer->width(), input_layer->height());
net_->Reshape();

auto t1 = std::chrono::high_resolution_clock::now();
double reshape_time = std::chrono::duration_cast<ms>(t1 - t0).count();
std::cout << "Caffe Reshape time: " << std::fixed << std::setprecision(2)
          << reshape_time << " ms ";

t0 = std::chrono::high_resolution_clock::now();
std::vector<cv::Mat> input_channels;
WrapInputLayer(&input_channels);
cv::Mat tm = Preprocess(img, &input_channels);
t1 = std::chrono::high_resolution_clock::now();
double preprocess_time = std::chrono::duration_cast<ms>(t1 - t0).count();
std::cout << "Preprocess time: " << std::fixed << std::setprecision(2)
          << preprocess_time << " ms" << std::endl;

t0 = std::chrono::high_resolution_clock::now();
net_->Forward();
t1 = std::chrono::high_resolution_clock::now();
double net_time = std::chrono::duration_cast<ms>(t1 - t0).count();
std::cout << "Net processing time: " << std::fixed << std::setprecision(2)
          << net_time << " ms ";

  Blob<float>* hm_result_blob = net_->output_blobs()[2];
const float* hm_result = hm_result_blob->cpu_data();
vector<int> hm_shape = hm_result_blob->shape();

Blob<float>* result_blob1 = net_->output_blobs()[0];
const float* result1 = result_blob1->cpu_data();
  vector<int> wh_shape = result_blob1->shape();

Blob<float>* result_blob0 = net_->output_blobs()[1];
const float* result0 = result_blob0->cpu_data();
vector<int> reg_shape = result_blob0->shape();

boost::shared_ptr<caffe::Blob<float>> layerData1 = net_->blob_by_name("conv_blob55");  
// 得到没有经过任何后处理的原始heat map
const float* hm_raw = layerData1->cpu_data();

boost::shared_ptr<caffe::Blob<float>> layerData = net_->blob_by_name("sigmoid_blob55");  // 获得指定层的输出
const float* pstart = layerData->cpu_data();

/*cuda part*/
std::vector<void*> mCudaBuffers(3);
std::vector<int64_t> mBindBufferSizes(3);  //  保存size
int64_t totalSize = 0;
int64_t outputBufferSize;
void * cudaOutputBuffer;

t0 = std::chrono::high_resolution_clock::now();

cudaEvent_t start1;
cudaEventCreate(&start1);
cudaEvent_t stop1;
cudaEventCreate(&stop1);
cudaEventRecord(start1, NULL);

totalSize = hm_shape[1] * hm_shape[2] * hm_shape[3] * sizeof(float);
mCudaBuffers[0] = safeCudaMalloc(totalSize);
mBindBufferSizes[0] = totalSize;

totalSize = wh_shape[1] * wh_shape[2] * wh_shape[3] * sizeof(float);
mCudaBuffers[1] = safeCudaMalloc(totalSize);
mBindBufferSizes[1] = totalSize;

totalSize = reg_shape[1] * reg_shape[2] * reg_shape[3] * sizeof(float);
mCudaBuffers[2] = safeCudaMalloc(totalSize);
mBindBufferSizes[2] = totalSize;

outputBufferSize = mBindBufferSizes[0] * 6 ;
cudaOutputBuffer = safeCudaMalloc(outputBufferSize);
CUDA_CHECK(cudaMemcpy(mCudaBuffers[0], hm_raw, mBindBufferSizes[0], cudaMemcpyHostToDevice));
CUDA_CHECK(cudaMemcpy(mCudaBuffers[1], result1, mBindBufferSizes[1], cudaMemcpyHostToDevice));
CUDA_CHECK(cudaMemcpy(mCudaBuffers[2], result0, mBindBufferSizes[2], cudaMemcpyHostToDevice));
CUDA_CHECK(cudaMemset(cudaOutputBuffer, 0, sizeof(float))); 

std::unique_ptr<float[]> outputData(new float[outputBufferSize]);

CTdetforward_gpu(static_cast<const float *>(mCudaBuffers[0]),static_cast<const float *>(mCudaBuffers[2]),
                        static_cast<const float *>(mCudaBuffers[1]),static_cast<float *>(cudaOutputBuffer),
                            hm_shape[3],hm_shape[2] ,2, 3, 0.2);

CUDA_CHECK(cudaDeviceSynchronize());
// std::cout << "cuda excute "  << std::endl;
CUDA_CHECK(cudaMemcpy(outputData.get(), cudaOutputBuffer, outputBufferSize, cudaMemcpyDeviceToHost));
int num_det = static_cast<int>(outputData[0]);
std::vector<Detection> result;
result.resize(num_det);
memcpy(result.data(), &outputData[1], num_det * sizeof(Detection));

// std::cout << num_det << std::endl;

t1 = std::chrono::high_resolution_clock::now();
double postprocess_time = std::chrono::duration_cast<ms>(t1 - t0).count();
std::cout << "postprocess_time: " << std::fixed << std::setprecision(2)
          << postprocess_time << " ms ";

double total_time = std::chrono::duration_cast<ms>(t1 - total_t0).count();

经验证结果是正确的,这里就不展示啦。至此,我们将CenterNet转化为Caffe并且添加了后处理操作,下一篇文章中将会将上述的后处理操作移动至Caffe的层中,更为优雅地实现相关功能。

参考链接

https://github.com/ouyanghuiyu/centernet_mobilenetv2_ncnn/blob/master/cpp/ncnn_centernet.cpp
https://stackoverflow.com/questions/50795108/cuda-kernel-not-called-by-all-blocks
https://blog.csdn.net/zhouliyang1990/article/details/38094709
https://github.com/610265158/mobilenetv3_centernet
https://sourcegraph.com/github.com/Stick-To/CenterNet-tensorflow@master
https://cloud.tencent.com/developer/ask/114524 如何测量NVIDIA CUDA的内核时间?
https://stackoverflow.com/questions/11888772/when-to-call-cudadevicesynchronize
https://stackoverflow.com/questions/13485018/cudastreamsynchronize-vs-cudadevicesynchronize-vs-cudathreadsynchronize

  点赞
本篇文章采用 署名-非商业性使用-禁止演绎 4.0 国际 进行许可
转载请务必注明来源: https://oldpan.me/archives/use-caffe-centernet-inference-up

   关注Oldpan博客微信公众号,你最需要的及时推送给你。


发表评论

邮箱地址不会被公开。 必填项已用*标注

评论审核已启用。您的评论可能需要一段时间后才能被显示。