当前位置:千优科技>行业资讯>详情

华人博士生首次尝试用两个Transformer构建一个GAN

2021-02-21 23:23:38 浏览:415次 来源:互联网 编辑:leo 推荐:人工智能

华人博士生首次尝试用两个Transformer构建一个GAN

最近,CV研究者对transformer产生了极大的兴趣并取得了不少突破。这表明,transformer有可能成为计算机视觉任务(如分类、检测和分割)的强大通用模型。

我们都很好奇:在计算机视觉领域,transformer还能走多远?对于更加困难的视觉任务,比如生成对抗网络(GAN),transformer表现又如何?

在这种好奇心的驱使下,德州大学奥斯汀分校的YifanJiang、ZhangyangWang,IBMResearch的ShiyuChang等研究者进行了第一次试验性研究,构建了一个只使用纯transformer架构、完全没有卷积的GAN,并将其命名为TransGAN。与其它基于transformer的视觉模型相比,仅使用transformer构建GAN似乎更具挑战性,这是因为与分类等任务相比,真实图像生成的门槛更高,而且GAN训练本身具有较高的不稳定性。

从结构上来看,TransGAN包括两个部分:一个是内存友好的基于transformer的生成器,该生成器可以逐步提高特征分辨率,同时降低嵌入维数;另一个是基于transformer的patch级判别器。

研究者还发现,TransGAN显著受益于数据增强(超过标准的GAN)、生成器的多任务协同训练策略和强调自然图像邻域平滑的局部初始化自注意力。这些发现表明,TransGAN可以有效地扩展至更大的模型和具有更高分辨率的图像数据集。

实验结果表明,与当前基于卷积骨干的SOTAGAN相比,表现最佳的TransGAN实现了极具竞争力的性能。具体来说,TransGAN在STL-10上的IS评分为10.10,FID为25.32,实现了新的SOTA。

该研究表明,对于卷积骨干以及许多专用模块的依赖可能不是GAN所必需的,纯transformer有足够的能力生成图像。

在该论文的相关讨论中,有读者调侃道,「attentionisreallybecoming『allyouneed』.」

不过,也有部分研究者表达了自己的担忧:在transformer席卷整个社区的大背景下,势单力薄的小实验室要怎么活下去?

如果transformer真的成为社区「刚需」,如何提升这类架构的计算效率将成为一个棘手的研究问题。

基于纯Transformer的GAN

作为基础块的Transformer编码器

研究者选择将Transformer编码器(Vaswani等人,2017)作为基础块,并尽量进行最小程度的改变。编码器由两个部件组成,第一个部件由一个多头自注意力模块构造而成,第二个部件是具有GELU非线性的前馈MLP(multiple-layerperceptron,多层感知器)。此外,研究者在两个部件之前均应用了层归一化(Ba等人,2016)。两个部件也都使用了残差连接。

内存友好的生成器

NLP中的Transformer将每个词作为输入(Devlin等人,2018)。但是,如果以类似的方法通过堆叠Transformer编码器来逐像素地生成图像,则低分辨率图像(如32×32)也可能导致长序列(1024)以及更高昂的自注意力开销。

所以,为了避免过高的开销,研究者受到了基于CNN的GAN中常见设计理念的启发,在多个阶段迭代地提升分辨率(Denton等人,2015;Karras等人,2017)。他们的策略是逐步增加输入序列,并降低嵌入维数

如下图1左所示,研究者提出了包含多个阶段的内存友好、基于Transformer的生成器:

每个阶段堆叠了数个编码器块(默认为5、2和2)。通过分段式设计,研究者逐步增加特征图分辨率,直到其达到目标分辨率H_T×W_T。具体来说,该生成器以随机噪声作为其输入,并通过一个MLP将随机噪声传递给长度为H×W×C的向量。该向量又变形为分辨率为H×W的特征图(默认H=W=8),每个点都是C维嵌入。然后,该特征图被视为长度为64的C维token序列,并与可学得的位置编码相结合。

与BERT(Devlin等人,2018)类似,该研究提出的Transformer编码器以嵌入token作为输入,并递归地计算每个token之间的匹配。为了合成分辨率更高的图像,研究者在每个阶段之后插入了一个由reshaping和pixelshuffle模块组成的上采样模块。

具体操作上,上采样模块首先将1D序列的token嵌入变形为2D特征图

,然后采用pixelshuffle模块对2D特征图的分辨率进行上采样处理,并下采样嵌入维数,最终得到输出

。然后,2D特征图X’_0再次变形为嵌入token的1D序列,其中token数为4HW,嵌入维数为C/4。所以,在每个阶段,分辨率(H,W)提升到两倍,同时嵌入维数C减少至输入的四分之一。这一权衡(trade-off)策略缓和了内存和计算量需求的激增。

研究者在多个阶段重复上述流程,直到分辨率达到(H_T,W_T)。然后,他们将嵌入维数投影到3,并得到RGB图像。

用于判别器的tokenized输入

与那些需要准确合成每个像素的生成器不同,该研究提出的判别器只需要分辨真假图像即可。这使得研究者可以在语义上将输入图像tokenize为更粗糙的patchlevel(Dosovitskiy等人,2020)。

如上图1右所示,判别器以图像的patch作为输入。研究者将输入图像

分解为8×8个patch,其中每个patch可被视为一个「词」。然后,8×8个patch通过一个线性flatten层转化为token嵌入的1D序列,其中token数N=8×8=64,嵌入维数为C。再之后,研究者在1D序列的开头添加了可学得位置编码和一个[cls]token。在通过Transformer编码器后,分类head只使用[cls]token来输出真假预测。

实验

CIFAR-10上的结果

研究者在CIFAR-10数据集上对比了TransGAN和近来基于卷积的GAN的研究,结果如下表5所示:

如上表5所示,TransGAN优于AutoGAN(Gong等人,2019),在IS评分方面也优于许多竞争者,如SN-GAN(Miyato等人,2018)、improvingMMDGAN(Wang等人,2018a)、MGAN(Hoang等人,2018)。TransGAN仅次于ProgressiveGAN和StyleGANv2。

对比FID结果,研究发现,TransGAN甚至优于ProgressiveGAN,而略低于StyleGANv2(Karras等人,2020b)。在CIFAR-10上生成的可视化示例如下图4所示:

STL-10上的结果

研究者将TransGAN应用于另一个流行的48×48分辨率的基准STL-10。为了适应目标分辨率,该研究将第一阶段的输入特征图从(8×8)=64增加到(12×12)=144,然后将提出的TransGAN-XL与自动搜索的ConvNets和手工制作的ConvNets进行了比较,结果下表6所示:

与CIFAR-10上的结果不同,该研究发现,TransGAN优于所有当前的模型,并在IS和FID得分方面达到新的SOTA性能。

高分辨率生成

由于TransGAN在标准基准CIFAR-10和STL-10上取得不错的性能,研究者将TransGAN用于更具挑战性的数据集CelebA64×64,结果如下表10所示:

TransGAN-XL的FID评分为12.23,这表明TransGAN-XL可适用于高分辨率任务。可视化结果如图4所示。

局限性

虽然TransGAN已经取得了不错的成绩,但与最好的手工设计的GAN相比,它还有很大的改进空间。在论文的最后,作者指出了以下几个具体的改进方向:

作者简介

本文一作YifanJiang是德州大学奥斯汀分校电子与计算机工程系的一年级博士生(此前在德克萨斯A&M大学学习过一年),本科毕业于华中科技大学,研究兴趣集中在计算机视觉、深度学习等方向。目前,YifanJiang主要从事神经架构搜索、视频理解和高级表征学习领域的研究,师从德州大学奥斯汀分校电子与计算机工程系助理教授ZhangyangWang。

在本科期间,YifanJiang曾在字节跳动AILab实习。今年夏天,他将进入GoogleResearch实习。

一作主页:https://yifanjiang.net/

标签:人工智能机器学习技术

版权声明:文章由 www.e1000u.com 整理收集,来源于互联网或者用户投稿,如有侵权,请联系我们,我们会立即处理。如转载请保留本文链接:https://www.e1000u.com/article/10290.html