大名鼎鼎的深度學(xué)習(xí)之父Yann LeCun曾評(píng)價(jià)GAN是“20年來機(jī)器學(xué)習(xí)領(lǐng)域最酷的想法”。的確,GAN向世人展示了從無到有、無中生有的神奇過程,并且GAN已經(jīng)在工業(yè)界有著廣泛的應(yīng)用,是一項(xiàng)令人非常激動(dòng)的AI技術(shù)。今天我將和大家一起去了解GAN及其內(nèi)部工作原理,洞開GAN的大門。
?
本文盡量用淺顯易懂的語(yǔ)言來進(jìn)行表述,少用繁瑣的數(shù)學(xué)公式,并對(duì)幾個(gè)典型的GAN模型進(jìn)行講解。
一、GAN(GenerativeAdversarial Networks)
GAN全名叫Generative Adversarial Networks,即生成對(duì)抗網(wǎng)絡(luò),是一種典型的無監(jiān)督學(xué)習(xí)方法。在GAN出現(xiàn)之前,一般是用AE(AutoEncoder)的方法來做圖像生成的,但是得到的圖像比較模糊,效果始終都不理想。直到2014年,Goodfellow大神在NIPS2014會(huì)議上首次提出了GAN,使得GAN第一次進(jìn)入了人們的眼簾并大放異彩,到目前為止GAN的變種已經(jīng)超過400種,并且CVPR2018收錄的論文中有三分之一的主題和GAN有關(guān),可見GAN仍然是當(dāng)今一大熱門研究方向。
GAN的應(yīng)用場(chǎng)景非常廣泛,主要有以下幾個(gè)方面:
2.圖像翻譯。從真實(shí)場(chǎng)景的圖像到漫畫風(fēng)格的圖像、風(fēng)景畫與油畫間的風(fēng)格互換等等。
3.圖像修復(fù)。比如圖像去噪、去除圖像中的馬賽克(嘿嘿…)。
4.圖像超分辨率重建。衛(wèi)星、遙感以及醫(yī)學(xué)圖像中用的比較多,大大提升后續(xù)的處理精度。
(一) GAN原理簡(jiǎn)述
?
GAN的原理表現(xiàn)為對(duì)抗哲學(xué),舉個(gè)例子:警察和小偷的故事,二者滿足兩個(gè)對(duì)抗條件:
1.小偷不停的更新偷盜技術(shù)以避免被抓。
2.警察不停的發(fā)現(xiàn)新的方法與工具來抓小偷。
小偷想要不被抓就要去學(xué)習(xí)國(guó)外的先進(jìn)偷盜技術(shù),而警察想要抓到小偷就要盡可能的去掌握小偷的偷盜習(xí)性。兩者在博弈的過程中不斷的總結(jié)經(jīng)驗(yàn)、吸取教訓(xùn),從而都得到穩(wěn)步的提升,這就是對(duì)抗哲學(xué)的精髓所在。要注意這個(gè)過程一定是一個(gè)交替的過程,也就是說兩者是交替提升的。想象一下,如果一開始警察就很強(qiáng)大,把所有小偷全部抓光了,那么在沒有了小偷之后警察也不會(huì)再去學(xué)習(xí)新的知識(shí)了,偵查能力就得不到提升。反之亦然,如果小偷剛開始就很強(qiáng)大,警察根本抓不到小偷,那么小偷也沒有動(dòng)力學(xué)習(xí)新的偷盜技術(shù)了,小偷的偷盜能力也得不到提升,這就好比在訓(xùn)練神經(jīng)網(wǎng)絡(luò)時(shí)出現(xiàn)了梯度消失一樣。所以一定是一個(gè)動(dòng)態(tài)博弈的過程,這也是GAN最顯著的特性之一。
?
在講完了警察與小偷的故事之后,我們引入今天的主人公——GAN。
?
(二) 模型架構(gòu)圖
?
從上圖能夠看出GAN的整個(gè)網(wǎng)絡(luò)架構(gòu)是非常簡(jiǎn)單明了的,GAN由一個(gè)生成器(Generator)和一個(gè)判別器(Discriminator)組成, 兩者的結(jié)構(gòu)都是多層感知機(jī)(MLP),具體有多少層、每層多少個(gè)神經(jīng)元可以根據(jù)實(shí)際情況自行設(shè)計(jì),比較靈活。在這里,生成器充當(dāng)著“小偷”的角色,判別器就扮演“警察”的角色。為了方便講解,后面把生成器簡(jiǎn)稱為G,判別器簡(jiǎn)稱為D。
G:接收一個(gè)隨機(jī)噪聲向量z(比如z服從高斯分布),G的目標(biāo)就是通過這個(gè)噪聲來生成一個(gè)像真實(shí)樣本的假樣本。
?
D:判別一個(gè)樣本是真實(shí)樣本還是G自己造的假樣本。它接收一個(gè)樣本數(shù)據(jù)作為輸入,所以這個(gè)樣本可以是G生成的假樣本也可以是真實(shí)樣本
。它輸出一個(gè)標(biāo)量,標(biāo)量的數(shù)值代表了輸入樣本到底是真實(shí)樣本還是G生成的假樣本的概率。如果接近1,則代表是真實(shí)樣本,接近于0則代表是生成器生成的假樣本,所以此時(shí)D最后一層的激活函數(shù)一定為sigmoid。
?
網(wǎng)絡(luò)的最終目標(biāo)是在D很強(qiáng)大的同時(shí),G生成的假樣本送給D后其輸出值變?yōu)?.5,說明G已經(jīng)完全騙過了D,即D已經(jīng)區(qū)分不出來輸入的樣本到底是還是
,從而得到一個(gè)生成效果很好的G。
?
損失函數(shù)的設(shè)計(jì):
從上面的式子可以看出,損失函數(shù)是兩個(gè)分布各自期望的和,其中是真實(shí)數(shù)據(jù)的概率分布,
是生成器所生成的假樣本的概率分布。對(duì)于D,它的目的是讓
中的樣本的輸出結(jié)果盡可能的大,即
變大,而讓
生成的樣本x的輸出結(jié)果盡可能的小,即
變大,導(dǎo)致變大。對(duì)于G,它的目的是用噪聲z來生成一個(gè)假樣本x并讓D給出一個(gè)較大的值,即讓
變小,導(dǎo)致
變小。綜上,我們得出:
(三) GAN的訓(xùn)練流程
假設(shè)batch_size=m,則在每一個(gè)epoch中:
先訓(xùn)練判別器k(比如3)次:
1. 從噪聲分布z(比如高斯分布)中隨機(jī)采樣出m個(gè)噪聲向量:。
2. 從真實(shí)樣本x中隨機(jī)采樣出m個(gè)樣本:。
3. 用梯度下降法使損失函數(shù): 與1之間的二分類交叉熵減?。ㄒ?yàn)樽詈笈袆e器最后一層的激活函數(shù)為sigmoid,所以要與0或者1做二分類交叉熵,這也是為什么損失函數(shù)要取log的原因)。
4. 用梯度下降法使損失函數(shù):與0之間的二分類交叉熵減小。
5. 所以判別器的總損失函數(shù)即讓d_loss越小越好。注意在訓(xùn)練判別器的時(shí)候生成器中的所有參數(shù)要固定住,即不參加訓(xùn)練。
?
再訓(xùn)練生成器1次:
1. 從噪聲分布中隨機(jī)采樣出m個(gè)噪聲向量:。
2. 用梯度下降法使損失函數(shù):與1之間的二分類交叉熵減小。
3. 所以生成器的損失函數(shù)即讓g_loss越小越好。注意在訓(xùn)生成器的時(shí)候判別器中的所有參數(shù)要固定住,即不參加訓(xùn)練。
?
直到所有epoch執(zhí)行完畢,訓(xùn)練結(jié)束。
從訓(xùn)練方法中可以看出,生成器和判別器是交替進(jìn)行訓(xùn)練的,呈現(xiàn)出一種動(dòng)態(tài)博弈的思想,非常有意思。不過在訓(xùn)練的時(shí)候還有一些注意事項(xiàng):
?
1.在訓(xùn)練G的時(shí)候D中的參數(shù)不參加訓(xùn)練,即不需要梯度反傳。同理,訓(xùn)練D的時(shí)候G中的參數(shù)不參加訓(xùn)練。
?
2.為了讓D保持在一個(gè)相對(duì)較高的評(píng)判水平,從而更好的訓(xùn)練G。在每一個(gè)epoch內(nèi),先對(duì)D進(jìn)行k(比如k=3)次訓(xùn)練,然后訓(xùn)練G一次,加快網(wǎng)絡(luò)的收斂速度。
?
3.在原始論文中,作者在訓(xùn)練G的時(shí)候給出的公式是,然而這個(gè)公式有一些隱患,因?yàn)樵谟?xùn)練的初始階段,G生成的樣本和真實(shí)樣本間的差異一般會(huì)很大,此時(shí)D能很輕松的分辨兩種樣本,導(dǎo)致
一直趨近于0,此時(shí)梯度消失,G也就得不到訓(xùn)練,所以這里的策略是
,上面訓(xùn)練過程的闡述中已經(jīng)對(duì)該處的損失函數(shù)做了更正。
?
(四) 損失函數(shù)相關(guān)數(shù)學(xué)推導(dǎo)
?
我們先將G中的參數(shù)固定住,此時(shí)的噪聲向量通過G后所生成的樣本是一一對(duì)應(yīng)的,則有如下映射:??
由此將由兩個(gè)數(shù)學(xué)期望的和組成的展開:
由于和
是固定的常量,另它們等于a,b。令
,得到
,由于是唯一極值點(diǎn),則必為最值點(diǎn),也能夠證明在
時(shí),其二階導(dǎo)小于0,那么該最值點(diǎn)為全局最大值點(diǎn)。
?
所以,當(dāng)G固定住的時(shí)候,不斷的訓(xùn)練D中的參數(shù),理論上可以讓D達(dá)到最大值:
。
此時(shí)將帶入進(jìn)
中,得到:
對(duì)于兩個(gè)概率分布和
,它們之間的KL散度就是數(shù)據(jù)的原始分布與近似分布的概率的對(duì)數(shù)差的期望值,其公式為:
所以此時(shí)得到:
再將兩個(gè)KL散度的和合并成JS散度,得到:
從上式可以看出,如果G要讓最小,必須要讓
和
間的JS散度最小,而JS散度的最小值為0,此時(shí)兩個(gè)分布完全重合,即
理論上的最小值為
,此時(shí)存在唯一解:
使得損失函數(shù)達(dá)到全局最小值,即生成器完美的實(shí)現(xiàn)了生成真實(shí)數(shù)據(jù)的過程,完全掌握了真實(shí)數(shù)據(jù)的概率分布。
(五) 總結(jié)
?
1.GAN的開山之作。
2.GAN的本質(zhì)其實(shí)是利用神經(jīng)網(wǎng)絡(luò)強(qiáng)大的非線性擬合能力來學(xué)習(xí)從一個(gè)任意先驗(yàn)的噪聲分布到真實(shí)數(shù)據(jù)分布的非線性映射,從而讓生成器具有能夠產(chǎn)生逼真樣本的能力。
3.早期GAN的訓(xùn)練非常不穩(wěn)定導(dǎo)致訓(xùn)練難度大,還容易出現(xiàn)梯度爆炸、mode collapse等問題。mode collapse的意思就是生成的樣本大量集中于部分真實(shí)樣本,那么就是很嚴(yán)重的mode collapse。以生成動(dòng)漫頭像圖片為例,從下圖中能夠明顯的看出,紅框標(biāo)記的圖像重復(fù)出現(xiàn)了很多次,即存在一定的mode collapse。
二、DCGAN(Deep Convolutional Generative Adversarial Networks)
在GAN被提出之后,GAN的熱度曲線呈指數(shù)式增長(zhǎng),期間在原始GAN的結(jié)構(gòu)基礎(chǔ)上進(jìn)行改進(jìn)的GAN變種層出不窮,其中最具代表性的當(dāng)屬DCGAN了,我們來看看它對(duì)原始GAN有什么創(chuàng)新:
1.將兩個(gè)多層感知機(jī)替換為兩個(gè)卷積神經(jīng)網(wǎng)絡(luò)。即將CNN融合進(jìn)GAN中,極大的加速了GAN在圖像領(lǐng)域中應(yīng)用的步伐,此后許多新提出的GAN都一直在沿用DCGAN的網(wǎng)絡(luò)架構(gòu)。 2.創(chuàng)新性的將反卷積(也叫轉(zhuǎn)置卷積)操作應(yīng)用于生成器中。 3.通過大量實(shí)驗(yàn),總結(jié)出一套構(gòu)建網(wǎng)絡(luò)時(shí)很有用的trick。
(一) 反卷積
?
常見的上采樣方式有三種:雙線性插值,反卷積(也叫轉(zhuǎn)置卷積)和反池化。鑒于篇幅所限,除反卷積以外的兩種上采樣方法就不在這里介紹了。
?
常規(guī)的卷積操作一般會(huì)導(dǎo)致圖像的尺寸越來越小,同時(shí)圖像深度在逐漸增加。而反卷積則使圖像尺寸越來越大,而深度在逐漸減小。所以反卷積是卷積操作的逆運(yùn)算,也就是說反卷積的正向傳播是卷積的反向傳播,其反向傳播是卷積的正向傳播,本文力求用形象的過程來展現(xiàn)反卷積的工作原理(注:下文所闡述的反卷積工作方式為tensorflow機(jī)器學(xué)習(xí)框架反卷積的底層實(shí)現(xiàn)方法,其他框架的底層實(shí)現(xiàn)方法可能略有不同)。
?
若輸入為3*3大小的單通道圖像:
考慮卷積核大小kernel_size=3*3,stride=2,padding=same的反卷積操作,且卷積核為:
如果stride=2,那么就在輸入圖像的每行和每列之間插入(stride-1)行(列)的零元素,另外還需要在補(bǔ)零后的矩陣的左邊和上邊添加額外的(stride-1)行(列)的零元素:
如果卷積核的大小kernel_size=3,且padding=same的情況下,我們知道在正常的卷積模式下是要上、下、左、右各添加(kernel_size-1)/2個(gè)行(列)元素,他們的初始值都為0,以此來保證輸出圖像與輸入圖像的大小是相同的,所以這里也采取相同的padding操作。這里簡(jiǎn)單說明一下:如果kernel_size=4,那么(kernel_size-1)/2=1.5,無法整除,那么此時(shí)左方和上方添加一行(列)零元素,右方和下方添加兩行(列)零元素,總之要保證添加的總行(列)數(shù)要和kernel_size-1是相等的,這也是tensorflow機(jī)器學(xué)習(xí)框架在卷積操作中padding=same時(shí)的填補(bǔ)方法。所以現(xiàn)在輸入圖像變成了這樣:
此時(shí)輸入圖像的尺寸由3*3變成了8*8,我們用kernel_size=3,stride=1,padding=valid的方式對(duì)這張圖進(jìn)行常規(guī)的卷積操作,則輸出尺寸變?yōu)椋篐=(8+0-3)/1+1=6,W=(8+0-3)/1+1=6。注意這步操作中的kernel_size是和反卷積核的kernel_size是保持一致的,stride固定為1,而且不進(jìn)行padding操作,因?yàn)榍懊嬉呀?jīng)padding過了,得到:
? ?我們用tensorflow做個(gè)小實(shí)驗(yàn),來驗(yàn)證上面算法的正確性。
?
輸出:
輸出結(jié)果和我們自己推導(dǎo)的完全一致!可見,反卷積也僅僅是卷積操作而已,與正常卷積使用相同大小的卷積核,只不過反卷積需要通過特定的規(guī)則對(duì)輸入tensor通過padding 0元素的方式處理一下。這樣我們最終得到的輸出圖像尺寸要比原圖像大,即實(shí)現(xiàn)了上采樣的功能。怎么樣,是不是非常簡(jiǎn)單。
?
反卷積的應(yīng)用領(lǐng)域非常廣泛,不僅僅在GAN中,還在圖像分割以及feature map的可視化領(lǐng)域有著廣泛的應(yīng)用。好了,簡(jiǎn)單講完反卷積后,讓我們回到DCGAN。
?
(二) 網(wǎng)絡(luò)實(shí)現(xiàn)上的一些tirck
?
1. 在生成器與判別器中,將所有池化層替換為步長(zhǎng)大于1的卷積操作,即拋棄所有池化層,目的是讓網(wǎng)絡(luò)去學(xué)習(xí)屬于它自己的上(下)采樣方式。想了一下,確實(shí)是非常有效的trick,因?yàn)樵趫D像分割領(lǐng)域中,maxpooling操作會(huì)破會(huì)圖像的邊緣與細(xì)節(jié),導(dǎo)致分割結(jié)果很粗糙,所以一般都通過別的辦法來替代maxpooling,以保證分割結(jié)果的細(xì)節(jié)完好。
?
2. 移除全局平均池化層,全局平均池化在圖像分類網(wǎng)絡(luò)中有著舉足輕重的地位,作者在做實(shí)驗(yàn)的過程中發(fā)現(xiàn)在判別器中用全局平均池化再接全連接層雖然能夠增加模型的穩(wěn)定性,但同時(shí)嚴(yán)重減緩了模型的收斂速度,所以決定移除。
?
3. 除了生成器的最后一層和判別器的輸入層,其余層都做batch normalization操作。是一個(gè)非常有助于網(wǎng)絡(luò)快速收斂的trick。作者發(fā)現(xiàn)如果全部層都用batch normalization,容易發(fā)生mode collapse現(xiàn)象,并使得模型變得不穩(wěn)定。
?
4. 生成器最后一層的激活函數(shù)采用tanh,其余層為relu激活函數(shù)。而判別器中則全部采用leaky relu激活函數(shù)。
?
(三) DCGAN中生成器的網(wǎng)絡(luò)結(jié)構(gòu)
網(wǎng)絡(luò)的整體架構(gòu)和原始GAN是差不多的,不同的僅僅是生成器和判別器的內(nèi)部結(jié)構(gòu),由MLP換成了CNN。從圖中來看,主要是由一個(gè)激活函數(shù)為relu的全連接層,三個(gè)激活函數(shù)為relu的反卷積層,以及最后的激活函數(shù)為tanh的反卷積層,將一個(gè)長(zhǎng)度為100滿足正態(tài)分布(或者均勻分布)的向量z變成一個(gè)大小為64*64的3通道圖像,這也是生成器生成的最終圖像。判別器在結(jié)構(gòu)上與生成器是完全對(duì)稱的,類似于常規(guī)的分類網(wǎng)絡(luò),這里不再贅述。
注意:由于生成器最后一層的激活函數(shù)為tanh,因此輸出值的范圍在[-1, 1]上,所以真實(shí)圖片樣本也必須要進(jìn)行縮放范圍一致的歸一化操作,即,令
,將輸入樣本x上的像素值都?xì)w一化到[-1, 1]上,再將這個(gè)歸一化后的圖片送入判別器中,以此來保證每一個(gè)輸入進(jìn)判別器的樣本分布區(qū)間的一致性。當(dāng)然也可以采用別的歸一化方法,只要能讓
就好。
?
(四) 用DCGAN在MNIST數(shù)據(jù)集上訓(xùn)練手寫數(shù)字生成
?
開源代碼倉(cāng)庫(kù)地址: https://github.com/carpedm20/DCGAN-tensorflow
在訓(xùn)練了30個(gè)epoch后,我把每個(gè)epoch生成器生成的100張圖片存下來并縮小做成動(dòng)態(tài)圖:
?
可以看出并沒有出現(xiàn)mode collapse現(xiàn)象,生成樣本具有一定的多樣性,效果還不錯(cuò)。其實(shí)主要還是數(shù)據(jù)集比較簡(jiǎn)單,圖片比較小,復(fù)雜紋理信息不多,比較容易生成。
?
生成的這些數(shù)據(jù)就可以用作手寫數(shù)字識(shí)別的訓(xùn)練數(shù)據(jù)。但是這些數(shù)據(jù)是沒有標(biāo)簽的,然而手寫數(shù)字識(shí)別為監(jiān)督學(xué)習(xí),難道還要對(duì)它們進(jìn)行人工標(biāo)注?這個(gè)問題我們留到下一個(gè)小節(jié)來解決。
?
(五) 總結(jié)
?
1.訓(xùn)練方法和訓(xùn)練原始GAN的方法保持一致。
2.將兩個(gè)MLP替換成為兩個(gè)CNN,生成的圖像較原始GAN來說質(zhì)量更高,更逼真。
3.通過大量實(shí)驗(yàn)總結(jié)出一套非常有用的trick,使得DCGAN在訓(xùn)練時(shí)的穩(wěn)定性相比原始GAN有顯著改善,要知道原始GAN是非常難訓(xùn)練的。
4.后面要講的模型中G和D的架構(gòu)均和DCGAN保持一致,便不再贅述。
三、InfoGAN(InformationMaximizing Generative Adversarial Nets)
?
DCGAN已經(jīng)能夠生成足夠逼真的圖像了,但是它直接將噪聲向量z作為G的輸入,沒有為z添加任何限制,導(dǎo)致我們根本不知道G主要用到了z的哪個(gè)維度來生成圖片,即已經(jīng)將z進(jìn)行高度耦合處理,所以z的維度信息對(duì)于真實(shí)數(shù)據(jù)來說不具有語(yǔ)義特征,也就是說是不可解釋的。
拿上面的圖為例,我們發(fā)現(xiàn)第三個(gè)“7”中間出現(xiàn)了一個(gè)橫線,但是為什么會(huì)出現(xiàn)這個(gè)橫線,誰(shuí)也不知道,為了讓GAN具有可解釋性,比較有代表性的GAN變體——InfoGAN就出現(xiàn)啦,為了解決語(yǔ)義問題,InfoGAN的作者對(duì)損失函數(shù)進(jìn)行了一些小的改進(jìn),一定程度上讓網(wǎng)絡(luò)學(xué)習(xí)到了可解釋的特征表示,即作者文中所說的interpretable reptesentation。
?
(一)原理闡述
既然要讓輸入的噪聲向量z帶有一定的語(yǔ)意信息,那就人為的為它添加上一些限制,于是作者把G的輸入看成兩部分:一部分就是噪聲z,可以將它看成是不可壓縮的噪聲向量。另一部分是若干個(gè)離散的和連續(xù)的latent variables(潛變量)所拼接而成的向量c,用于代表生成數(shù)據(jù)的不同語(yǔ)意信息。
?
以MNIST數(shù)據(jù)集為例,可以用一個(gè)離散的隨機(jī)變量(0-9,用于表示生成數(shù)字的具體數(shù)值)和兩個(gè)連續(xù)的隨機(jī)變量(假設(shè)用于表示筆劃的粗細(xì)與傾斜程度)。所以此時(shí)的c由一個(gè)離散的向量(長(zhǎng)度為10)、兩個(gè)連續(xù)的向量(長(zhǎng)度為1)拼接而成,即c長(zhǎng)度為12。
上圖是作者用InfoGAN在MNIST數(shù)據(jù)集上的部分結(jié)果,通過保持離散變量不變、逐漸增大某一個(gè)連續(xù)的潛變量(論文中是從-2到2),可以看出從左到右數(shù)字的筆劃逐漸增粗,具有很強(qiáng)的可解釋性。所以上一小節(jié)遺留的問題就迎刃而解了,理想情況下,我們可以通過這些潛變量來生成無數(shù)個(gè)滿足我們需求的手寫數(shù)字了!也就不需要再為生成的數(shù)據(jù)人工打標(biāo)簽了。
?
所以此時(shí)對(duì)于G的輸入來說不再是單純的噪聲z了,而是z和一個(gè)長(zhǎng)度為12的向量c,但是僅僅有這個(gè)設(shè)定還不夠,因?yàn)樯善鞯膶W(xué)習(xí)具有很高的自由度,它很容易找到一個(gè)解,使得:
此時(shí)在生成器看來,z和c是兩個(gè)完全獨(dú)立的向量,有沒有c都一樣可以生成數(shù)據(jù),這樣生成器就完全繞過了c,導(dǎo)致它起不到應(yīng)有的作用。
?
為了解決這個(gè)問題,作者通過優(yōu)化GAN的損失函數(shù)來讓和c強(qiáng)制產(chǎn)生聯(lián)系,使得兩者完成建模。作者從信息論中得到啟發(fā),提出基于互信息(mutual information)的正則化項(xiàng)。在信息論中,互信息
用來衡量“已知隨機(jī)變量Y的情況下,可以獲得多少有關(guān)隨機(jī)變量X的信息”,其計(jì)算公式為:
上式中,H表示計(jì)算熵值,所以I(X;Y)是兩個(gè)熵值的差。H(X|Y)衡量的是“給定隨機(jī)變量的情況下,隨機(jī)變量X的不確定性”。從公式中可以看出,若X和Y是獨(dú)立的,此時(shí)H(X)=H(X|Y),得到I(X;Y)=0,為最小值。若X和Y有非常強(qiáng)的關(guān)聯(lián)時(shí),即已知Y時(shí),X沒有不確定性,則H(X|Y)=0????,I(X;Y)達(dá)到最大值。所以為了讓G(z,c)和c之間產(chǎn)生盡量明確的語(yǔ)義信息,必須要讓它們二者的互信息足夠的大,所以我們對(duì)GAN的損失函數(shù)添加一個(gè)正則項(xiàng),就可以改寫為:
注意屬于G的損失函數(shù)的一部分,所以這里為負(fù)號(hào),即讓該項(xiàng)越大越好,使得G的損失函數(shù)變小。其中
為平衡兩個(gè)損失函數(shù)的權(quán)重。但是,在計(jì)算
的過程中,需要知道后驗(yàn)概率分布
,而這個(gè)分布在實(shí)際中是很難獲取的,因此作者在解決這個(gè)問題時(shí)采用了變分推理的思想,引入變分分布
來逼近
,進(jìn)而通過輪流迭代的方法用
去逼近
的下界,得到最終的網(wǎng)路損失函數(shù):
(二)網(wǎng)絡(luò)結(jié)構(gòu)
?
從上圖可以清晰的看出,雖然在設(shè)計(jì)InfoGAN時(shí)的數(shù)學(xué)推導(dǎo)比較復(fù)雜,但是網(wǎng)絡(luò)架構(gòu)還是非常簡(jiǎn)單明了的。G和D的網(wǎng)絡(luò)結(jié)構(gòu)和DCGAN保持一致,均由CNN構(gòu)成。在此基礎(chǔ)上,改動(dòng)的地方主要有:
?
1.G的輸入不僅僅是噪聲向量z了,而是z和具有語(yǔ)意信息的淺變量c進(jìn)行拼接后的向量輸入給G。
2.D的輸出在原先的基礎(chǔ)上添加了一個(gè)新的輸出分支Q,Q和D共享全部分卷積層,然后各自通過不同的全連接層輸出不同的內(nèi)容:Q的輸出對(duì)應(yīng)于的c的概率分布,D則仍然判別真?zhèn)巍?/span>
?
(三) InfoGAN的訓(xùn)練流程
?
假設(shè)batch_size=m,數(shù)據(jù)集為MNIST,則根據(jù)作者的方法,不可壓縮噪聲向量的長(zhǎng)度為62,離散潛變量的個(gè)數(shù)為1,取值范圍為[0, 9],代表0-9共10個(gè)數(shù)字,連續(xù)淺變量的個(gè)數(shù)為2,代表了生成數(shù)字的傾斜程度和筆劃粗細(xì),最好服從[-2, 2]上的均勻分布,因?yàn)檫@樣能夠顯式的通過改變其在[-2,2]上的數(shù)值觀察到生成數(shù)據(jù)相應(yīng)的變化,便于實(shí)驗(yàn),所以此時(shí)輸入變量的長(zhǎng)度為62+10+2=74。
?
則在每一個(gè)epoch中:
?
先訓(xùn)練判別器k(比如3)次:
?
1. 從噪聲分布(比如高斯分布)中隨機(jī)采樣出m個(gè)噪聲向量:。
2.從真實(shí)樣本x中隨機(jī)采樣出m個(gè)樣本:
3. 用梯度下降法使損失函數(shù)real_loss:與1之間的二分類交叉熵減?。ㄒ?yàn)樽詈笈袆e器最后一層的激活函數(shù)為sigmoid,所以要與0或者1做二分類交叉熵,這也是為什么損失函數(shù)要取log的原因)。
4.用梯度下降法使損失函數(shù)fake_loss:與0之間的二分類交叉熵減小。
5. 所以判別器的總損失函數(shù)d_loss:即讓d_loss減小。注意在訓(xùn)練判別器的時(shí)候分類器中的所有參數(shù)要固定住,即不參加訓(xùn)練。
?
再訓(xùn)練生成器1次:
?
1. 從噪聲分布中隨機(jī)采樣出m個(gè)噪聲向量:。
2. 從離散隨機(jī)分布中隨機(jī)采樣m個(gè)長(zhǎng)度為10、one-hot編碼格式的向量:。
3. 從兩個(gè)連續(xù)隨機(jī)分布中各隨機(jī)采樣m個(gè)長(zhǎng)度為1的向量:,
4. 將上面的所有向量進(jìn)行concat操作,得到長(zhǎng)度為74的向量,共m個(gè),并記錄每個(gè)向量所在的位置,便于計(jì)算損失函數(shù)。
5. 此時(shí)g_loss由三部分組成:一個(gè)是與1之間的二分類交叉熵、一個(gè)是Q分支輸出的離散淺變量的預(yù)測(cè)值和相應(yīng)的輸入部分的交叉熵以及Q分支輸出的連續(xù)淺變量的預(yù)測(cè)值和輸入部分的互信息,并為這三部分乘上適當(dāng)?shù)钠胶庖蜃?,其中互信息?xiàng)的系數(shù)是負(fù)的。
6. 用梯度下降法使越小越好。注意在訓(xùn)生成器的時(shí)候判別器中的所有參數(shù)要固定住,即不參加訓(xùn)練。
直到所有epoch執(zhí)行完畢,訓(xùn)練結(jié)束。
?
(四)總結(jié)
?
1.G的輸入不再是一個(gè)單一的噪聲向量,而是噪聲向量與潛變量的拼接。
2.對(duì)于潛變量來說,G和D組成的大網(wǎng)絡(luò)就好比是一個(gè)AutoEncoder,不同之處只是將信息編碼在了圖像中,而非向量,最后通過D解碼還原回。
3.D的輸出由原先的單一分支變?yōu)閮蓚€(gè)不同的分支。
4.從信息熵的角度對(duì)噪聲向量和潛變量的關(guān)系完成建模,并通過數(shù)學(xué)推導(dǎo)以及實(shí)驗(yàn)的方式證明了該方法確實(shí)有效。
5.通過潛變量,使得G生成的數(shù)據(jù)具有一定的可解釋性。
四、WGAN(Wasserstein GAN)
(一) Wassersteindistance
從前面的章節(jié)我們知道,DCGAN的損失函數(shù)本質(zhì)上是讓與
間的JS散度盡可能的小,但是很有可能出現(xiàn)
與
兩個(gè)分布根本就沒有重疊的地方,對(duì)于任意兩個(gè)沒有交疊、距離足夠遠(yuǎn)的分布,它們之間的JS散度恒定為log2,導(dǎo)致梯度消失,此時(shí)
不可能在訓(xùn)練的過程中向
的方向移動(dòng),D也就得不到訓(xùn)練。而WGAN就著手于從損失函數(shù)上進(jìn)行優(yōu)化,使得訓(xùn)練更加穩(wěn)定。
WGAN的作者用大量的數(shù)學(xué)推導(dǎo)來證明了基于二分類交叉熵的損失函數(shù)的缺陷與不合理性,并提出了一種新的損失函數(shù),取名為Wasserstein distance,這個(gè)損失函數(shù)在任何位置都有著相對(duì)平滑的梯度,由于篇幅所限,我盡量直觀的向大家闡述,我們先來看一下網(wǎng)絡(luò)結(jié)構(gòu)。
?
(二)網(wǎng)絡(luò)結(jié)構(gòu)
乍一看怎么和DCGAN相差無幾呢,是的,作者在網(wǎng)絡(luò)結(jié)構(gòu)上的變動(dòng)僅僅是去掉了DCGAN中D最后一層的sigmoid激活函數(shù),使得網(wǎng)絡(luò)最后一層的輸出變成線性的了。
?
(三) 原理闡述
?
究竟什么是Wasserstein distance呢?Wasserstein distance用來衡量來兩個(gè)分布間的距離,而且即使兩個(gè)分布間沒有交疊,也會(huì)根據(jù)分布相距的遠(yuǎn)近程度給出一個(gè)相應(yīng)的數(shù)值,即損失函數(shù)的值會(huì)隨著兩個(gè)分布間的距離的遠(yuǎn)近程度而動(dòng)態(tài)的發(fā)生改變,在這篇論文中,作者初步給出Wasserstein distance的表達(dá)式:
從直觀上理解,損失函數(shù)是兩個(gè)期望的差值,并讓這個(gè)差值盡可能的大,即使盡可能的大,同時(shí)使
盡可能的小,但僅從這個(gè)表達(dá)式是不足以讓訓(xùn)練變得收斂的,我們來看下圖:
雖然能夠準(zhǔn)確對(duì)生成樣本和真實(shí)樣本完美的區(qū)分,但是在沒有了sigmoid函數(shù)限制值域的情況下會(huì)讓D在真實(shí)樣本上的輸出值趨于無窮大,而在生成樣本上的輸出值趨于無窮小,導(dǎo)致
永遠(yuǎn)不會(huì)收斂,為了避免出現(xiàn)這個(gè)問題,作者在損失函數(shù)中添加了一個(gè)額外的限制條件:
限制是指:在樣本空間中,要求判別器函數(shù)D(x)梯度值不大于一個(gè)有限的常數(shù)k,通過權(quán)重限制的方式保證了權(quán)重參數(shù)的有界性,間接限制了其梯度信息。
?
目的就是讓D的輸出曲線盡可能的平滑,不讓它趨向與無窮大或者無窮小,那么怎么限制呢?在作者2017年發(fā)布的WGAN中只是對(duì)D的權(quán)重進(jìn)行簡(jiǎn)單的clipping操作:
?
人為的規(guī)定一個(gè)閾值c,并將D中的網(wǎng)絡(luò)參數(shù)數(shù)值全部限制在上[-c,c],對(duì)于D中的任意一個(gè)參數(shù)w,如果 w>c, 則令w=c。如果w<-c,則令w=-c,即始終保持,該操作稱為weight clipping,使得D的輸出曲線比較平滑。是的,就是這么簡(jiǎn)單!實(shí)驗(yàn)證明該算法雖然簡(jiǎn)單粗暴,但確實(shí)使得訓(xùn)練過程變得更加穩(wěn)定。另一方面,c的取值范圍很難確定,是一個(gè)依賴于經(jīng)驗(yàn)的數(shù)值。如果取的過小,網(wǎng)絡(luò)參數(shù)都被限制在了一個(gè)比較小的范圍,導(dǎo)致D的擬合能力受限。如果取的過大,又可能會(huì)讓D的輸出值趨近于無窮,網(wǎng)絡(luò)又無法收斂,所以它的取值極度依賴實(shí)驗(yàn),不過一般地,將c取為0.01是一個(gè)比較個(gè)合理的值。綜上,WGAN改動(dòng)的地方主要有以下三點(diǎn):
1.D最后一層去掉sigmoid激活函數(shù),所以它現(xiàn)在的輸出值不再代表二分類的概率了。 2.G和D的loss不再取log,即不再用與0或者1的二分類交叉熵作為損失函數(shù)了。 3.每次更新D的參數(shù)后,將其所有參數(shù)的絕對(duì)值截?cái)嗟讲怀^一個(gè)固定常數(shù)c(經(jīng)驗(yàn)數(shù)值,可以取為0.01),即weight clipping操作,其實(shí)本質(zhì)上就是對(duì)D的參數(shù)添加了一個(gè)簡(jiǎn)單粗暴的正則項(xiàng)。
(四) WGAN的訓(xùn)練流程
假設(shè)batch_size=m,則在每一個(gè)epoch中:
先訓(xùn)練判別器k(比如5)次:
1. 從噪聲分布(比如高斯分布)中隨機(jī)采樣出m個(gè)噪聲向量:.
2. 從真實(shí)樣本x中隨機(jī)采樣出m個(gè)真實(shí)樣本:
3. 用梯度下降法使損失函數(shù)越小越好(取負(fù)號(hào)的原因是一般的深度學(xué)習(xí)框架只能讓損失函數(shù)越來越小,所以這里加個(gè)負(fù)號(hào)就和原先最大化的邏輯保持一致了)。
4. 用梯度下降法使損失函數(shù)越小越好,并保存
生成的假樣本的結(jié)果,記
。
5. 所以判別器的總損失函數(shù),即讓d_loss越小越好。注意在訓(xùn)練判別器的時(shí)候分類器中的所有參數(shù)要固定住,即不參加訓(xùn)練。
6. 檢查D中所有可訓(xùn)練參數(shù)的值,將它們限制在一個(gè)人為規(guī)定的常數(shù)|c|內(nèi),即,令
(可以將c取為0.01)。
再訓(xùn)練生成器1次:
1. 從噪聲分布中隨機(jī)采樣出m個(gè)噪聲向量:用梯度下降法使損失函數(shù)
越小越好。注意在訓(xùn)生成器的時(shí)候判別器中的所有參數(shù)要固定住,即不參加訓(xùn)練。
直到所有epoch執(zhí)行完畢,訓(xùn)練結(jié)束。
?(五) 總結(jié)
?
1.修改一直在沿用的原始GAN損失函數(shù),提出一種新的損失函數(shù),使得GAN的訓(xùn)練變得比以前穩(wěn)定。
2.提出針對(duì)判別器的weight clipping操作,并經(jīng)過大量實(shí)驗(yàn)證明確實(shí)能夠讓訓(xùn)練變得穩(wěn)定、加快模型收斂,而且代碼實(shí)現(xiàn)上也非常簡(jiǎn)單,對(duì)DCGAN代碼的改動(dòng)不超過20行就能讓它變成WGAN。
3.模型是否能夠收斂高度依賴于超參數(shù)c的取值,而該參數(shù)的選取通常依賴于實(shí)驗(yàn)。如果選取得當(dāng),能夠提高網(wǎng)絡(luò)訓(xùn)練的穩(wěn)定性,如果選取不當(dāng),模型反而無法收斂。
五、WGAN_GP(WassersteinGAN with Gradient Penality)
(一) gradient penality
在提出了WGAN后,作者繼續(xù)在WGAN上進(jìn)行優(yōu)化,又給出了一種新的損失函數(shù),拋棄weight clipping,也就不再需要經(jīng)驗(yàn)常數(shù)c了,取而代之的是gradient penality(梯度懲罰),因此取名為WGAN_GP,也叫Improved_WGAN。
?
Gradient penality是指:對(duì)D的每一個(gè)輸入樣本x,使得。意思是對(duì)于任意一個(gè)輸入樣本x,用D的輸出結(jié)果D(x)對(duì)求梯度后的值的L2范數(shù)不大于1。
?
上面的解釋可能有些拗口,我們從一維空間中的函數(shù)f(x)來進(jìn)行闡述:一維函數(shù)f(x),對(duì)于任意輸入x,該函數(shù)滿足:,即任意一點(diǎn)的斜率的平方不大于1,進(jìn)而可以推出:
,可想而知f(x)的函數(shù)曲線是比較平滑的,所以稱為梯度懲罰。那為什么是L2范數(shù)呢而不是L1范數(shù)呢?原因很簡(jiǎn)單,L1范數(shù)會(huì)破壞一個(gè)函數(shù)的可微性呀,所以L2范數(shù)是非常合理的!
?
注意前面我說的是針對(duì)于D的每一個(gè)輸入樣本,都讓它滿足,實(shí)際上這是不現(xiàn)實(shí)的,所以作者又想了一個(gè)辦法來解決這個(gè)問題:假設(shè)從真實(shí)數(shù)據(jù)中采樣出來的一個(gè)點(diǎn)稱為x(這個(gè)點(diǎn)是高維空間中的點(diǎn)),G利用采樣得到的噪聲向量所生成的假數(shù)據(jù)稱為
在這兩點(diǎn)之間的某一個(gè)位置采樣一個(gè)點(diǎn)記為,即對(duì)于每一個(gè)
,盡量讓
。那么最常見的滿足上述要求的采樣方法就是線性采樣方法了,即在x與
所形成的超平面上任意選取一個(gè)點(diǎn)
,換句話說就是在生成樣本和真實(shí)樣本間做一個(gè)線性插值,所以存在
。
在新的損失函數(shù)閃亮登場(chǎng)之前,我們還有一個(gè)小小的優(yōu)化!因?yàn)樽髡咦詈蟀l(fā)現(xiàn),其實(shí)讓是最好的方案,而不是把1作為上、下限,別問我為什么,作者也不知道!因?yàn)槭峭ㄟ^大量的實(shí)驗(yàn)總結(jié)出來的。
?
那么WGAN_GP的核心就在線性插值這了,為了不讓這部分變得太抽象,我們用pytorch來實(shí)現(xiàn)一下插值這部分。
?
所以,新的損失函數(shù)可以寫為: ?
說了這么多,其實(shí)用數(shù)學(xué)公式表達(dá)出來還是非常簡(jiǎn)單的,式子中前兩項(xiàng)仍然是WGAN的損失函數(shù),只是新添加了一個(gè)正則項(xiàng),便是在真實(shí)數(shù)據(jù)和生成數(shù)據(jù)之間通過線性插值得到的點(diǎn),即盡量讓D對(duì)它的梯度的L2范數(shù)越接近于1,使
越大越好,通過該正則項(xiàng),能讓損失函數(shù)上的每一點(diǎn)都有較為平滑的梯度,訓(xùn)練也就更加穩(wěn)定,大大降低了訓(xùn)練GAN的難度。超參數(shù)
用于對(duì)這兩部分的損失函數(shù)進(jìn)行平衡,作者通過實(shí)驗(yàn)發(fā)現(xiàn)=10是一個(gè)比較合理的數(shù)值。
(二) WGAN_GP的訓(xùn)練流程
假設(shè)batch_size=m,則在每一個(gè)epoch中:
先訓(xùn)練判別器k(比如5)次:
1. 從噪聲分布(比如高斯分布)中隨機(jī)采樣出m個(gè)噪聲向量:。
2. 從真實(shí)樣本x中隨機(jī)采樣出m個(gè)真實(shí)樣本:。
3. 用梯度下降法使損失函數(shù)?越小越好(取負(fù)號(hào)的原因是一般的深度學(xué)習(xí)框架只能讓損失函數(shù)越來越小,所以這里加個(gè)負(fù)號(hào)就和原先最大化的邏輯保持一致了)。
4. 用梯度下降法使損失函數(shù)越小越好,并保存
生成的假樣本的結(jié)果,記為
。
5.在這m個(gè)假樣本與已經(jīng)得到的m個(gè)真實(shí)樣本進(jìn)行線性插值,得到m個(gè)插值樣本:。將m個(gè)插值樣本送入D中得到的結(jié)果對(duì)輸入求梯度,使
越小越好。
6.所以判別器的總損失函數(shù)d_loss =read_los-s + fake_loss + gp,即讓d_loss越小越好。注意在訓(xùn)練判別器的時(shí)候分類器中的所有參數(shù)要固定住,即不參加訓(xùn)練。為平衡兩個(gè)損失函數(shù)的權(quán)重,取為10是比較合理的數(shù)值。
再訓(xùn)練生成器1次:
從噪聲分布中隨機(jī)采樣出m個(gè)噪聲向量:用梯度下降法使損失函數(shù)g_loss:
越小越好。注意在訓(xùn)練生成器的時(shí)候判別器中的所有參數(shù)要固定住,即不參加訓(xùn)練。
直到所有epoch執(zhí)行完畢,訓(xùn)練結(jié)束。
?
(三) WGAN_GP小試牛刀
在寫這篇文章的時(shí)候,正好看到TinyMind舉辦了一個(gè)關(guān)于用GAN生成書法字體的比賽https://www.tinymind.cn/competitions/45 – ranking,當(dāng)時(shí)距離比賽結(jié)束僅剩三天時(shí)間,但是為了讓文章更充實(shí)一些,還是馬不停蹄的把數(shù)據(jù)集下載到本地,不說了,GAN就完了!
?
比賽目的是用GAN來生成圖片大小為128*128的書法字體圖片,評(píng)判標(biāo)準(zhǔn)是上傳10000張自己生成的書法字進(jìn)行系統(tǒng)評(píng)分,當(dāng)然質(zhì)量、多樣性越高越好。訓(xùn)練集中共有100種字,每種字又有400張不同的字體圖片,所以一共是40000張圖片,每張圖片的高、寬都在200到400之間,并且為灰度圖像,那么我們就來用WGAN_GP來完成這個(gè)小比賽!,參考開源代碼地址:https://github.com/igul222/improved_wgan_training,實(shí)現(xiàn)框架為tensorflow。
?
先來看看數(shù)據(jù)集長(zhǎng)什么樣吧。
?
這里我將每種字隨機(jī)抽出1個(gè)并resize到64*64進(jìn)行排列展示,所以正好100個(gè)不同的字,發(fā)現(xiàn)有一些根本不認(rèn)識(shí)!不過認(rèn)不認(rèn)識(shí)沒關(guān)系,對(duì)于網(wǎng)絡(luò)來說它需要的僅僅是數(shù)據(jù)而已。另外一點(diǎn)就是這里面有一些臟數(shù)據(jù),比如大字下面還有一些小字,這肯定不是我們期望的樣本,但是我在這里并沒有過濾掉這些臟數(shù)據(jù),一是工作量太大,不能自動(dòng)完成,需要人工檢查。二是先嘗試著訓(xùn)練一下,不行的話再想辦法剔除,事實(shí)證明對(duì)結(jié)果影響不大。
?
原repo的代碼只能生成64*64的圖片,所以需要對(duì)其網(wǎng)絡(luò)結(jié)構(gòu)進(jìn)行相應(yīng)的改進(jìn),使其能夠產(chǎn)生128*128的圖片,改進(jìn)的方案也非常簡(jiǎn)單:
?
1)將G的第一個(gè)全連接層的輸出神經(jīng)元個(gè)數(shù)擴(kuò)大為原先的兩倍,所以這時(shí)reshape后tensor深度變?yōu)樵鹊膬杀叮撕缶矸e核的個(gè)數(shù)每層都除以2。
?
2)將生成器最后一層的激活函數(shù)改為relu,接一個(gè)batch normalization,并在其后面再添加一個(gè)deconv層,激活函數(shù)為tanh。
?
3)將判別器的最后一層的全連接層改為卷積層,接一個(gè)batch normalization,激活函數(shù)為leaky relu,并重復(fù)一次,即再降采樣一次,reshape后再接一個(gè)單神經(jīng)元的全連接層就可以了,注意沒有激活函數(shù)。
?
4)因?yàn)槭切碌臄?shù)據(jù)了,所以數(shù)據(jù)讀取以及組織數(shù)據(jù)的代碼需要自己寫。損失函數(shù)、訓(xùn)練代碼不用動(dòng)??赡苄枰趯?shí)驗(yàn)中對(duì)學(xué)習(xí)率進(jìn)行調(diào)整。
?
在訓(xùn)練了40個(gè)epoch后,我把每個(gè)epoch生成器生成的100張圖片存下來并縮小做成動(dòng)態(tài)圖:
(gif 圖片太大,截取部分靜態(tài)圖)
?
可以看出生成的數(shù)據(jù)已經(jīng)趨于穩(wěn)定,變動(dòng)不大。由于時(shí)間有限再加上工作繁忙,沒有足夠的時(shí)間對(duì)網(wǎng)絡(luò)進(jìn)行優(yōu)化,排名沒進(jìn)前10,因?yàn)榍?0名才有獎(jiǎng)勵(lì)呀,重在參與嘛!將10000張生成的圖片上傳后,官網(wǎng)展示了部分圖片:
?
個(gè)人感覺效果一般吧,不過在參加了這個(gè)小比賽后讓我學(xué)到了很多知識(shí),也認(rèn)識(shí)到了自身的不足。
?
- 直接把數(shù)據(jù)的標(biāo)簽信息仍掉了,所有數(shù)據(jù)同等對(duì)待一起訓(xùn)的,導(dǎo)致最終數(shù)據(jù)的多樣性可能不夠高,拉低評(píng)分。
- 既然是比賽,采取合理的小技巧來達(dá)到更高的評(píng)分也是可以的。我們知道越大的圖片越不好生成,而64*64的圖片相對(duì)來說比較容易生成,也易于訓(xùn)練??梢灾簧?4*64的圖片,提交成績(jī)的時(shí)候再通過一些好的插值方法(比如雙三次插值)resize到128*128!賽后我知道確實(shí)有人是這樣做的。不過這種方法所生成的書法字肯定沒有直接生成128*128的圖片質(zhì)量高。
- 其實(shí)有很多比WGAN_GP更先進(jìn)、生成效果更好的網(wǎng)絡(luò),畢竟這篇文章的發(fā)表時(shí)間是在2017年,但是新手上路嘛,以穩(wěn)為主,就選擇了一個(gè)比較經(jīng)典的模型。
-
為以后的手寫字識(shí)別提供了不少思路,通過GAN來增加訓(xùn)練數(shù)據(jù)量是非??尚械姆椒?。
(四) 總結(jié)
?
1.為WGAN的損失函數(shù)提出了一種新的正則方法——gradient penality,從而更好的解決了訓(xùn)練GAN的過程中梯度消失的問題。
2.比標(biāo)準(zhǔn)WGAN擁有更快的收斂速度,并能生成更高質(zhì)量的樣本。
3.將resnet中的殘差塊成功應(yīng)用于生成器和判別器中,使網(wǎng)絡(luò)可以變得更深、同時(shí)能夠生成質(zhì)量更高的樣本,并且訓(xùn)練過程也更加穩(wěn)定。
4.不需要過多的調(diào)參,成功訓(xùn)練多種針對(duì)圖片的GAN結(jié)構(gòu)。
?
六 、總結(jié)
?
1.本文沿著GANDCGANInfoGANWGANWGAN_GP的路線來介紹GAN,其初衷是能讓大家對(duì)GAN有一個(gè)感性的了解,所以大量的數(shù)學(xué)公式推導(dǎo)沒有列出來。當(dāng)然,還有很多優(yōu)秀的GAN本文沒有涉及到,畢竟以入門為主嘛!相信在讀完本文后能夠讓大家更好的理解當(dāng)下比較新穎并且有意思的GAN。
?
2.其實(shí)GAN在最終的實(shí)現(xiàn)上都非常簡(jiǎn)單,比較難的地方是涉及模型損失函數(shù)的優(yōu)化以及相關(guān)數(shù)學(xué)推導(dǎo)、還有就是在現(xiàn)有網(wǎng)絡(luò)上的創(chuàng)新,從而提出一個(gè)新穎并且生成質(zhì)量高的GAN模型。
?
3.雖然GAN在圖像生成上取得了耀眼的成績(jī),但并沒有在NLP領(lǐng)域取得顯著成果。其中一個(gè)主要原因是圖像數(shù)據(jù)都是實(shí)數(shù)空間上的連續(xù)數(shù)據(jù),而NLP中大多都是離散數(shù)據(jù),例如分詞后的詞組。而對(duì)于連續(xù)型數(shù)據(jù),就可以略微改變合成的數(shù)據(jù),比如一個(gè)浮點(diǎn)類型的像素值為0.64,將這個(gè)值改為0.65是沒有問題的。但是對(duì)于離散型數(shù)據(jù),如果輸出了一個(gè)單詞”hello”,但接下來不能將其改為”hello+0.01”,因?yàn)楦緵]有這個(gè)單詞!所以NLP中應(yīng)用GAN是比較困難的。但并不代表沒有人研究這個(gè)方向,有一些學(xué)者已經(jīng)能夠?qū)AN應(yīng)用于NLP中了,大多數(shù)要與強(qiáng)化學(xué)習(xí)結(jié)合,感興趣的小伙伴可以讀一讀TextGAN、SeqGAN這兩篇文章。
?
4.由于平時(shí)對(duì)GAN的接觸比較少,再加上專業(yè)水平有限,文章中出錯(cuò)之處在所難免,還望多多包涵。
?
七、參考文獻(xiàn)
[1]IanJ. Goodfellow, Jean Pouget-Abadie and Mehdi Mirza, “Gererative AdversarialNetworks,” ArXiv preprint arXiv:1406.2661, 2014.
[2]AlecRadford, Luke Metz and Soumith Chintala, “Unsupervised Representation Learningwith Deep Convolutional Generative Adversarial Networks,” ArXiv preprintaxXiv:1511.06434, 2016.
[3]Xi Chen, Yan Duan and Rein Houthooft, “InfoGAN:Interpretable Representation Learning by Information Maximizing GenerativeAdversarial Nets,” ArXiv Preprint arXiv:1606.03657, 2016.
[4]Martin Arjovsky, Soumith Chintala and Léon Bottou, “WassersteinGAN,” ArXiv preprint arXiv:1606.03657, 2016.
[5]Ishaan Gulrajani, Faruk Ahmed and Martin Arjovsky, “ImprovedTraining of Wasserstein GANs”, ArXiv preprint arXiv:1704.00028, 2017.
?
BOUT
關(guān)于作者
馬振宇:達(dá)觀數(shù)據(jù)算法工程師,負(fù)責(zé)達(dá)觀數(shù)據(jù)OCR方向的相關(guān)算法研發(fā),優(yōu)化工作。