前言
风格迁移是神经网络深度学习中比较重要且有趣的一个项目。如果不知道什么是风格迁移的请参考这篇文章:https://oldpan.me/archives/pytorch-neural-transfer。
本文参考论文:Stable and Controllable Neural Texture Synthesis and Style Transfer Using Histogram Losses
正文
gram matrix
风格迁移(Style Transfer)中我们使用了很多损失函数,最主要的损失函数是在内容层的L2损失以及在风格层的Gram(格拉姆矩阵)损失,Gram损失即利用原图和目标的gram矩阵进行比较得到的损失。Gram矩阵即是简单的一个数据(比如一张图片)中内部元素相乘的矩阵乘法,获取该数据的内在特征,原因很简单,一个数据的内在特征两两相乘后,特殊的特征(元素值比较大)会更大,而元素值比较小的特征在两两相乘后也会变小,所以这个矩阵得到的效果即是,放大数据的特征,得到该数据的纹理细节,从而方便比较:
上面是Gram Matrix(格拉姆矩阵),但是gram是不稳定的,在实际过程中需要人工进行调参才可以得到不错的结果:
如上图,a图为输入图像,b图为通过输入图像a经过gram矩阵仿制出来的,很明显这个gram矩阵很不稳定,导致图片纹理模糊不清楚,而c图则是在经过调参后得到的不错的效果图,但是仍然可以从其中看到一些模糊和细节丢失的痕迹。
为什么会gram矩阵会出现这些问题,原因在于gram矩阵在读取对象本身的特征同时对这个对象本身的分布并不“感冒”。
举个例子,上面的两幅图中,左边的图的分布比较均匀,可以得到该分布的均值是0.707、而方差是0。而右图的均值是1/2,方差也是1/2,这两张图我们可以当做风格迁移中某一个特征图中的一个层,代表了不同对象的特征信息。为什么要说这两张图,因为在对对这两个不同的图进行计算后,发现,这两张图的gram矩阵的值是一样的!
这两张图的gram matrix信息竟然一样,在本文参考的论文中有一些公式论证,这里直接说结论:我们可以在保持gram matrix矩阵值不变的情况下,改变这两张图的分布的方差,这也就是为什么gram matrix矩阵不稳定的原因。
Histogram Loss
这时就需要Histogram Loss来实现更好的texture transfer-风格迁移。
为什么用Histogram,之前我们说过gram loss不稳定是因为其对所提取对象的分部信息“不感冒”,所以我们利用Histogram来进行修改,因为直方图代表的信息就是分布。
这篇文章主要说直方图匹配,另外还有一篇文章是说直方图损失,可以与这篇文章进行相互补充:传送门。
利用直方图提取对象分布信息再结合gram来实现风格的迁移。听起来挺酷,但是实现起来就需要稍微换一个方向。
我们利用这个公式:
其中是一个风格的激活层,而则是经过直方图匹配后的激活层,则是权重参数,我们定义这个为histogram损失,在风格迁移中就可以结合gram损失一块使用。
即 ==>
直方图匹配和直方图均衡这两个概念应该都比较熟悉,在数字图像处理中是比较常见的算法,opencv就有直方图均衡的算法。
直方图匹配过程
直方图匹配就是根据特定的图片的亮度信息和分布信息去调整另一张图片的信息,例如下图利用中间图片的分布信息去匹配我们的Source图片。
这里给出通过python实现的直方图匹配算法,采用的深度学习框架为Pytorch,输入为tensor型变量。
以下代码来源于Sylvain Gugger的Blog,有兴趣的可以看看文章中的参考部分。
其中remap_hist为直方图匹配函数,x写为Tensor.view(-1,1)形式,hist_ref 是我们的参考Tensor。
def select_idx(tensor, idx): ch = tensor.size(0) return tensor.view(-1)[idx.view(-1)].view(ch,-1) def remap_hist(x,hist_ref): ch, n = x.size() sorted_x, sort_idx = x.data.sort(1) ymin, ymax = x.data.min(1)[0].unsqueeze(1), x.data.max(1)[0].unsqueeze(1) hist = hist_ref * n/hist_ref.sum(1).unsqueeze(1)#Normalization between the different lengths of masks. cum_ref = hist.cumsum(1) cum_prev = torch.cat([torch.zeros(ch,1).cuda(), cum_ref[:,:-1]],1) step = (ymax-ymin)/n_bins rng = torch.arange(1,n+1).unsqueeze(0).cuda() idx = (cum_ref.unsqueeze(1) - rng.unsqueeze(2) < 0).sum(2).long() ratio = (rng - select_idx(cum_prev,idx)) / (1e-8 + select_idx(hist,idx)) ratio = ratio.squeeze().clamp(0,1) new_x = ymin + (ratio + idx.float()) * step new_x[:,-1] = ymax _, remap = sort_idx.sort() new_x = select_idx(new_x,idx) return new_x
你这个直方图匹配运行起来不会很慢吗?
这个代码我放错了,这个是patchmatch的代码,跑起来的确很慢...我去修改下,之后实现的histogram匹配没有用python写而是使用cuda语言利用GPU加速,速度1s不到。