700字范文,内容丰富有趣,生活中的好帮手!
700字范文 > 李沐精读论文:GAN《Generative Adversarial Nets》by Ian J. Goodfellow

李沐精读论文:GAN《Generative Adversarial Nets》by Ian J. Goodfellow

时间:2023-09-21 21:10:42

相关推荐

李沐精读论文:GAN《Generative Adversarial Nets》by Ian J. Goodfellow

论文:/paper//file/5ca3e9b122f61f8f06494c97b1afccf3-Paper.pdf

视频:GAN论文逐段精读【论文精读】_哔哩哔哩_bilibili

课程:CS231n PPT笔记- 生成模型Generative Modeling

​李宏毅机器学习——对抗生成网络(GAN)_iwill323的博客-CSDN博客

拓展网站:This Person Does Not Exist

/r/MachineLearning/top/?t=month

https://crypko.ai/

博文:本文主要参考下面博文并摘取了文字和图片李沐论文精读系列一: ResNet、Transformer、GAN、BERT_神洛华的博客-

要想较为详细了解GAN,推荐博文:生成对抗网络,从DCGAN到StyleGAN、pixel2pixel,人脸生成和图像翻译。_神洛华的博客-CSDN博客_人像油画生成 对抗网络

目录

1.简介

2 导论

3 相关工作

4 目标函数及其求解

目标函数

1.生成器G

2.判别器D

3.两个模型同时训练

模型训练过程演示

迭代求解过程

5 理论结果:全局最优解 pg​=pdata​

收敛证明

6 GAN的优势与缺陷

优势

问题

7代码实现

8 影响

9. 关于损失函数的讨论

二元分类

discriminator

generator

1.简介

GANs(Generative Adversarial Networks,生成对抗网络)是从对抗训练中估计一个生成模型,其由两个基础神经网络组成,即生成器神经网络G(Generator Neural Network) 和判别器神经网络D(Discriminator Neural Network)

生成器G从给定噪声中(一般是指均匀分布或者正态分布)采样来合成数据,判别器D用于判别样本是真实样本还是G生成的样本。G的目标就是尽量生成真实的图片去欺骗判别网络D,使D犯错;而D的目标就是尽量把G生成的图片和真实的图片分别开来。二者互相博弈,共同进化,最理想的状态下,G可以生成足以“以假乱真”的图片G(z);对于D来说,它难以判定G生成的图片究竟是不是真实的,因此D(G(z)) = 0.5,此时噪声分布接近真实数据分布。

发展:

2 导论

深度学习是用来发现一些丰富的、有层次的模型,这些模型能够对AI里的各种数据做一个概率分布的表示。深度学习网络只是一种手段而已。

 深度学习不仅是学习网络,更是对数据分布的一种表示。这和统计学习方法里面的观点不谋而合,后者认为机器学习模型从概率论的角度讲,就是一个概率分布Pθ​(X) (这里以概率密度函数来代表概率分布)

 机器学习的任务就是求最优参数θt​ ,使得概率分布 Pθ​(X) 最大(即已发生的事实,其对应的概率理应最大)。

 argmax 函数代表的是取参数使得数据的概率密度最大。求解最优参数θt​的过程,我们称之为模型的训练过程( Training )

深度学习在判别模型上取得了很好的效果,但是在生成模型上比较差。难点在于最大化似然函数时,要对概率分布做很多近似,近似带来了很大的计算困难。

本文的核心观点就是, 不用再去近似似然函数了,可以用更好的办法(GAN)来计算模型。

GAN是一个框架,里面的模型都是MLP。生成器G这个MLP的输入是随机噪声,通常是高斯分布,然后将其映射到任何一个我们想去拟合的分布;判别器D也是MLP,所以可以通过误差的反向传递来训练,而不需要像使用马尔可夫链这样的算法对一个分布进行复杂的采样。这样模型就比较简单,计算上有优势。

3 相关工作

之前的生成模型总是想构造一个分布函数出来,同时这些函数提供了一些参数可以学习。这些参数通过最大化对数似然函数来求解。这样做的缺点是,采样一个分布时,求解参数算起来很难,特别是高维数据。因为这样计算很困难,所以最近有一些Generative Machines,不再去构造分布函数,而是学习一个模型来近似这个分布。

 前者真的是在数学上学习出一个分布,明明白白知道数据是什么分布 ,里面的均值方差等等到底是什么东西。而GAN就是通过一个模型来近似分布的结果,而不需要构造分布函数。这样计算起来简单,缺点是不知道最终的分布到底是什么样子。

对f的期望求导,等价于对f自己求导。这也是为什么通过误差反向传递来对GAN求解。

光看上面描述看不懂,可以简单参考:CS231n PPT笔记- 生成模型Generative Modeling_iwill323的博客-CSDN博客_cs231n

生成模型可以解决密度估计问题,有两种方式:

显式密度模型会显式地给出一个和输入数据的分布pmodel(x)隐式密度模型训练一个模型,从输入数据中采样,并直接输出样本,而不用显式地给出分布的表达式。

4 目标函数及其求解

目标函数

GAN最简单的框架就是模型都是MLP。

1.生成器G

生成器是要在数据x上学习一个分布pg​(x),其输入是定义在一个先验噪声z上面,z的分布为pz​(z),比如高斯分布。生成模型G的任务就是用MLP把噪声z映射成数据x。比如图片生成,假设不同的生成图片是100个变量控制的,而MLP理论上可以拟合任何一个函数,那么我们就构造一个100维的向量,MLP强行把z映射成x,从而生成像样的图片。z可以先验的设定为一个100维向量,其均值为0,方差为1,呈高斯分布。这么做优点是算起来简单,缺点是MLP并不是真的了解背后的z是如何控制输出的,只是学出来随机选一个比较好的z来近似x,所以最终效果也就一般。

2.判别器D

判别器输出一个标量(概率),判断其输入是G生成的数据,还是真实的数据。对于D,真实数据label=1,假的数据label=0

3.两个模型同时训练

最终目标函数公式如下所示,E代表期望,公式中同时有minmax,所以是对抗训练。

G表示生成网络,D 表示判别网络,θd​是判别器参数,θg​是生成器的参数,训练目标是让目标函数在θg​上取得最小值,同时在 θd​上取得最大值。

第一项:pdata 表示真实数据的分布。D(x)是判别器网络对真实数据(训练数据)x的判别结果,输出一个 0-1 的概率(0表示假,1表示真)。E表示我们考虑的是整个训练集中所有样本的一个期望,而不是具体某个样本的概率。第二项:p(z)表示噪声的分布。使用 G(z) 可以生成一个样本,D​​(G​​(z))代表了判别器网路对生成的伪数据的判别结果。θd的目标:整个表达式越大越好。希望logD(x) 越大越好,即判别器对于真实样本的判别为真的期望越大越好;希望 log(1−D(G(z)))越大越好,也就是希望判别器对假的样本判别为真的概率越小越好。因此如果能最大化这一结果,就意味着判别器能够很好的区别真实数据和伪造数据。θg的目标:整个式子越小越好。G的目标是希望生成的图片越接近真实越好,使得D(G(z))接近1,也就是最小化log(1−D(G(z)))。结果就是训练一个G,使判别器尽量犯错,无法区分出数据来源,意味着生成器在生成与真实样本非常相似的数据

3模型训练过程演示

如上图所示,假设x和z都是一维向量,且z是均匀分布。虚线点为真实数据分布,蓝色虚线是判别器D判别结果的分布,绿色实线为生成器G的分布。

a. 生成器从均匀分布学成绿色实线表示的高斯分布,这时候判别器还很差;

b. 判别器学成图b所示的分布,可以把真实数据和生成数据区别开来;

c. 随着训练进行,生成器波峰靠向真实数据波峰,使得判别器难以分辨了;辨别器为了更准,其分布也往真实数据靠拢;

d. 最终训练的结果,生成器拟合真实分布,判别器难以分辨,输出概率都为0.5,即D(x) = 1/2

迭代求解过程

下面是具体的算法过程:

完整的训练过程:在每一个训练迭代期都先训练判别器网络,然后训练生成器网络。

1)对于判别器网络,先从噪声先验分布z中采样得到一个小批量样本,接着从训练数据x中采样获得小批量的真实样本,将噪声样本传给生成器网络,并在生成器的输出端获得生成的图像。此时我们有了一个小批量伪造图像和小批量真实图像,在判别器生进行一次梯度计算,利用梯度信息更新判别器参数,按照以上步骤迭代k次来训练判别器。2)训练生成器,采样获得一个小批量噪声样本,将它传入生成器,对生成器进行反向传播,优化目标函数。

训练完之后,将噪声图像传给生成网络,就能生成伪造图像。

Optimizing D to completion in the inner loop of training is computationally prohibitive, and on finite datasets would result in overfitting. Instead, we alternate between k steps of optimizing D and one step of optimizing G. This results in D being maintained near its optimal solution, so long asGchanges slowly enough.

k是一个超参数,不能太小也不能太大。要保证判别器D可以足够更新,但也不能更新太好。

如果D更新的不够好,那么G训练时在一个判别很差的模型里面更新参数,继续糊弄D意义不大;如果D训练的很完美,那么 log(1−D(G(z)))趋近于0,求导结果也趋近于0,生成器难以训练整体来说GAN的收敛是很不稳定的,所以之后有很多工作对其进行改进。

另一个问题:

早期G非常弱,所以很容易把D训练的很好,这样就造成刚刚说的G训练不动了。

In practice, equation 1 may not provide sufficient gradient forGto learn well. Early in learning, whenGis poor,Dcan reject samples with high confidence because they are clearly different from the training data. In this case, log(1- D(G(z))) saturates.

所以作者建议G的目标函数改为最大化logD(G(z)),这样可以得到同样的G和D的不动点,同时又能在早期更好的下降。

Rather than trainingGto minimize log(1- D(G(z))) we can trainGto maximize logD(G(z)). This objective function results in the same fixed point of the dynamics ofGandDbut provides much stronger gradients early in learning

下面从优化目标曲线的形状角度来解释:

上图的蓝色曲线为 log(1−D(G(z)))。当生成器效果不好(D(G(z)接近0)时,梯度非常平缓,模型训练很慢;当生成器效果好(D(G(z)接近1)时,梯度很陡峭,模型更新地会过快。这就与我们期望的相反了,我们希望在生成器效果不好的时候梯度更陡峭,这样能学到更多,在即将收敛的时候应该放缓更新步伐。

max logD(G(z))图像如下图绿色曲线所示,它就有很好的特性,即初始时梯度大,最后梯度小,符合训练的需要,实际训练中基本都用这个式子。

李沐:带来的问题是,D(G(z))→0的时候,log0是负无穷大,会带来数值上的问题。

关于目标函数的部分让人看的云里雾里,其实说白了就是二分类问题误差函数的构造,比如,我们会去min -log(y^),会去min -log(1-y^),而不会去min log(1-y^),见最后一部分的讨论

5 理论结果:全局最优解 pg​=pdata​

具体证明部分可以参考帖子GAN论文阅读——原始GAN(基本概念及理论推导)_StarCoo的博客-CSDN博客_gan 原始论文

1.先训练D。固定生成器G,最优的辨别器应该是

*表示最优解pg​(x) 和 pdata​(x)分别表示x在生成器拟合的分布里和真实数据的分布里,它的概率分别是多少。当pg​(x)=pdata​(x)时,结果为1/2,表示两个分布完全相同,最优的判别器也无法将其分辨出来。这个公式的意义是,从两个分布里面分别采样数据,用目标函数 min​ max​ V(D,G)训练一个二分类器,如果分类器输出的值都是0.5,则可以认为这两个分布是完全重合的。在统计学上,这个叫two sample test,用于判断两块数据是否来自同一分布。

注:two sample test是一个很实用的技术,比如在一个训练集上训练了一个模型,然后部署到另一个环境,需要看看测试数据的分布和训练数据是不是一样的,就可以像这样训练一个二分类器,看能否区分数据来源,避免部署的环境和我们训练的模型不匹配。

证明:

第一行是密度函数求积分,换元g(z)=x得到第二行(这个换元没看懂)

假设上面看懂了,那么后面的就简单了。在数据给定,G 给定的前提下,Pdata(x)与PG(x)都可以看作是常数,我们可以分别用a,b来表示他们,这样我们就可以得到如下的式子:

证毕。

2.然后训练G。把D的最优解代回目标函数,目标函数只和G相关,写作C(G):

现在只需要最小化这一项就行。

可以证明,当且仅当pg​=pdata​时有最优解 C(G)=−log4。

上式两项可以写成KL散度,KL散度用来衡量这两个分布的差异,

它表示了假如我们采取某种编码方式使编码Q分布所需的比特数最少,那么编码P分布所需的额外的比特数。假如P和Q分布完全相同,则其KL divergence 为零。

KL散度有很多有用的性质,最重要的是,它是非负的。KL散度为0,当且仅当P和Q在离散型变量的情况下是相同的分布,或者在连续型变量的情况下是“几乎处处”相同的。

使用KL散度,简化上面的式子:

又因为JS散度定义为:

所以进一步化简成:

要求 minC(G),当且仅当最后一项等于0的时候成立(JS散度≥0),此时pg​=pdata,表示两个分布完全相同,带入到D*(x)表达式,结果为1/2,最优的判别器也无法将其分辨出来。

注:JS散度跟KL散度的区别是前者是对称的,pg​和 pdata​可以互换,而后者不对称。

  综上所述,目标函数 min ​max​V(D,G)有全局最优解,这个解当且仅当 pg​=pdata​时成立,也就是生成器学到的分布等于真实数据的分布,可以取得最优生成器。

The global minimum of the virtual training criterion C(G) is achieved if and only if pg = pdata. At that point, C(G) achieves the value - log 4.

收敛证明

这部分证明了:给定足够的训练数据和正确的环境,在算法1中每一步允许D达到最优解的时候,对G进行下面的迭代:

训练过程将收敛到pg​=pdata​,此时生成器G是最优生成器。

其实我们每次只是k个steps训练D,离上述前提条件还很远,结论是否真的适用,就不那么好说了

6 GAN的优势与缺陷

参考GAN论文阅读——原始GAN(基本概念及理论推导)_StarCoo的博客-CSDN博客_gan 原始论文

优势

与其他生成式模型相比较,生成式对抗网络有以下四个优势深度 | OpenAI Ian Goodfellow的Quora问答:

比其它模型生成效果更好(图像更锐利、清晰)。GAN能训练任何一种生成器网络(理论上-实践中,用 REINFORCE 来训练带有离散输出的生成网络非常困难)。大部分其他的框架需要该生成器网络有一些特定的函数形式,比如输出层是高斯的。重要的是所有其他的框架需要生成器网络遍布非零质量(non-zero mass)。不需要设计遵循任何种类的因式分解的模型,任何生成器网络和任何判别器都会有用。无需利用马尔科夫链反复采样,无需在学习过程中进行推断(Inference),回避了近似计算棘手的概率的难题。

问题

GAN目前存在的主要问题:

难以收敛(non-convergence)

目前面临的基本问题是:所有的理论都认为 GAN 应该在纳什均衡(Nash equilibrium)上有卓越的表现,但梯度下降只有在凸函数的情况下才能保证实现纳什均衡。当博弈双方都由神经网络表示时,在没有实际达到均衡的情况下,让它们永远保持对自己策略的调整是可能的深度深度 | OpenAI Ian Goodfellow的Quora问答

难以训练:崩溃问题(collapse problem)

GAN模型被定义为极小极大问题,没有损失函数,在训练过程中很难区分是否正在取得进展。GAN的学习过程可能发生崩溃问题,生成器开始退化,总是生成同样的样本点,无法继续学习。当生成模型崩溃时,判别模型也会对相似的样本点指向相似的方向,训练无法继续。

无需预先建模,模型过于自由不可控。

与其他生成式模型相比,GAN不需要构造分布函数,而是使用一种分布直接进行采样,从而真正达到理论上可以完全逼近真实数据,这也是GAN最大的优势。然而,这种不需要预先建模的方法缺点是太过自由了,对于较大的图片,较多的 pixel的情形,基于简单 GAN 的方式就不太可控了(超高维)。

所以可以看到,最终作者生成的图片分辨率都很低。在GAN 中,每次学习参数的更新过程,被设为D更新k回,G才更新1回,也是出于类似的考虑。

7代码实现

代码:李宏毅机器学习作业6-使用GAN生成动漫人物脸

CS231n对抗生成网络代码

pytorch版参考/tutorials/beginner/dcgan_faces_tutorial.html

8 影响

1.开创了GAN这个领域

2.GAN本身是无监督学习,不需要标注数据。GAN的思想应用于无监督学习。

3.其训练方式是用有监督学习的损失函数来做无监督学习(有监督的标签来源于数据是真实的还是生成的),所以训练上会高效很多。这也是之后bert之类自监督学习模型的灵感来源。再如Domain Adversarial Training李宏毅机器学习作业11——Transfer Learning,Domain Adversarial Training

9. 关于损失函数的讨论

下面是个人一些思考,简单来说,就是论文中对GAN的优化目标函数的设计看似复杂,其实就是一个二元分类损失函数

二元分类

GAN的训练过程是一个minmax训练,但是几乎没有人会真的使用梯度上升的方法,所以实作和理论有出入。下面先看二元分类问题的损失函数,希望该Loss function越小越好。y^是模型预测结果,y是标签

当y=1时,L(y^,y)=−log y^。如果y^越接近1,L(y^,y)≈0,表示预测效果越好;如果y^越接近0,L(y^,y)≈+∞,表示预测效果越差。

当y=0时,L(y^,y)=−log (1−y^)。如果y^越接近0,L(y^,y)≈0,表示预测效果越好;如果y^越接近1,L(y^,y)≈+∞,表示预测效果越差。

discriminator

下面是discriminator损失函数

套用二元分类的损失函数,让y^=D(y)。当数据采集自Pdata时,标签y=1,损失函数为−log y^;当数据采集自PG时,标签y=0,损失函数为−log (1−y^)。将二者相加,其实就是V(G,D)的相反数,也就是说,训练discriminator可以直接使用二元交叉熵损失(BCELoss),其中真实图片的label为1,生成的图片的label为0

r_label = torch.ones((bs)).to(self.device)f_label = torch.zeros((bs)).to(self.device)r_loss = self.loss(r_logit, r_label)f_loss = self.loss(f_logit, f_label)loss_D = (r_loss + f_loss) / 2

generator

下面是generator损失函数

套用二元分类的损失函数,让y^=D(G(z)),让标签y=1,则损失函数为−log y^,所以也可以直接使用二元交叉熵损失(BCELoss),只要指定label为1

loss_G = self.loss(f_logit, r_label)

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。