Generative Adversarial Networks (GAN) : A Gentle Introduction (UPDATED)
Generative Adversarial Networks (GAN) : A Gentle Introduction (UPDATED)
net/publication/316382604
CITATIONS READS
0 4,678
1 author:
Su Wang
University of Texas at Austin
14 PUBLICATIONS 42 CITATIONS
SEE PROFILE
Some of the authors of this publication are also working on these related projects:
All content following this page was uploaded by Su Wang on 24 April 2018.
Su Wang
Department of Statistics and Data Science
University of Texas at Austin
Abstract
In this tutorial, I present an intuitive introduction to the Generative Adversarial
Network (GAN) [1], invented by Ian Goodfellow of Google Brain, overview the
general idea of the model, and describe the algorithm for training it as per the
original work. I further briefly introduce the application of GAN in Natural
Language Processing to show its flexibility and strong potential as a neural network
architecture. In lieu of the discussion, I also present simple Tensorflow code1
for the original GAN [1] and an important variant — Wasserstein GAN [26,27], to
help the reader getting a quick start in practical applications.
1 Overview
Generative Adversarial Networks (GAN) [1] is a deep learning framework in which two models, a
generative model G and a discriminative model D, are trained simultaneously. The objective of G
is to capture the distribution of some target data (e.g. distributions of pixel intensity in images). D
aids the training of G by examining the data generated by G in reference to “real” data, and thereby
helping G learn the distribution that underpins the real data. GAN is fleshed out in Goodfellow et al.
(2014) [1] as a pair of simple neural networks. However in practice, the models can in principle be
any generative-discriminative pairs2
In the original work [1,2] and elsewhere [3], GAN has been given the analogy as a process of making
counterfeit money: G plays the role of a counterfeiter in-training, while D the bank strives to identify
fake bills and in the process (hopefully unintentionally) helps G honing its bill-making skills. More
concretely, let x ∼ pdata be characterizing features for real bills, and G(z) be features G creates
from some noise distribution z ∼ pz . Further let J be some quantitative metric which measures the
extent to which a bill is real. Then, D’s job is to lower J(G(z)) (the score of fake bill) while increase
the score J(x) (the score of real bill) for more successful identification. G, on the other hand, aims
at increasing J(G(z)) (i.e. improving the quality of fake bill) by learning from “observing” how D
make differentiations. As the game of “busting fake bill” and “making better fake bill” proceeds, the
model distribution pG draws closer to pdata , and eventually reaches an equilibrium [2] where D can
no longer classfiy better than chance (i.e. D(x) = D(G(z)) = 12 ). Now we say G has arrived at an
optimal point3 in counterfeiting.
GAN has gained massive attraction in computer vision [4,5,6], feature representation [7], and more
recently in Natural Language Processing (NLP) tasks: Document Modeling [8], Dialogue Generation
[9], Sentiment Analysis [10], and Domain Adaptation [11]. In Section 5 I briefly exemplify the
successful application of GAN in NLP [8,9].
1
https://github.com/suwangcompling/GAN-tutorials.
2
These can be any generative and discriminative models, in a much wider sense than the term “G-D pair” is
used in the literature [12].
3
I will show that the optimal point is reached iff pG = pdata (Section 2)
Algorithmically, the training takes the form of an alternation process between minimization and
maximization, which is described in Algorithm 1. In practice, Eq. 1 often does not bring the model to
equilibrium, this is because log(1 − D(G(z))) rapidly saturates in the early stage of training, where
D easily rejects G(z) because G generates fake data of poor quality, such that they conspicuously
differ from the real data. Therefore, rather than evaluating how bad fake data are (G’s objective in
the second term of Eq. 1), we instead evaluate how good they are by setting G’s goal to maximizing
JG (G) = logD(G(z)). In so doing we end up with two objective functions:
max JD (D, G)
D
(3)
max JG (G)
G
To illustrate the training process graphically, we observe (a) how pG changes over time, and conse-
quently (b) how the discrimination boundary of D changes accordingly (Figure 1, Figure 1 in [1]).
Figure 1: Training Process of GAN (green solid: pG ; black dotted: pdata ; blue dash: D’s discrimina-
tion boundary; arrows: generation of fake data)
In (a) through (d), z are sampled uniformly from the noise prior pz , and pG draws closer to pdata . In
the process, the descrimination boundary changes accordingly, and finally morphs into a flat line (i.e.
D(x) = D(G(z)) = 12 ) which indicates D is now unable to tell fake and real data apart. Specifically,
Figure 1 shows a scenario where the model is near convergence: (a) pG is similar to pdata , and
D is now partially accurate6 ; (b) D is updated: based on the relative distribution of pG and pdata ,
D converges in the inner loop of Algorithm 1 to D∗ (x) = pdatapdata (x)
(x)+pG (x) ; (c) G is updated: pG
4
The generator is essentially a function that maps a noise datum into the space of a real datum, i.e. G : z 7→ x.
Note that z and x therefore do not have to be equal in dimensionality.
5
The first term is independent of G.
6
In the late stage of training, G starts to generate high quality fakes, to the effect that D’s classification
performance suffers (but not entirely down to the level of randomness, i.e. D(x) = D(G(z)) = 12 ).
2
Algorithm 1 Minmax Game
1: for specified # of training iterations do
. T RAINING D ISCRIMINATOR
2: for specified k steps* do
3: Draw minibatch of m noise samples {z (1) , . . . , z (m) } ∼ pz .
4: Draw minibatch of m data samples {x(1) , . . . , x(m) } ∼ pdata .
5: Update D’s parameters by gradient ascent:
m
1 Xh i
∇θ d logD(x(i) ) + log 1 − D(G(z (i) ))
m i=1
6: end for
. T RAINING G ENERATOR
7: Draw minibatch of m noise samples {z (1) , . . . , z (m) } ∼ pz .
8: Update G’s parameters by gradient descent:
m
1 X
∇θ g log 1 − D(G(z (i) ))
m i=1
9: end for
*k is a tunable hyperparameter which is usually set to 1 to lower the training cost of each iteration.
is drawn closer to pdata under the guidance of the gradient of D; (d) final convergence: assuming
sufficient model capacity, the adversarial pair reach the equilibrium pG = pdata , where D(x) = 21 .
Analysis. We now look at (b)-(d) in Figure 1 analytically to understand why the training scheme
works. We begin by addressing (b): how does D converge to D∗ (x) in the inner loop?
Statement 1. For fixed G, the optimal discriminator D is
∗ pdata (x)
DG (x) = (4)
pdata (x) + pG (x)
Proof. Given a fixed G, the training objective for D is to maximize J(D, G), where
J(G, D) = Ex∼pdata [logD(x)] + Ez∼pz [log(1 − D(G(z)))]
Z Z
= pdata (x)log(D(x))dx + pz (z)log(1 − D(G(z)))dz
Zx z
The last equality employs change of variable for the second term of the penultimate formula: as
G : z 7→ x, (i) integrating over pz (z) is equivalent to integrating over pG (x), and (ii) integrating
over G(z) is equivalent to integrating over x. Further, we have7
D∗ = argmax J(G, D)
D
Z
= argmax pdata (x)log(D(x)) + pG (x)log(1 − D(x))dx (6)
D x
where the integrand can be abstracted in the form f (D) = alog(D) + blog(1 − D). By setting
∂f ∗ a
∂D , 0 and solving for D, we have D = a+b , satisfying Eq. 4, concluding the proof.
Next we look at (c,d) in the figure and ask: how does updating G get us to pG = pdata ?
Statement 2. Let C(G) = maxJ(G, D), i.e. C(G) is G’s minimization objective (cf. Eq. 2). The
D
global minimum of C(G) is reached iff pG = pdata , at which point C(G) = −log4.
7
The last equality in Eq. 6 is copied from the results of Eq. 5.
3
Proof. First we prove pG = pdata ⇒ C(G) = −log4. We know that
C(G) = maxJ(G, D)
D
∗ ∗
= Ex∼pdata [logDG (x)] + Ez∼pz [log(1 − DG (G(z)))]
∗ ∗
= Ex∼pdata [logDG (x)] + Ex∼pG [log(1 − DG (x))] (by change of variable (cf. Eq. 5))
pdata (x) pG (x)
= Ex∼pdata log + Ex∼pG log (7)
pdata (x) + pG (x) pdata (x) + pG (x)
∗
We also know that if pdata = pG , then DG (x) = 12 (by Eq. 4), which we plug in Eq. 7 to obtain
1 1
C(G) = log 2 + log 2 = −log4. That is, pG = pdata ⇒ C(G) = −log4. We now show that
C(G) = −log4 ⇒ pG = pdata .
pdata (x) pG (x)
C(G) = Ex∼pdata log + Ex∼pG log
pdata (x) + pG (x) pdata (x) + pG (x)
pdata (x) pG (x)
= −log4 + log4 + Ex∼pdata log + Ex∼pG log
pdata (x) + pG (x) pdata (x) + pG (x)
pdata (x) pG (x)
= −log4 + Ex∼pdata log 2 · + Ex∼pG log 2 ·
pdata (x) + pG (x) pdata (x) + pG (x)
" # " #
pdata (x) pG (x)
= −log4 + Ex∼pdata log p (x)+p (x) + Ex∼pG log p (x)+p (x)
data G data G
2 2
pdata (x) + pG (x)
pdata (x) + pG (x)
= −log4 + KL pdata
+ KL pG
2
2
= −log4 + 2 · JS(pdata k pG ) (by the def. of Jensen-Shannon Divergence) (8)
Now given C(G) = log4, we must have JS(pdata k pG ) = 0, which is only true when pG = pdata .
Thus we have shown C(G) = −log4 ⇒ pG = pdata .
Statement 1 and 2 show the optima in the alternating updates in Algorithm 1 lead us to the equilibrium
of the Minmax Game. The original work [1] also proves the convergence of Algorithm 1. However
the derivation involves advanced knowledge of optimization, I thus refer the interested reader to the
proof there.
3 Basic Implementation
This section presents a demo implementation with MNIST image reconstruction8 . Specifically, we
have the discriminator D as a convolutional net, and the generator G a deconvolutional net [24]. A
simple GAN graph is as follows (the precise setup of the networks is omitted to defer to the full code
Jupyter notebook):
import tensorflow as tf
# Discriminator loss
8
For full code see https://github.com/suwangcompling/GAN-tutorials/blob/master/Basic
GAN (MNIST demo).ipynb
4
Figure 2: Left: real image; Center: generated image before training; Right: generated image after
training. The convolutional net has two conv-avgpool layers and two fully-connected layers; the
deconvolutional net has four layers and does exponential upscaling.
# Generator loss
# J(D,G) = E[log(1-D(Gz))]
# NB: only update G’s params in training.
g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits = Dg,
labels = tf.ones_like(Dg)))
# the above minimizes E[log(1-D(Gz))] by pushing D(G(z)) to label=1 (fool
discriminator).
It should be noted that GAN in its vanilla form is notoriously difficult to train and sensitive to hyper-
parameter or architectural setup. In the next section we briefly discuss the Wasserstein GAN [26,27],
a massive improvement on the original GAN which is more amenable to practical applications.
5
these issues. It is “... continuous everywhere and differentiable (almost) everywhere.” (Theorem
1.2, [26]). Let pdata , pg be the distributions of the real data and the generator respectively, the EM
distance can be understand as the minimal effort required to move probability mass from pdata to
transform it into pg . It is defined as follows:
W (pdata , pg ) = inf E(x,y)∼γ [k x − y k] (9)
γ∈Π(pdata ,pg )
where Π(pdata , pg ) is the set of all joint distributions γ whose marginals are pdata and pg . Because
the EM distance is intractable in exact, Arjovsky et al. propose to approximate it by applying the
Kantorovich-Rubinstein duality, which says W is equivalent to the following:
W (pdata , pg ) = sup Ex∼pdata [f (x)] − Ex∼pg [f (x)] (10)
kf kL ≤1
where the supremum is taken over all 1-Lipschitz functions, and f is a general notation for a
function, which, in the context of GAN, denotes the discriminator D (it is also known as the critic).
Sidestepping the sophisticated mathematical ideas involved, one only needs to understand that the
K-Lipschitz condition provides guarantee that the EM distance is continuous and differentiable
everywhere with only minor assumptions (cf. Theorem 1. [26]). With this approximation, the
gradient is now easily calculated:
∇θ W (pdata , pg ) = ∇θ {Ex∼pdata [Dw (x)] − Ez∼pz Dw (Gθ (z))}
= −Ez∼pz [∇θ Dw (Gθ (z))]
where θ are the parameters of the generator G, z ∼ pz the random noise. Finally, Arjovsky et al. also
propose a weight clipping technique which guarantees the family of distributions the discriminator
can take on, i.e. Dw , w ∈ W is K-Lipschitz — the weights are constrained to lie within [−c, c],
where c is some (usally small) parameter (e.g. 0.01 is a popular default). The weight-clipping takes
place after the weight update.
Despite of the complexity in the mathematical argument of WGAN, implementation-wise9 it is quite
simple: we simply replace the loss calculation in the code of the previous section with the following:
# Approximating the Earth Mover (EM) distance
d_loss = tf.reduce_mean(Dx) - tf.reduce_mean(Dg)
g_loss = tf.reduce_mean(Dg)
While theoretically sound and general being superior to the vanilla GAN, WGAN as presented above
still exhibits undesired behavior in practice. In particular it is prone to (i) experience exploding or
vanishing gradients; (ii) fails to match higher-order moments in pulling pdata and pg close. As a
remedy which is very robust in practice, Gulrajani et al. [27] suggest a gradient penalty method to
replace the weight clipping method (to which they attribute the behaviors above). The following
penalty term is added in the loss function:
λEx̂∼px̂ (k ∇x̂ D(x̂) k2 −1)2
(11)
where x̂ = · x + (1 − ) · G(z), an interpolation of the true data and the generated data. We will
not explore the details of this method. Interested readers are encouraged to read the original paper
which is very well and clearly written [27]. The gradient penalty is implemented as follows (added
after the WGAN implementation and get rid of weight-clipping)10 :
# Gradient penalty
# Gulrajani et al. (2017), Algorithm 1.
9
https://github.com/suwangcompling/GAN-tutorials/blob/master/Wasserstein GAN
(original weight-clipping, MNIST demo).ipynb.
10
https://github.com/suwangcompling/GAN-tutorials/blob/master/Wasserstein GAN
(gradient-penalty, MNISTdemo).ipynb.
6
Figure 3: Adversarial Document Model: a variant of Energy-based GAN
Finally two training tips: (i) the scaling factor λ works well empirically across a wide range of tasks
when set to 10.0, as per the recommendation in the paper; (ii) the results are generally better if we
replace batch norm with layer norm for the discriminator/critic.
5 Applications in NLP
As briefly mentioned in Section 1, recent years have seen applications of GAN in the space of NLP. In
this section, I take Document Modeling [8] and Dialogue Generation [9] as examples to demonstrate
the potential of GAN as a flexible and versatile learning framework.
Document modeling. Glover (2016) [8] proposes an Energy-based GAN [13] (ADM)11 where the
generator is a regular MLP, while the discrimator is a Denoising Autoencoder (DAE) [14] (as the
energy function [15]) instead. The model is pitched against a highly well-tuned Restricted Boltzmann
Machine (RBM) [17] based system DocNADE [17], and a baseline (a stand-alone discriminative
DAE classifier). The architecture is illustrated graphically in Figure 2.
In ADM, a document is modeled as a binary bag-of-words vector x ∈ {0, 1}V , where V is the size
of vocabulary. D (the DAE discriminator) takes two inputs: (i) a “real” document vector x, and (ii) a
fake document vector G(z) generated by G (an MLP). In input vector is first processed through a
corruption process12 C to obtain vector xc and feed xc to a regular encode-decode component [14]
(Enc, Dec).
h = f (W e xc + be ) (12)
d
y = W h + bd (13)
where (W e , be ), (W d , bd ) are the parameters of the encoder/decoder, respectively. The quality of
the input is evaluated by Mean Square Error (MSE) as a metric for reconstruction error.
V
1 X
M SE = (xi − yi )2 (14)
V i=1
11
The model is listed in [8] with the name Adversarial Document Model.
12
Vincent et al. (2010) [14] show that formulating a autoencoder reconstruction task as a denoising task (with
corrupted input) helps the autoencoder to generalize better.
7
Figure 4: ADM results
The process amounts to the same effect as Minmax Game (Eq. 3), in that D increases the score it
assigns the real data (i.e. x) while lowering the score for the fake data (i.e. G(z)) generated by G,
and G increases the score of the fake data.
In his experiments, [8] formulates a document classification task with the 20 Newsgroup Dataset
[18], where the model takes a query document d as input and produces a set of output documents that
are closest14 to d. The results are shown in Figure 3 as a precision-recall curve. While producing
overall weaker performance than the state-of-the-art DocNADE, ADM demonstrates its power by
defeating the strong baseline DAE models by a respectable margin, showing its promise15 in more
sophisticated formulation.
Dialogue generation. Li et al., (2017) [9] presents an interesting GAN-based dialogue generation
system. They formulate a reinforcement learning [20] based Turning Test where the goal is for the
generator G, under the guidance of a discriminator D, to learn to generate realistic responses to
input sentences that are indistinguishable from responses given by humans. D guides G by giving it
feedbacks in the form of reinforcement rewards: Positive reward for realistic responses, and negative
reward for non-realistic ones.
Concretely, G takes the form of a sequence-to-sequence Recurrent Neural Network (RNN) [21] which
generates response y based on a dialogue history x. D is an autoencoder-based binary classifier [22]
that takes a (dialogue history, response) pair {x, y}, and produces a label indicating whether the input
is generated by a human (i.e. real data) or G (i.e. fake data). It further returns a reward score to guide
G: Q+ ({x, y}), Q− ({x, y}) for positive and negative rewards, respectively.
Deferring the details of the model to the original work [9], I only show some sample outputs from
its two variants16 evaluated therein to demonstrate the quality of the generated responses (Figure 4).
Here, the input is a dialogue probe, and the following lines are the respective responses of the models
give to the probe. The two GAN’s produce apparently superior responses by human judgment. In
addition to generating good responses, [9] also shows the reliability of their model (see Table 3 in the
original work).
13
One may consider an energy function as a family of loss function [15].
14
By cosine similarity.
15
In image generation tasks, while sometimes fall short behind VAE [19] in classification tasks [2], GAN
often times generates more realistic images by human judgment. This thus leads us to reasonably believe in its
potential in document-based tasks.
16
REGS improves on REINFORCE by ameliorating its tendency for mode collapose — for G to generate the
same fake data over and over again [2].
8
Figure 5: Dialogue Generation (REINFORCE and REGS are GANs)
6 Conclusion
In this tutorial, I started by giving an in-detail description of GAN (Section 1) and step-by-step
derivation in related proofs (Section 2). In Section 3 and 4 I presented simple implementations of the
original GAN and the Wasserstein GAN. Finally I demonstrated the flexibility of GAN as a novel
neural net architecture with examples of its application in NLP (Section 5).
References
[1] Goodfellow, I.J., Pouget-Abadie, J., Mirza, M., Xu, B., Warde-Farley, D., Ozair, S., Courville, A. & Bengio,
Y. (2014) Generative Adversarial Nets. NIPS 2014.
[2] Goodfellow, I.J. (2016) NIPS 2016 Tutorial: Generative Adversarial Networks. NIPS 2016. CoRR,
arXiv:1701.00160.
[3] Jang, E. (2015) Generative Adversarial Nets in Tensorflow. Eric Jang’s Blog: blog.evjang.com/2016/
06/generative-adversarial-nets-in.html.
[4] Lotter, W., Kreiman, G. & Cox, D. (2015) Unsupervised Learning of Visual Structure Using Predictive
Generative Networks. CoRR, arXiv:1151.06380.
[5] Ledig, C., Theis, L., Huszar, F., Caballero, J., Aitken, A. P., Tejani, A., Totz, J., Wang, Z., & Shi,
W. (2016) Photo-realistic single Image Super-resolution Using a Generative Adversarial Network. CoRR,
arXiv:1609.04802.
[6] Isola, P., Zhu, J.-Y., Zhou, T., & Efros, A. A. (2016). Image-to-image Translation with Conditional
Adversarial Networks. CoRR, arXiv:1611.07004.
[7] Donahue, J., Krähenbühl, P. & Darrell, T. (2017) Adversarial Feature Learning. CoRR, arXiv:1605.09782.
[8] Glover, J. (2016) Modeling Documents with Generative Adversarial Networks. CoRR, arXiv:1612.09122.
[9] Li, J., Monroe, W., Shi, T., Jean, S., Ritter, A. & Jurafsky, D. (2017) Adversarial Learning for Neural
Dialogue Generation. CoRR, arXiv:1701.06547.
[10] Chen X., Athiwaratkun, B., Sun, Y., Weinberger, K. & Cardie, C. (2016) Adversarial Deep Averaging
Networks for Cross-lingual Sentiment Classification. CoRR, arXiv:1606.01614.
[11] Zhang, Y., Barzilay, R. & Jaakkola, T. (2017) Aspect-augmented Adversarial Networks for Domain
Adaptation. CoRR, arXiv:1701.00188.
[12] Ng, A. & Jordan, M.I. (2002) On Discriminative vs. Generative Classifiers: A Comparison of Logistic
Regression and Naïve Bayes. NIPS 2002.
[13] Zhao, J., Mathieu, M. & LeCun Y. (2016) Energy-based Generative Adversarial Network. CoRR,
arXiv:1609.03126.
9
[14] Vincent, P., Larochelle, H., Lajoie, I., Bengio, Y. & Manzagol, P-A. (2010) Stacked Denoising Autoencoders:
Learning Useful Representations in a Deep Network with a Local Denoising Criterion. JMLR 2010.
[15] LeCun, Y., Chopra, S., Hadsell, R., Ranzato, M-A. & Huang, F.J. (2006) A Tutorial on Energy-based
Learning. In Bakir, G., Hofman, T., Schölkopf, B., Smola, A. & Taskar, B. (eds.) Predicting Structured Data.
MIT Press.
[16] Hinton, G.E. (2002) Training Products of Experts by Minimizing Contrastive Divergence. Neural Computa-
tion, vol. 14, no. 8, pp. 1607–1614.
[17] Larochelle, H. & Lauly, S (2012) A Neural Autoregressive Distribution Estimator. NIPS 2012.
[18] Lang, K. (1995) Newsweeder: Learning to Filter News. ICML 1995.
[19] Kingma, D.P. & Welling, M. (2014) Auto-encoding Variational Bayes. ICLR 2014.
[20] Williams, R.J. (1992) Simple Statistical Gradient-following Algorithms for Connectionist Reinforcement
Learning. Machine Learning, vol. 8 (3-4): 229–256.
[21] Sutskever, I., Vinyals, O. & Quoc, V.L. (2014) Sequence to Sequence Learning with Neural Networks. NIPS
2014.
[22] Li, J., Luong, M-T. & Jurafsky, D. (2015) A Hierarchical Neural Autoencoder for Paragraphs and
Documents. CoRR, arXiv:1506.01057.
[23] Gutmann, M. & Hyvärinen A. (2010) Noise-contrastive estimation: A new estimation principle for
unnormalized statistical models. In Proceedings of AISTATS. Sardina, Italy.
[24] Zeiler, M.D., Krishnan, D., Taylor, G.W., Fergus, R. (2010) Deconvolutional Networks. In Proceedings of
CVPR.
[25] Radford, A., Metz, L., Chintala, S. (2016) Unsupervised Representation Learning with Deep Convolutional
Generative Adversarial Networks. In Proceedings of ICLR.
[26] Arjovsky, M., Chintala, S., Bottou, L. (2017) Wasserstein GAN. CoRR.
[27] Gulrajani, I., Ahmed, F., Arjovsky, M., Dumoulin, V., Courville, A. (2017) Improved Training of Wasserstein
GANs. In Proceedings of NIPS.
10