A Statistical Machine Learning Perspective of Deep Learning: Algorithm, Theory, Scalable Computing Maruan Al-Shedivat, Zhiting Hu, Hao Zhang, and Eric Xing Petuum Inc & Carnegie Mellon University
Element of AI/Machine Learning Task
Model
Algorithm Implementation System Platform and Hardware
• Graphical Models
• Large-Margin
• Deep Learning
• Sparse Coding
• Nonparametric Bayesian Models
• Regularized Bayesian Methods
• Spectral/Matrix Methods
• Sparse Structured I/O Regression
• Stochastic Gradient Descent / Back propagation • Mahout (MapReduce)
Hadoop • Network switches • Infiniband
• Coordinate Descent
• Mllib (BSP)
Spark • Network attached storage • Flash storage
• L-BFGS
• Gibbs Sampling
• CNTK
• MxNet
MPI
RPC
• Server machines • Desktops/Laptops • NUMA machines • Mobile devices • GPUs, CPUs, FPGA, TPU • ARM-powered devices
• MetropolisHastings
• Tensorflow (Async)
GraphLab
…
…
• RAM • Cloud compute • Virtual • Flash (e.g. Amazon EC2) machines • SSD • IoT networks • Data centers
© Petuum,Inc.
1
ML vs DL
© Petuum,Inc.
2
Plan • Statistical And Algorithmic Foundation and Insight of Deep Learning
• On Unified Framework of Deep Generative Models
• Computational Mechanisms: Distributed Deep Learning Architectures © Petuum,Inc.
3
Part-I Basics
Outline • Probabilistic Graphical Models: Basics • An overview of DL components
• Historical remarks: early days of neural networks • Modern building blocks: units, layers, activations functions, loss functions, etc. • Reverse-mode automatic differentiation (aka backpropagation)
• Similarities and differences between GMs and NNs • Graphical models vs. computational graphs • Sigmoid Belief Networks as graphical models • Deep Belief Networks and Boltzmann Machines
• Combining DL methods and GMs
• Using outputs of NNs as inputs to GMs • GMs with potential functions represented by NNs • NNs with structured outputs
• Bayesian Learning of NNs
• Bayesian learning of NN parameters • Deep kernel learning © Petuum,Inc.
5
Outline • Probabilistic Graphical Models: Basics • An overview of DL components
• Historical remarks: early days of neural networks • Modern building blocks: units, layers, activations functions, loss functions, etc. • Reverse-mode automatic differentiation (aka backpropagation)
• Similarities and differences between GMs and NNs • Graphical models vs. computational graphs • Sigmoid Belief Networks as graphical models • Deep Belief Networks and Boltzmann Machines
• Combining DL methods and GMs
• Using outputs of NNs as inputs to GMs • GMs with potential functions represented by NNs • NNs with structured outputs
• Bayesian Learning of NNs
• Bayesian learning of NN parameters • Deep kernel learning © Petuum,Inc.
6
Fundamental questions of probabilistic modeling • Representation: what is the joint probability distr. on multiple variables?
!(#$ , #& , #' , … , #) ) • How many state configurations are there? • Do they all need to be represented? • Can we incorporate any domain-specific insights into the representation?
• Learning: where do we get the probabilities from? • Maximum likelihood estimation? How much data do we need? • Are there any other established principles?
• Inference: if not all variables are observable, how to compute the conditional distribution of latent variables given evidence? • Computing !(+|-) would require summing over 2/ configurations of the unobserved variables © Petuum,Inc.
7
What is a graphical model? • A possible world of cellular signal transduction
© Petuum,Inc.
8
GM: structure simplifies representation • A possible world of cellular signal transduction
© Petuum,Inc.
9
Probabilistic Graphical Models • If #0 ’s are conditionally independent (as described by a PGM), then the joint can be factored into a product of simpler terms ! #$ , #& , #' , #2 , #3 , #/ , #4 , #1 = ! #$ ! #& ! #' #$ ! #2 #& ! #3 #& !(#/ |#' , #2 )!(#4 |#/ )!(#1 |#3 , #/ )
• Why we may favor a PGM? • Easy to incorporate domain knowledge and causal (logical) structures • Significant reduction in representation cost (21 reduced down to 18) © Petuum,Inc. 10
The two types of GMs
!(+|@) q = argmaxq !q(@)
• Directed edges assign causal meaning to the relationships (Bayesian Networks or Directed Graphical Models) ! #$ , #& , #' , #2 , #3 , #/ , #4 , #1 = ! #$ ! #& ! #' #$ ! #2 #& ! #3 #& !(#/ |#' , #2 )!(#4 |#/ )!(#1 |#3 , #/ )
• Undirected edges represent correlations between the variables (Markov Random Field or Undirected Graphical Models) ! #$ , #& , #' , #2 , #3 , #/ , #4 , #1 = 1 exp {= #$ + = #& + = #$ , #' + = #& , #2 + = #3 , #& + 7 = #' , #2 , #/ + = #/ , #4 + = #3 , #/ , #1 } © Petuum,Inc. 11
Outline • Probabilistic Graphical Models: Basics • An overview of DL components
• Historical remarks: early days of neural networks • Modern building blocks: units, layers, activations functions, loss functions, etc. • Reverse-mode automatic differentiation (aka backpropagation)
• Similarities and differences between GMs and NNs • Graphical models vs. computational graphs • Sigmoid Belief Networks as graphical models • Deep Belief Networks and Boltzmann Machines
• Combining DL methods and GMs
• Using outputs of NNs as inputs to GMs • GMs with potential functions represented by NNs • NNs with structured outputs
• Bayesian Learning of NNs
• Bayesian learning of NN parameters • Deep kernel learning © Petuum,Inc. 12
Perceptron and Neural Nets • From biological neuron to artificial neuron (perceptron) McCulloch & Pitts (1943)
Inputs x1 w1
Linear Combiner
Hard Limiter
Output
å w2
Y q
x2
Threshold
• From biological neuron network to artificial neuron networks
Soma Dendrites
Dendrites
Synapse
Axon
Soma
I n p u t Si g n a l s
Axon
O u t p u t Si g n a l s
Synapse
Synapse
Middle Layer Input Layer
Output Layer
© Petuum,Inc. 13
The perceptron learning algorithm
• Recall the nice property of sigmoid function • Consider regression problem f: XàY, for scalar Y: • We used to maximize the conditional data likelihood
• Here … © Petuum,Inc. 14
The perceptron learning algorithm xd = input td = target output od = observed output wi = weight i
Incremental mode: Do until converge: § For each training example d in D 1. compute gradient ÑEd[w] Batch mode:
2.
Do until converge:
where
1. compute gradient ÑED[w] 2. © Petuum,Inc. 15
Neural Network Model Inputs Age
.6
34
Gende r
2
Stage
4
Independent variables
.1 .3
.2
S
. 4 .2
S
.7
Output .5 .8
S
.2
Weights
Hidden Layer
Weights
0.6 “Probability of beingAlive”
Dependent variable Prediction
© Petuum,Inc. 16
“Combined logistic models” Inputs Age Gende r
2
Stage
4
Output
.6
34
.5
.1
S .8
.7
Independent variables
Weights
Hidden Layer
0.6 “Probability of beingAlive”
Weights
Dependent variable Prediction
© Petuum,Inc. 17
“Combined logistic models” Inputs Age
Output
34
.5
.2
Gende r
2
Stage
4
Independent variables
S
.3
“Probability of beingAlive”
.8 .2
Weights
Hidden Layer
0.6
Weights
Dependent variable Prediction
© Petuum,Inc. 18
“Combined logistic models” Inputs Age Gende r
1
Stage
4
Independent variables
Output
.6
34 .1 .3
.5
.2
S .7
“Probability of beingAlive”
.8 .2
Weights
Hidden Layer
0.6
Weights
Dependent variable Prediction
© Petuum,Inc. 19
Not really, no target for hidden units... Age
.6
34
Gende r
2
Stage
4
Independent variables
.1 .3
.2
S
. 4 .2
S
.7
.5 .8
S
.2
Weights
Hidden Layer
Weights
0.6 “Probability of beingAlive”
Dependent variable Prediction
© Petuum,Inc. 20
Backpropagation: Reverse-mode differentiation • Artificial neural networks are nothing more than complex functional compositions that can be represented by computation graphs:
x
2
4
1
Input variables
3 Intermediate computations
5
f (x) Outputs
X @fn @f @fn = @x @fi1 @ i1 2⇡(n)
© Petuum,Inc. 21
Backpropagation: Reverse-mode differentiation • Artificial neural networks are nothing more than complex functional compositions that can be represented by computation graphs:
x
2
4
1
5
3
f (x)
i1 2⇡(n)
• By applying the chain rule and using reverse accumulation, we get X @fn X @fi @fi X @fn @fi @fn 1 1 1 = ... = = @fi1 @fi2 @x @x @fi1 @x i1 2⇡(n)
i1 2⇡(n)
X @fn @f @fn = @x @fi1 @
i2 2⇡(i1 )
• The algorithm is commonly known as backpropagation • What if some of the functions are stochastic? • Then use stochastic backpropagation! (to be covered in the next part) • Modern packages can do this automatically (more later)
© Petuum,Inc. 22
Modern building blocks of deep networks x1 w 1
• Activation functions • Linear and ReLU • Sigmoid and tanh • Etc.
x2
w2
f
f(Wx + b)
output
output
x3 w 3
input
Linear
input
Rectified linear (ReLU) © Petuum,Inc. 23
Modern building blocks of deep networks • Activation functions • Linear and ReLU • Sigmoid and tanh • Etc. • Layers • Fully connected • Convolutional & pooling • Recurrent • ResNets • Etc.
fully connected convolutional
recurrent
source: colah.github.io
blocks with residual connections
© Petuum,Inc. 24
Modern building blocks of deep networks • Activation functions • Linear and ReLU • Sigmoid and tanh • Etc. • Layers • Fully connected • Convolutional & pooling • Recurrent • ResNets • Etc. • Loss functions • Cross-entropy loss • Mean squared error • Etc.
Putting things together: loss
activation
concatenation
fully connected convolutional avg& max pooling (a part of GoogleNet)
© Petuum,Inc. 25
Modern building blocks of deep networks • Activation functions • Linear and ReLU • Sigmoid and tanh • Etc.
Putting things together:
• Layers • Fully connected • Convolutional & pooling • Recurrent • ResNets • Etc. • Loss functions • Cross-entropy loss • Mean squared error • Etc.
(a part of GoogleNet)
l
Arbitrary combinations of the basic building blocks
l
Multiple loss functions – multi-target prediction, transfer learning, and more
l
Given enough data, deeper architectures just keep improving
l
Representation learning: the networks learn increasingly more abstract representations of the data that are “disentangled,” i.e., amenable to linear separation.
© Petuum,Inc. 26
Outline • Probabilistic Graphical Models: Basics • An overview of the DL components
• Historical remarks: early days of neural networks • Modern building blocks: units, layers, activations functions, loss functions, etc. • Reverse-mode automatic differentiation (aka backpropagation)
• Similarities and differences between GMs and NNs • Graphical models vs. computational graphs • Sigmoid Belief Networks as graphical models • Deep Belief Networks and Boltzmann Machines
• Combining DL methods and GMs
• Using outputs of NNs as inputs to GMs • GMs with potential functions represented by NNs • NNs with structured outputs
• Bayesian Learning of NNs
• Bayesian learning of NN parameters • Deep kernel learning © Petuum,Inc. 27
Graphical models vs. Deep nets Graphical models
Deep neural networks
• Representation for encoding meaningful knowledge and the associated uncertainty in a graphical form
l
Learn representations that facilitate computation and performance on the end-metric (intermediate representations are not guaranteed to be meaningful)
© Petuum,Inc. 28
Graphical models vs. Deep nets Graphical models
Deep neural networks
• Representation for encoding meaningful knowledge and the associated uncertainty in a graphical form
l
Learn representations that facilitate computation and performance on the end-metric (intermediate representations are not guaranteed to be meaningful)
• Learning and inference are based on a rich toolbox of well-studied (structure-dependent) techniques (e.g., EM, message passing, VI, MCMC, etc.)
l
Learning is predominantly based on the gradient descent method (aka backpropagation); Inference is often trivial and done via a “forward pass”
• Graphs represent models
l
Graphs represent computation
© Petuum,Inc. 29
Graphical models vs. Deep nets X1
Graphical models Utility of the graph
X2
log P (X) =
X3
• A vehicle for synthesizing a global loss function from local structure
X5
• potential function, feature function, etc.
X i
log (xi ) +
X
log (xi , xj )
i,j
X4
• A vehicle for designing sound and efficient inference algorithms • Sum-product, mean-field, etc.
• A vehicle to inspire approximation and penalization • Structured MF, Tree-approximation, etc.
• A vehicle for monitoring theoretical and empirical behavior and accuracy of inference
E + ~!(+|@)
Utility of the loss function • A major measure of quality of the learning algorithm and the model
q = argmaxq !q(@)
© Petuum,Inc. 30
Graphical models vs. Deep nets Deep neural networks Utility of the network l
A vehicle to conceptually synthesize complex decision hypothesis l
l
A vehicle for organizing computational operations l
l
stage-wise update of latent states
A vehicle for designing processing steps and computing modules l
l
stage-wise projection and aggregation
Layer-wise parallelization
No obvious utility in evaluating DL inference algorithms
Utility of the Loss Function l
Images from Distill.pub
Global loss? Well it is complex and nonconvex... © Petuum,Inc.
31
Graphical models vs. Deep nets Graphical models
Deep neural networks
Utility of the graph
Utility of the network
• A vehicle for synthesizing a global loss function from local structure
l
• potential function, feature function, etc.
• A vehicle for designing sound and efficient inference algorithms
l
l
• Sum-product, mean-field, etc.
• A vehicle to inspire approximation and penalization
l
stage-wise update of latent states
A vehicle for designing processing steps and computing modules l
l
stage-wise projection and aggregation
A vehicle for organizing computational operations l
• Structured MF, Tree-approximation, etc.
• A vehicle for monitoring theoretical and empirical behavior and accuracy of inference
A vehicle to conceptually synthesize complex decision hypothesis
Layer-wise parallelization
No obvious utility in evaluating DL inference algorithms
Utility of the loss function
Utility of the Loss Function
• A major measure of quality of the learning algorithm and the model
l
Global loss? Well it is complex and nonconvex... © Petuum,Inc.
32
DL
Empirical goal:
Graphical models vs. e.g., classification, feature learning
< = ? > Deep
ML (e.g., GM)
nets
e.g., latent variable inference, transfer learning
Structure:
Graphical
Graphical
Objective:
Something aggregated from local functions Something aggregated from local functions
Vocabulary:
Neuron, activation function, …
Algorithm:
A single, unchallenged, inference algorithm A major focus of open research, many – algorithms, and more to come Backpropagation (BP)
Evaluation:
On a black-box score – end performance
On almost every intermediate quantity
Implementation:
Many tricks
More or less standardized
Experiments:
Massive, real data (GT unknown)
Modest, often simulated data (GT known)
Variable, potential function, …
© Petuum,Inc. 33
Graphical Models vs. Deep Nets • So far: • Graphical models are representations of probability distributions • Neural networks are function approximators (with no probabilistic meaning)
• Some of the neural nets are in fact proper graphical models (i.e., units/neurons represent random variables): • • • • •
Boltzmann machines (Hinton & Sejnowsky, 1983) Restricted Boltzmann machines (Smolensky, 1986) Learning and Inference in sigmoid belief networks (Neal, 1992) Fast learning in deep belief networks (Hinton, Osindero, Teh, 2006) Deep Boltzmann machines (Salakhutdinov and Hinton, 2009)
• Let’s go through these models one-by-one
© Petuum,Inc. 34
I: Restricted Boltzmann Machines • RBM is a Markov random field represented with a bi-partite graph • All nodes in one layer/part of the graph are connected to all in the other; no inter-layer connections
• Joint distribution:
1 ! G, ℎ = exp I J0K G0 ℎ0 + I M0 G0 + I NK ℎK 7 0,K
Images from Marcus Frean, MLSS Tutorial 2010
0
K
© Petuum,Inc. 35
I: Restricted Boltzmann Machines • Log-likelihood of a single data point (unobservables marginalized out): log Q G = log I exp I J0K G0 ℎ0 + I M0 G0 + I NK ℎK − log (7) S
0,K
0
K
• Gradient of the log-likelihood w.r.t. the model parameters: T T T log Q G = I !(ℎ|G) !(G, ℎ) − I !(G, ℎ) !(G, ℎ) TJ0K TJ0K TJ0K S
U,S
• where we have averaging over the posterior and over the joint.
Images from Marcus Frean, MLSS Tutorial 2010
© Petuum,Inc. 36
I: Restricted Boltzmann Machines • Gradient of the log-likelihood w.r.t. the parameters (alternative form): T T T log Q G = VW(S|U) !(G, ℎ) − VW(U,S) !(G, ℎ) TJ0K TJ0K TJ0K • • • •
Both expectations can be approximated via sampling Sampling from the posterior is exact (RBM factorizes over ℎ given G) Sampling from the joint is done via MCMC (e.g., Gibbs sampling) In the neural networks literature: • computing the first term is called the clamped / wake / positive phase (the network is “awake” since it conditions on the visible variables) • Computing the second term is called the unclamped / sleep / free / negative phase (the network is “asleep” since it samples the visible variables from the joint; metaphorically, it is ”dreaming” the visible inputs)
© Petuum,Inc. 37
I: Restricted Boltzmann Machines • Gradient of the log-likelihood w.r.t. the parameters (alternative form): T T T log Q G = VW(S|U) !(G, ℎ) − VW(U,S) !(G, ℎ) TJ0K TJ0K TJ0K • Learning is done by optimizing the log-likelihood of the model for a given data via stochastic gradient descent (SGD) • Estimation of the second term (the negative phase) heavily relies on the mixing properties of the Markov chain • This often causes slow convergence and requires extra computation © Petuum,Inc. 38
II: Sigmoid Belief Networks
from Neal, 1992
• Sigmoid belief nets are simply Bayesian networks over binary variables with conditional probabilities represented by sigmoid functions: ! X0 Y X0
= Z X0
I
J0K XK
[\ ∈ ^ [_
• Bayesian networks exhibit a phenomenon called “explain away effect” If A correlates with C, then the chance of B correlating with C decreases. A and B become correlated given C. © Petuum,Inc. 39
II: Sigmoid Belief Networks
from Neal, 1992
• Sigmoid belief nets are simply Bayesian networks over binary variables with conditional probabilities represented by sigmoid functions: ! X0 Y X0
= Z X0
I
J0K XK
[\ ∈ ^ [_
• Bayesian networks exhibit a phenomenon called “explain away effect” Note: Due to the “explain away effect,” when we condition on the visible layer in belief networks, hidden variables all become dependent.
© Petuum,Inc. 40
Sigmoid Belief Networks: Learning and Inference • Neal proposed Monte Carlo methods for learning and inference (Neal, 1992): log derivative Approximated with Gibbs sampling
• Conditional distributions:
prob. of the visibles via marginalization Bayes rule + rearrange sums
•
No negative phase as in RBM!
•
Convergence is very slow, especially for large belief nets, due to the intricate “explain-away” effects… Equations from Neal, 1992
Plug-in the actual sigmoid form of the conditional prob.
© Petuum,Inc. 41
RBMs are infinite belief networks • Recall the expression for the gradient of the log likelihood for RBM: T T T log Q G = VW(S|U) !(G, ℎ) − VW(U,S) !(G, ℎ) TJ0K TJ0K TJ0K • To make a gradient update of the model parameters, we need compute the expectations via sampling. • We can sample exactly from the posterior in the first term • We run block Gibbs sampling to approximately sample from the joint distribution
images from Marcus Frean, MLSS Tutorial 2010
sampling steps
© Petuum,Inc. 42
RBMs are infinite belief networks • Gibbs sampling: alternate between sampling hidden and visible variables
sampling steps
• Conditional distributions !(G|ℎ) and ! ℎ G are represented by sigmoids • Thus, we can think of Gibbs sampling from the joint distribution represented by an RBM as a top-down propagation in an infinitely deep sigmoid belief network! images from Marcus Frean, MLSS Tutorial 2010
© Petuum,Inc. 43
RBMs are infinite belief networks • RBMs are equivalent to infinitely deep belief networks
• Sampling from this is the same as sampling from the network on the right
images from Marcus Frean, MLSS Tutorial 2010
© Petuum,Inc. 44
RBMs are infinite belief networks • RBMs are equivalent to infinitely deep belief networks
images from Marcus Frean, MLSS Tutorial 2010
© Petuum,Inc. 45
RBMs are infinite belief networks • RBMs are equivalent to infinitely deep belief networks
• When we train an RBM, we are really training an infinitely deep brief net! • It is just that the weights of all layers are tied. • If the weights are “untied” to some extent, we get a Deep Belief Network. images from Marcus Frean, MLSS Tutorial 2010
© Petuum,Inc. 46
III: Deep Belief Nets
• DBNs are hybrid graphical models (chain graphs): • Exact inference in DBNs is problematic due to explaining away effect • Training: greedy pre-training + ad-hoc fine-tuning; no proper joint training • Approximate inference is feed-forward © Petuum,Inc. 47
Deep Belief Networks • DBNs represent a joint probability distribution ! G, ℎ$ , ℎ& , ℎ' = ! ℎ& , ℎ' ! ℎ$ ℎ& !(G|ℎ$ ) • Note that ! ℎ& , ℎ' is an RBM and the conditionals ! ℎ$ ℎ& and !(G|ℎ$ ) are represented in the sigmoid form • The model is trained by optimizing the log likelihood for a given data log ! G Challenges: • Exact inference in DBNs is problematic due to explain away effect • Training is done in two stages: • greedy pre-training + ad-hoc fine-tuning; no proper joint training
• Approximate inference is feed-forward (bottom-up)
© Petuum,Inc. 48
DBN: Layer-wise pre-training • Pre-train and freeze the 1st RBM • Stack another RBM on top and train it
• The weights weights 2+ layers remain tied • We repeat this procedure: pre-train and untie the weights layer-by-layer… images from Marcus Frean, MLSS Tutorial 2010
© Petuum,Inc. 49
DBN: Layer-wise pre-training • We repeat this procedure: pre-train and untie the weights layer-by-layer: • The weights of 3+ layers remain tied
• and so forth • From the optimization perspective, this procedure loosely corresponds to an approximate block-coordinate accent on the log-likelihood images from Marcus Frean, MLSS Tutorial 2010
© Petuum,Inc. 50
DBN: Fine-tuning • Pre-training is quite ad-hoc and is unlikely to lead to a good probabilistic model per se • However, the layers of representations could perhaps be useful for some other downstream tasks! • We can further “fine-tune” a pre-trained DBN for some other task Setting A: Unsupervised learning (DBN → autoencoder) 1. Pre-train a stack of RBMs in a greedy layer-wise fashion 2. “Unroll” the RBMs to create an autoencoder 3. Fine-tune the parameters by optimizing the reconstruction error images from Hinton & Salakhutdinov, 2006
© Petuum,Inc. 51
DBN: Fine-tuning • Pre-training is quite ad-hoc and is unlikely to lead to a good probabilistic model per se • However, the layers of representations could perhaps be useful for some other downstream tasks! • We can further “fine-tune” a pre-trained DBN for some other task Setting A: Unsupervised learning (DBN → autoencoder) 1. Pre-train a stack of RBMs in a greedy layer-wise fashion 2. “Unroll” the RBMs to create an autoencoder 3. Fine-tune the parameters by optimizing the reconstruction error images from Hinton & Salakhutdinov, 2006
© Petuum,Inc. 52
DBN: Fine-tuning • Pre-training is quite ad-hoc and is unlikely to lead to a good probabilistic model per se • However, the layers of representations could perhaps be useful for some other downstream tasks! • We can further “fine-tune” a pre-trained DBN for some other task Setting A: Unsupervised learning (DBN → autoencoder) 1. Pre-train a stack of RBMs in a greedy layer-wise fashion 2. “Unroll” the RBMs to create an autoencoder 3. Fine-tune the parameters by optimizing the reconstruction error images from Hinton & Salakhutdinov, 2006
© Petuum,Inc. 53
DBN: Fine-tuning • Pre-training is quite ad-hoc and is unlikely to lead to a good probabilistic model per se • However, the layers of representations could perhaps be useful for some other downstream tasks! • We can further “fine-tune” a pre-trained DBN for some other task Setting B: Supervised learning (DBN → classifier) 1. Pre-train a stack of RBMs in a greedy layer-wise fashion 2. “Unroll” the RBMs to create a feedforward classifier 3. Fine-tune the parameters by optimizing the reconstruction error Some intuitions about how pre-training works: Erhan et al.: Why Does Unsupervised Pre-training Help Deep Learning? JMLR, 2010
© Petuum,Inc. 54
Deep Belief Nets and Boltzmann Machines
• DBNs are hybrid graphical models (chain graphs): • Inference in DBNs is problematic due to explaining away effect • Training: greedy pre-training + ad-hoc fine-tuning; no proper joint training • Approximate inference is feed-forward © Petuum,Inc. 55
Deep Belief Nets and Boltzmann Machines
• DBMs are fully un-directed models (Markov random fields): • Can be trained similarly as RBMs via MCMC (Hinton & Sejnowski, 1983) • Use a variational approximation of the data distribution for faster training (Salakhutdinov & Hinton, 2009) • Similarly, can be used to initialize other networks for downstream tasks © Petuum,Inc. 56
Graphical models vs. Deep networks • A few critical points to note about all these models: • The primary goal of deep generative models is to represent the distribution of the observable variables. Adding layers of hidden variables allows to represent increasingly more complex distributions. • Hidden variables are secondary (auxiliary) elements used to facilitate learning of complex dependencies between the observables. • Training of the model is ad-hoc, but what matters is the quality of learned hidden representations. • Representations are judged by their usefulness on a downstream task (the probabilistic meaning of the model is often discarded at the end).
• In contrast, classical graphical models are often concerned with the correctness of learning and inference of all variables © Petuum,Inc. 57
An old study of belief networks from the GM standpoint
[Xing, Russell, Jordan, UAI 2003]
Mean-field partitions of a sigmoid belief network for subsequent GMF inference
Study focused on only inference/learning accuracy, speed, and partition GMFb GMFr BP
© Petuum,Inc. 58
“Optimize” how to optimize via truncation & re-opt • Energy-based modeling of the structured output (CRF)
• Unroll the optimization algorithm for a fixed number of steps (Domke, 2012)
`$ `a
`3
`' `&
=
`2 Relevant recent paper:
We can backprop through the optimization steps since they are just a sequence of computations
Anrychowicz et al.: Learning to learn by gradient descent by gradient descent. 2016. © Petuum,Inc. 59
Dealing with structured prediction • Energy-based modeling of the structured output (CRF)
• Unroll the optimization algorithm for a fixed number of steps (Domke, 2012)
• We can think of y* as some non-linear differentiable function of the inputs and weights → impose some loss and optimize it as any other standard computation graph using backprop! • Similarly, message passing based inference algorithms can be truncated and converted into computational graphs (Domke, 2011; Stoyanov et al., 2011) © Petuum,Inc. 60
Outline • Probabilistic Graphical Models: Basics • An overview of DL components
• Historical remarks: early days of neural networks • Modern building blocks: units, layers, activations functions, loss functions, etc. • Reverse-mode automatic differentiation (aka backpropagation)
• Similarities and differences between GMs and NNs • Graphical models vs. computational graphs • Sigmoid Belief Networks as graphical models • Deep Belief Networks and Boltzmann Machines
• Combining DL methods and GMs
• Using outputs of NNs as inputs to GMs • GMs with potential functions represented by NNs • NNs with structured outputs
• Bayesian Learning of NNs
• Bayesian learning of NN parameters • Deep kernel learning © Petuum,Inc. 61
Combining sequential NNs and GMs
slide courtesy: Matt Gormley
© Petuum,Inc. 62
Combining sequential NNs and GMs
slide courtesy: Matt Gormley
© Petuum,Inc. 63
Hybrid NNs + conditional GMs
• In a standard CRF, each of the factor cells is a parameter. • In a hybrid model, these values are computed by a neural network. slide courtesy: Matt Gormley
© Petuum,Inc. 64
Hybrid NNs + conditional GMs
slide courtesy: Matt Gormley
© Petuum,Inc. 65
Hybrid NNs + conditional GMs
slide courtesy: Matt Gormley
© Petuum,Inc. 66
Using GMs as Prediction Explanations
• Idea: Use deep neural nets to generate parameters of a graphical model for a given context (e.g., specific instance or case) • Produced GMs are used to make the final prediction • GMs are built on top of interpretable variables (not deep embeddings!) and can be used as contextual explanations for each prediction Al-Shedivat, Dubey, Xing, arXiv, 2017
© Petuum,Inc. 67
Using GMs as Prediction Explanations θ
Dictionary
Context Encoder
dot
Y1
Y2
Y3
Y4
X1
X2
X3
X4
Y1
Y2
Y3
Y4
CEN
X
MoE
dot Context
Attention
Y1
Y2
Y3
Y4
Y1
Y2
Y3
Y4
Y1
Y2
Y3
Y4
X1
X2
X3
X4
X1
X2
X3
X4
X1
X2
X3
X4
Attributes
A practical implementation: • Maintain a (sparse) dictionary of GM parameters • Process complex inputs (images, text, time series, etc.) using deep nets; use soft attention to either select or combine models from the dictionary • Use constructed GMs (e.g., CRFs) to make predictions • Inspect GMs to understand the reasoning behind predictions Al-Shedivat, Dubey, Xing, arXiv, 2017
© Petuum,Inc. 68
Outline • An overview of the DL components
• Historical remarks: early days of neural networks • Modern building blocks: units, layers, activations functions, loss functions, etc. • Reverse-mode automatic differentiation (aka backpropagation)
• Similarities and differences between GMs and NNs • Graphical models vs. computational graphs • Sigmoid Belief Networks as graphical models • Deep Belief Networks and Boltzmann Machines
• Combining DL methods and GMs
• Using outputs of NNs as inputs to GMs • GMs with potential functions represented by NNs • NNs with structured outputs
• Bayesian Learning of NNs
• Bayesian learning of NN parameters • Deep kernel learning © Petuum,Inc. 69
Bayesian learning of NNs • A neural network as a probabilistic model: • Likelihood: b ` X, c
Weight Uncertainty
• Categorical distribution for classification ⇒ cross-entropy loss Y • Gaussian distribution for regression ⇒ squared loss
• Prior on parameters: b c
• Maximum a posteriori (MAP) solution: • cefW = argmaxg log b ` X, c b(c) • Gaussian prior ⇒ L2 regularization • Laplace prior ⇒ L1 regularization
0.1
0.5
H1
H2 0.1
Y
0.7
H3
1.3
1
H1
H2
H3
X
1
1
0.1 0.3 1.4
0.2 1.2
X
1
Figure courtesy: Blundell et al, 2016
• Bayesian learning [MacKay 1992, Neal 1996, de Figure Freitas1.2003] Left: each weight has a fixed value, as provided by classical backpropagation. Right: each weight is assigned a distribu• Posterior: b c X, ` tion, as provided by Bayes by Backprop. • Variational inference with approximate posterior h(c) © Petuum,Inc. 70
Bayesian learning of NNs • Variational inference (in a nutshell): minr s t, c = KL h c || b c t
− Er(c) [log b(t|c)]
minr s t, c = KL h c || b c t
− I log b(t|c0 )
where ci ∼ h(c); KL term can be approximated similarly
0
• We can define h c as a diagonal Gaussian or full-covariance Gaussian • Alternatively, h c can be defined implicitly, e.g. via dropout [Gal & Ghahramani, 2016] c = n ⋅ diag | , | ∼ Bernoulli(b) • Dropping out neurons is equivalent to zeroing out columns of the parameter matrices (i.e., weights) • k0 = 0 corresponds to m-th column of n being dropped out ⇒ the procedure is equivalent to dropout of unit m [Hinton et al., 2012] • Variational parameters are {n, o}
© Petuum,Inc. 71
“Infinitely Wide” Deep Models • We have seen that an ”infinitely deep” network can be explained by a proper GM, How about an “infinitely wide” one? • Consider a neural network with a Gaussian prior on its weights an infinitely many hidden neurons in the intermediate layer. • Turns out, if we have a certain Gaussian prior on the weights of such infinite network, it will be equivalent to a Gaussian process [Neal 1996].
Infinitely many hidden units
• Gaussian process (GP) is a distribution over functions:
• When used for prediction, GPs account for correlations between the data points and can output well-calibrated predictive uncertainty estimates.
© Petuum,Inc. 72
Gaussian Process and Deep Kernel Learning • Consider a neural network with a Gaussian prior on its weights an infinitely many hidden neurons in the intermediate layer. Infinitely many hidden units
• Certain classes of Gaussian priors for neural networks with infinitely many hidden units converge to Gaussian processes [Neal 1996] • Deep kernel [Wilson et al., 2016] •
Combines the inductive biases of deep model architectures with the non-parametric flexibility of Gaussian processes Ä X0 , XK Å → Ä(â X0 , Ç , â(XK , Ç)|Å, Ç) where 0K = Ä(X0 , XK )
b ` É = Ñ(`|É, Ö Ü$ ) b É Å = Ñ(É|á(X), )
•
Starting from a base kernel Ä(X0 , XK |Å), transform the inputs X as
•
Learn both kernel and neural parameters Å, Ç jointly by optimizing marginal log-likelihood (or its variational lower-bound).
•
Fast learning and inference with local kernel interpolation, structured inducing points, and Monte Carlo approximations
© Petuum,Inc. 73
Gaussian Process and Deep Kernel Learning • By adding GP as a layer to a deep neural net, we can think of it as adding an infinite hidden layer with a particular prior on the weights • Deep kernel learning [Wilson et al., 2016] • Combines the inductive biases of deep models with the non-parametric flexibility of Gaussian processes • GPs add powerful regularization to the network • Additionally, they provide predictive uncertainty estimates
© Petuum,Inc. 74
Deep kernel learning on sequential data What if we have data of sequential nature? Can we still apply the same reasoning and build rich nonparametric models on top recurrent nets?
© Petuum,Inc. 75
Deep kernel learning on sequential data The answer is YES! By adding a GP layer to a recurrent network, we effectively correlate samples across time and get predictions along with well calibrated uncertainty estimates. To train such model using stochastic techniques however requires some additional care (see our paper).
Al-Shedivat et al., JMLR, 2017
© Petuum,Inc. 76
Deep kernel learning on sequential data Lane prediction: LSTM vs GP-LSTM Front distance, m
50 40 30 20 10 0
5
0
5
5
0
5
5
0 Side distance, m
5
5
0
5
5
0
5
Front distance, m
50 40 30 20 10 0
5
0
5
Al-Shedivat et al., JMLR, 2017
5
0
5
5
0 5 Side distance, m
5
0
5
5
0
5
© Petuum,Inc. 77
Deep kernel learning on sequential data Lead vehicle prediction: LSTM vs GP-LSTM Front distance, m
100 80 60 40 20 0
5
0
5
5
0
5
5
0 Side distance, m
5
5
0
5
5
0
5
5
0
5
5
0
5
5
0 Side distance, m
5
5
0
5
5
0
5
Front distance, m
100 80 60 40 20 0
Al-Shedivat et al., JMLR, 2017
© Petuum,Inc. 78
Conclusion • DL & GM: the fields are similar in the beginning (structure, energy, etc.), and then diverge to their own signature pipelines • DL: most effort is directed to comparing different architectures and their components (models are driven by evaluating empirical performance on a downstream tasks) • DL models are good at learning robust hierarchical representations from the data and suitable for simple reasoning (call it “low-level cognition”)
• GM: the effort is directed towards improving inference accuracy and convergence speed • GMs are best for provably correct inference and suitable for high-level complex reasoning tasks (call it “high-level cognition”)
• Convergence of both fields is very promising! • Next part: a unified view of deep generative models in the GM interpretation © Petuum,Inc. 79
Part-II Deep Generative Models
Plan • Statistical And Algorithmic Foundation and Insight of Deep Learning
• On Unified Framework of Deep Generative Models
• Computational Mechanisms: Distributed Deep Learning Architectures © Petuum,Inc. 81
Outline • Overview of advances in deep generative models • Backgrounds of deep generative models • Wake sleep algorithm • Variational autoencoders • Generative adversarial networks
• A unified view of deep generative models • new formulations of deep generative models • Symmetric modeling of latent and visible variables
© Petuum,Inc. 82
Outline • Overview of advances in deep generative models • Backgrounds of deep generative models • Wake sleep algorithm • Variational autoencoders • Generative adversarial networks
• A unified view of deep generative models • new formulations of deep generative models • Symmetric modeling of latent and visible variables
© Petuum,Inc. 83
Deep generative models • Define probabilistic distributions over a set of variables • "Deep" means multiple layers of hidden variables! #$ ...
#% & © Petuum,Inc. 84
Early forms of deep generative models • Hierarchical Bayesian models • Sigmoid brief nets [Neal 1992]
(&)
|ä = 0,1
ã
Ç0K ($)
|ä = 0,1 çä = 0,1 ($)
($)
(&)
&
å
é
b Xèä = 1 cè , |ä = Z cêè |ä ($)
b k0ä = 1 c0 , |ä = Z cê0 |ä
© Petuum,Inc. 85
Early forms of deep generative models • Hierarchical Bayesian models • Sigmoid brief nets [Neal 1992]
• Neural network models • Helmholtz machines [Dayan et al.,1995]
7$
7&
inference weights # [Dayan et al. 1995]
© Petuum,Inc. 86
Early forms of deep generative models • Hierarchical Bayesian models • Sigmoid brief nets [Neal 1992]
• Neural network models • Helmholtz machines [Dayan et al.,1995] • Predictability minimization [Schmidhuber 1995]
DATA Figure courtesy: Schmidhuber 1996 © Petuum,Inc. 87
Early forms of deep generative models • Training of DGMs via an EM style framework • Sampling / data augmentation | = |$ , |& |$äòô ~b |$ |& , ç äòô |äòô ~b | | ,ç & $ &
• Variational inference log b ç ≥ Erí | ç log bg ç, | − KL(hì | ç || b(|)) ≔ ℒ(c, ñ; ç) maxc,ñ ℒ(c, ñ; ç)
• Wake sleep Wake: ming Vrí(ö|[) log bg X k Sleep: minì Võú([|ö) log hì k X
© Petuum,Inc. 88
Resurgence of deep generative models • Restricted Boltzmann machines (RBMs) [Smolensky, 1986] • Building blocks of deep probabilistic models
© Petuum,Inc. 89
Resurgence of deep generative models • Restricted Boltzmann machines (RBMs) [Smolensky, 1986] • Building blocks of deep probabilistic models
• Deep belief networks (DBNs) [Hinton et al., 2006] • Hybrid graphical model • Inference in DBNs is problematic due to explaining away
• Deep Boltzmann Machines (DBMs) [Salakhutdinov & Hinton, 2009] • Undirected model
© Petuum,Inc. 90
Resurgence of deep generative models • Variational autoencoders (VAEs) [Kingma & Welling, 2014] / Neural Variational Inference and Learning (NVIL) [Mnih & Gregor, 2014]
hì (||ç) inference model
bg (ç||) generative model
Figure courtesy: Kingma & Welling, 2014
© Petuum,Inc. 91
Resurgence of deep generative models • Variational autoencoders (VAEs) [Kingma & Welling, 2014] / Neural Variational Inference and Learning (NVIL) [Mnih & Gregor, 2014]
• Generative adversarial networks (GANs)
ùg : generative model tì : discriminator © Petuum,Inc. 92
Resurgence of deep generative models • Variational autoencoders (VAEs) [Kingma & Welling, 2014] / Neural Variational Inference and Learning (NVIL) [Mnih & Gregor, 2014]
• Generative adversarial networks (GANs) • Generative moment matching networks (GMMNs) [Li et al., 2015; Dziugaite et al., 2015]
© Petuum,Inc. 93
Resurgence of deep generative models • Variational autoencoders (VAEs) [Kingma & Welling, 2014] / Neural Variational Inference and Learning (NVIL) [Mnih & Gregor, 2014]
• Generative adversarial networks (GANs) • Generative moment matching networks (GMMNs) [Li et al., 2015; Dziugaite et al., 2015]
• Autoregressive neural networks
"$
"'
"(
") © Petuum,Inc. 94
Outline • Overview of advances in deep generative models • Backgrounds of deep generative models • Wake sleep algorithm • Variational autoencoders • Generative adversarial networks
• A unified view of deep generative models • new formulations of deep generative models • Symmetric modeling of latent and visible variables
© Petuum,Inc. 95
Synonyms in the literature • Posterior Distribution -> Inference model • • • • •
Variational approximation Recognition model Inference network (if parameterized as neural networks) Recognition network (if parameterized as neural networks) (Probabilistic) encoder
• "The Model" (prior + conditional, or joint) -> Generative model • • • •
The (data) likelihood model Generative network (if parameterized as neural networks) Generator (Probabilistic) decoder © Petuum,Inc. 96
Recap: Variational Inference • Consider a generative model bg ç|| , and prior b | • Joint distribution: bg ç, | = bg ç|| b |
• Assume variational distribution hì ||ç • Objective: Maximize lower bound for log likelihood log b ç = Q hì | ç || bc | ç bg ç, | ≥ û hì | ç log h | ç ì | ≔ ℒ(c, ñ; ç)
+ û hì |
bg ç, | | ç log hì | ç
• Equivalently, minimize free energy s c, Å; ç = −log b ç + Q(hì | ç || bc (||ç))
© Petuum,Inc. 97
Recap: Variational Inference Maximize the variational lower bound ℒ(c, ñ; ç) • E-step: maximize ℒ wrt. Å with c fixed maxì ℒ c, ñ; ç = Vrí (ö|[) log bg X k
+ Q(hì k X ||b(k))
• If with closed form solutions ∗ hì (k|X) ∝ exp[log bg (X, k)]
• M-step: maximize ℒ wrt. c with Å fixed maxg ℒ c, ñ; ç = Vrí k X log bg X k
+ Q(hì k X ||b(k))
© Petuum,Inc. 98
Recap: Amortized Variational Inference • Variational distribution as an inference model hì | ç with parameters ñ • Amortize the cost of inference by learning a single datadependent inference model • The trained inference model can be used for quick inference on new data • Maximize the variational lower bound ℒ(c, ñ; ç) • E-step: maximize ℒ wrt. ñ with c fixed • M-step: maximize ℒ wrt. c with ñ fixed © Petuum,Inc. 99
Deep generative models with amortized inference • Helmholtz machines • Variational autoencoders (VAEs) / Neural Variational Inference and Learning (NVIL) • We will see later that adversarial approaches are also included in the list • Predictability minimization (PM) • Generative adversarial networks (GANs)
© Petuum,Inc. 100
Wake Sleep Algorithm • [Hinton et al., Science 1995] • Train a separate inference model along with the generative model • Generally applicable to a wide range of generative models, e.g., Helmholtz machines
• Consider a generative model bg ç | and prior b | • Joint distribution bg ç, | = bg ç | b | • E.g., multi-layer brief nets
• Inference model hì | ç • Maximize data log-likelihood with two steps of loss relaxation: • Maximize the lower bound of log-likelihood, or equivalently, minimize the free energy s c, ñ; ç = −log b ç + Q(hì | ç || bc (||ç)) • Minimize a different objective (reversed KLD) wrt Å to ease the optimization • Disconnect to the original variational lower bound loss
s′ c, ñ; ç = −log b ç + Q(bg | ç || hì (||ç))
© Petuum,Inc. 101
R2
Wake Sleep Algorithm • Free energy:
ç
R1
s c, ñ; ç = −log b ç + Q(hñ | ç || bc (||ç))
• Minimize the free energy wrt. c of bg à wake phase maxc Erí(||ç) log bc (ç, |) • Get samples from hì (k|X) through inference on hidden variables • Use the samples as targets for updating the generative model bg (||ç) • Correspond to the variational M step
[Figure courtesy: Maei’s slides]
© Petuum,Inc. 102
Wake Sleep Algorithm • Free energy: s c, ñ; ç = −log b ç + Q(hñ | ç || bc (||ç))
• Minimize the free energy wrt. Å of hì | ç • Correspond to the variational E step • Difficulties: o (|, ç) ∗ hñ |ç =
c
∫ oc |, ç •| intractable • Optimal • High variance of direct gradient estimate ¢ì s Ç, Å; X = ⋯ + ¢ì Vrí (ö|[) log bg (k, X) + ⋯ • Gradient estimate with the log-derivative trick: ¢ì Vrí log bg = ∫ ¢ì hì log bg = ∫ hì log bg ¢ì log hì = Vrí [log bg ¢ì log hì ]
• Monte Carlo estimation: ¢ì Vrí log bg ≈ Vö_∼rí [log bg (X, k0 ) ¢ì hì k0 |X ] • The scale factor log bg of the derivative ¢ì log hì can have arbitrary large magnitude © Petuum,Inc. 103
Wake Sleep Algorithm • Free energy:
ç
R2
G2
R1
G1
s c, ñ; ç = −log b ç + Q(hñ | ç || bc (||ç))
• WS works around the difficulties with the sleep phase approximation • Minimize the following objective à sleep phase s′ c, ñ; ç = −log b ç + Q(bg | ç || hì (||ç)) maxñ Eõú(|,ç) log hì | ç • “Dreaming” up samples from bg ç | through top-down pass • Use the samples as targets for updating the inference model
• (Recent approaches other than sleep phase is to reduce the variance of gradient estimate: slides later) [Figure courtesy: Maei’s slides]
© Petuum,Inc. 104
Wake Sleep Algorithm Wake sleep
Variational EM
• Parametrized inference model hñ | ç
• Variational distribution hì | ç
• Wake phase: • minimize Q(hñ | ç || bc (||ç)) wrt. Ç • Erí(||ç) ¢g log bc ç |
• Variational M step: • minimize Q(hì | ç || bc (||ç)) wrt. Ç • Erñ(||ç) ¢g log bc ç |
• Sleep phase: • minimize Q(bg | ç || hì (||ç)) wrt. Å
• Variational E step: • minimize Q(hì | ç || bc (||ç)) wrt. Å ∗ • hì ∝ exp[log bg ] if with closed-form • ¢ì Vrí log bg (k, X)
• Eõú(|,ç) ¢ì log hì (|, ç) • low variance • Learning with generated samples of ç
• Two objective, not guaranteed to converge
• need variance-reduce in practice • Learning with real data ç • Single objective, guaranteed to converge © Petuum,Inc. 105
Variational Autoencoders (VAEs) • [Kingma & Welling, 2014] • Use variational inference with an inference model • Enjoy similar applicability with wake-sleep algorithm
• Generative model bg ç | , and prior b(|) • Joint distribution bg ç, | = bg ç | b |
• Inference model hì | ç
hì (||ç) inference model
bg (ç||) generative model
Figure courtesy: Kingma & Welling, 2014 © Petuum,Inc. 106
Variational Autoencoders (VAEs) • Variational lower bound ℒ c, ñ; ç = Erí | ç log bg ç, | − KL(hì | ç || b(|))
• Optimize ℒ(c, ñ; ç) wrt. Ç of bg ç | • The same with the wake phase
• Optimize ℒ(c, ñ; ç) wrt. Å of hì | ç ¢ì ℒ Ç, Å; X = ⋯ + ¢ì Vrí (ö|[) log bg X k
+⋯
• Use reparameterization trick to reduce variance • Alternatives: use control variates as in reinforcement learning [Mnih & Gregor, 2014; Paisley et al., 2012]
© Petuum,Inc. 107
Reparametrized gradient • Optimize ℒ c, ñ; ç wrt. Å of hì | ç • Recap: gradient estimate with log-derivative trick: ¢ì Vrí log bg ç, | = Vrí [log bg ç, | ¢ì log hì ] • High variance: ¢ì Vrí log bg ≈ Vö_∼ rí [log bg (X, k0 ) ¢ì hì k0 |X ] • The scale factor log bg (X, k0 ) of the derivative ¢ì log hì can have arbitrary large magnitude
• gradient estimate with reparameterization trick | ∼ hì | ç ⇔ Æ = g ì ¨, ç , ¢ì Erí | ç log bg ç, |
¨ ∼ b(¨)
= E¨∼õ(≠) ¢ì log bg ç, |ì ¨
• (Empirically) lower variance of the gradient estimate • E.g., | ∼ ß ® ç , © ç © ç ™ ⇔ ¨ ∼ ß 0,1 , | = ® ç + ©(ç)¨ © Petuum,Inc. 108
VAEs: algorithm
[Kingma & Welling, 2014]
© Petuum,Inc. 109
input mean samp. 1 samp. 2 samp. 3
VAEs: example results •
VAEs tend to generate blurred images due to the mode covering behavior (more later)
we looked out at the setting sun . they were laughing at the same time . ill see you in the early morning . i looked up at the blue sky . it was down on the dance floor .
• Latent interpolation Table 7: Threecode sentences which and were used as inputs to sentences generation from VAEs mean of the posterior distribution, and from three samp [Bowman et al., 2015].
“ i want to talk to you . ” “i want to be with you . ” “i do n’t want to be with you . ” i do n’t want to be with you . she did n’t want to be with him .
Celebrity faces [Radford 2015]
iw i we i we i loo i tu
he was silent for a long moment . he was silent for a moment . it was quiet for a moment . it was dark and cold . there was a pause . it was my turn . © Petuum,Inc. 110
se is th ge lo m m
no va ti
Generative Adversarial Nets (GANs) • [Goodfellow et al., 2014] • Generative model ç = ùg | , | ∼ b(|)
• Map noise variable | to data space ç • Define an implicit distribution over ç: b∞ú (ç) • a stochastic process to simulate data ç • Intractable to evaluate likelihood
• Discriminator tì ç • Output the probability that ç came from the data rather than the generator • No explicit inference model • No obvious connection to previous models with inference networks like VAEs • We will build formal connections between GANs and VAEs later © Petuum,Inc. 111
x from data distribution pdata (x). The distribution in Eq.(1) is thus rewritten as: ⇢ pdata (x) y = 0 p(x|z, y) = pg (x|z) y = 1,
(5)
Generative Adversarial Nets (GANs) where p (x|z) = G(z) is the generative distribution. Note that p (x) is the empirical data g
•
data
distribution which is free of parameters. The discriminator is defined in the same way as above, i.e., D(x) = p(y = 0|x). Then the objective of GAN is precisely defined in Eq.(2). To make this clearer, Learning we again transform the objective into its conventional form: maxDgame LD = between Ex⇠pdata (x)the [loggenerator D(x)] + Ex⇠G(z),z⇠p(z) [log(1 D(x))] , • A minimax and the discriminator
• Train tmax to G maximize the probability of assigning the correct LG = Ex⇠p [log(1 D(x))] + Ex⇠G(z),z⇠p(z) [loglabel D(x)]to both data (x) training examples generated samples = Eand x⇠G(z),z⇠p(z) [log D(x)] . • Train ù to fool the discriminator maxD LD = Ex⇠pdata (x) [log D(x)] + Ex⇠G(z),z⇠p(z) [log(1
D(x))] ,
maxD LD = Ex⇠pdata (x) [log D(x)] + Ex⇠G(z),z⇠p(z) [log(1
D(x))] ,
minG LG = Ex⇠G(z),z⇠p(z) [log(1
(6)
D(x))] .
maxG LG = Ex⇠G(z),z⇠p(z) [log D(x)] .
Note that for learning the generator we are using the adapted objective, i.e., maximizing Ex⇠G(z),z⇠p(z) [log D(x)], as is usually used in practice (Goodfellow et al., 2014), rather than minimizing Ex⇠G(z),z⇠p(z) [log(1 D(x))]. © Petuum,Inc. 112
[Figure courtesy: Kim’s slides]
maxD LD = Ex⇠pdata (x) [log D(x)] + Ex⇠G(z),z⇠p(z) [log(1 maxG LG = Ex⇠pdata (x) [log(1
D(x))] ,
D(x))] + Ex⇠G(z),z⇠p(z) [log D(x)]
E [log D(x)] . Generative =Adversarial Nets (GANs) x⇠G(z),z⇠p(z)
• Learning
maxD LD = Ex⇠pdata (x) [log D(x)] + Ex⇠G(z),z⇠p(z) [log(1
Ex⇠G(z),z⇠p(z) [log(1 G LG • Train ùmin to fool the=discriminator
D(x))] ,
D(x))] .
• The original loss suffers from vanishing gradients when t is too strong maxD LD Ex⇠pdatain(x) [log D(x)] + Ex⇠G(z),z⇠p(z) [log(1 D(x))] , • Instead use the=following practice
maxG LG = Ex⇠G(z),z⇠p(z) [log D(x)] .
Note that for learning the generator we are using the adapted objective, i.e., maximizi Ex⇠G(z),z⇠p(z) [log D(x)], as is usually used in practice (Goodfellow et al., 2014), rather th minimizing Ex⇠G(z),z⇠p(z) [log(1 D(x))].
KL Divergence Interpretation Now we take a closer look into Eq.(2). Assume uniform prior distribution p(y) where p(y = 0) p(y = 1) = 0.5. For optimizing p(x|z, y), we have Theorem 1. Let p✓ (x|z, y) be the conditional distribution in Eq.(1) parameterized with © Petuum,Inc. 113 ✓. Den
[Figure courtesy: Kim’s slides]
0
Generative Adversarial Nets (GANs) • Learning • Aim to achieve equilibrium of the game • Optimal state: • b∞ ç = b)±≤± (X) • t ç =
[Figure courtesy: Kim’s slides]
õ≥¥µ¥ [ õ≥¥µ¥ [ ∂õ∑ [
=
$ &
© Petuum,Inc. 114
GANs: example results
Generated bedrooms [Radford et al., 2016]
© Petuum,Inc. 115
Alchemy Vs Modern Chemistry
© Petuum,Inc. 116
Outline • Overview of advances in deep generative models • Backgrounds of deep generative models • Wake sleep algorithm • Variational autoencoders • Generative adversarial networks
• A unified view of deep generative models • new formulations of deep generative models • Symmetric modeling of latent and visible variables Z Hu, Z YANG, R Salakhutdinov, E Xing, “On Unifying Deep Generative Models”, arxiv 1706.00550 © Petuum,Inc. 117
A unified view of deep generative models • Literatures have viewed these DGM approaches as distinct model training paradigms • GANs: achieve an equilibrium between generator and discriminator • VAEs: maximize lower bound of the data likelihood
• Let's study a new formulation for DGMs • Connects GANs, VAEs, and other variants, under a unified view • Links them back to inference and learning of Graphical Models, and the wake-sleep heuristic that approximates this • Provides a tool to analyze many GAN-/VAE-based algorithms • Encourages mutual exchange of ideas from each individual class of models © Petuum,Inc. 118
Adversarial domain adaptation (ADA) • Let’s start from ADA • The application of adversarial approach on domain adaptation • We then show GANs can be seen as a special case of ADA • Correspondence of elements: Elements
GANs
ADA
ç
data/generation
features
|
code vector
Data from src/tgt domains
`
Real/fake indicator
Source/target domain indicator
GANs
ADA © Petuum,Inc. 119
Adversarial domain adaptation (ADA) • Data k from two domains indicated by ` ∈ 0,1 • Source domain (` = 1) • Target domain (` = 0) ,
• ADA transfers prediction knowledge learned from the source domain to the target domain • Learn a feature extractor ùg : ç = ùg (|) • Wants ç to be indistinguishable by a domain discriminator: tì ç
• Application in classification • E.g., we have labels of the source domain data • Train classifier over ç of source domain data to predict the labels • ç is domain invariant ⇒ ç is predictive for target domain data
© Petuum,Inc. 120
ADA: conventional formulation • Train t to distinguish between domains
maximize theìbinary classification accuracy of recognizing the feature domains: maximize accuracy recognizing feature domains: max Lthe=binary Ex=Gclassification [log D of (x)] + Ex=G✓ the [log(1 (z),z⇠p(z|y=0) ✓ (z),z⇠p(z|y=1) L extractor = Ex=GG✓ (z),z⇠p(z|y=1) [log Ex=G✓ (z),z⇠p(z|y=0) [log(1 Themax feature to D fool(x)] the + discriminator: ✓ is then trained
• Train ù to fool t
D (x))] . D (x))] .
(1) (1)
g E ì = [log(1 + Ex=G✓ (z),z⇠p(z|y=0) [log D (x)] . (2) Themax feature is then trained to foolDthe(x))] discriminator: ✓ L✓extractor ✓ x=GG (z),z⇠p(z|y=1) ✓ max L✓ =the Ex=G + Etox=G D (x)]domain . (2) Here we✓omit additional loss on ✓[log(1 that fits D the (x))] features the✓ data label pairs[log of source (z),z⇠p(z|y=0) ✓ (z),z⇠p(z|y=1) (see materials details). Herethe wesupplementary omit the additional loss for on the ✓ that fits the features to the data label pairs of source domain
(see the for the details). With the supplementary background of materials the conventional formulation, we now frame our new interpretation of ADA. The data distribution p(z|y) and deterministic transformation G✓ together form an implicit distribution With the background of the conventional formulation, we now frame our new interpretation of ADA. over x, denoted as p✓ (x|y), which is intractable to evaluate likelihood but easy to sample from. Let The data distribution p(z|y) and deterministic transformation G✓ together form an implicit distribution p(y) be the prior distribution of the domain indicator y, e.g., a uniform distribution as in Eqs.(1)-(2). over x, denoted as p (x|y), which is intractable to evaluate likelihood but easy to sample from. Let ✓ The discriminator defines a conditional distribution q (y|x) = D (x). Let q r (y|x) = q (1 y|x) p(y) be the prior distribution of the domain indicator y, e.g., a uniform distribution as in Eqs.(1)-(2). be the reversed distribution over domains. The objectives of ADA are therefore rewritten as (up to a The discriminator defines a conditional distribution q (y|x) = D (x). Let q r (y|x) = q (1 y|x) © Petuum,Inc. 121 constant scale factor 2):
t domains.
rame our new interpretation of ADA, and review conventional formulations in the supplementary rials. To make clear notational correspondence to other models in the sequel, [Eric: Please add ure drawing a graphical model here for ADA.] let z be a data example either in the source rget domain, and y 2 {0, 1} be the domain indicator with y = 0 indicating the target domain y = 1 the source domain. The data distributions conditioning on the domain are then denoted (z|y). Let• p(y) be the prior (e.g., of the domain indicator. The feature let’s✓rewrite To reveal the distribution connections touniform) conventional variational approaches, Figure 2: One optimization step of the parameter ✓ through Eq.(7) at point 0 . The posterior ctor maps zthe to representations x = G (z) with parameters ✓. The data distributions over z and ✓ objectives that variational EM = 1) (red in the left panel) with the q r (x|y)inisaaformat mixture of p✓0resembles (x|y = 0) (blue) and p✓0 (x|y rministic transformationmixing G✓ together form an implicit distribution over x, denoted as p✓ (x|y), of Eq.(7) w.r.t ✓ drives r (y|x). weights induced from q Minimizing the KL divergence 0 • Implicit distribution over ç ∼tobsample g (ç|`)from: h is intractable to evaluate likelihood but easy p✓ (x|y = 0) towards the respective mixture q r (x|y = 0) (green), resulting in a new state where ç = ùg | , | ∼ bnew |` = 1) = pdistinguish p✓new = 0) pg (x) gets is closer to pto nforce domain invariance of (x|y feature x, = a discriminator trained adversarially ✓0 (x|y data (x). Due to the asymmetry of • Discriminator distribution h (x) KL divergence, pnew missed the smaller modewith of the mixture q r (x|y een the two domains, which defines a conditional distribution q (y|x) parameters , and= 0) which is a mode of ì (`|ç) g r ∏to pdata (x). eature extractor is optimized Let q (y|x) = q (1 y|x) be the reversed hì `fool ç the = hdiscriminator. (1 − `|ç) ì ibution over domains. The objectives of ADA are therefore given as:
ADA: new formulation
• Rewrite the objective in the new form (up to constant scale factor)
maxtheLprior =E [logasq is(y|x)] p✓ (x|y)p(y) where p(y) is uniform widely set, resulting in the constant scale factor 1/2. Note that ⇥ ⇤ r unsaturated objective [16] which is(1) here the generator is trained using the commonly used in practice. max L = E log q (y|x) , ✓
✓
p✓ (x|y)p(y)
• | is encapsulated in the implicit distribution bg (ç|`) e we omit the additional loss of ✓ to fit to the data label pairs of source domain (see supplements max view, L =the Epfirst [log q (y =the 0|x)] + Ep✓ (x|y=1)p(y=1) [log q (y = 1|x)] more details). In conventional equation minimizes discriminator binary cross ✓ (x|y=0)p(y=0) (6) opy with respect to discriminative parameter , while the second trains the feature extractor 1 1 = Eto [log(1 D (x))] + Ex=G [log D (x)] x=G ✓ (z),z⇠p(z|y=1) aximize the cross entropy with respect the✓ (z),z⇠p(z|y=0) transformation parameter ✓. [Eric: for 2 2 I think (Ignorebe thebetter constant factorboth 1/2) of the cross-entropy notion above.] containedness, it• would toscale explain © Petuum,Inc. 122 rnatively, we can interpret the objectives as optimizing the reconstruction of the domain variable
✓
ic transformation G✓ together form an implicit distribution over x, denoted as p✓ (x|y), ractable to evaluate likelihood but easy to sample from:
ADA: new formulation
domain invariance of feature x, a discriminator is trained to adversarially distinguish e two domains, which defines a conditional distribution q (y|x) with parameters , and extractor is optimized to fool the discriminator. Let q r (y|x) = q (1 y|x) be the reversed over domains. objectives of ADA are therefore given as: • NewThe formulation max L = Ep✓ (x|y)p(y) [log q (y|x)] ⇥ ⇤ r max✓ L✓ = Ep✓ (x|y)p(y) log q (y|x) ,
(1)
mit the additional loss of ✓difference to fit to thebetween data label Çpairs (see supplements • The only andofÅ:source h vs.domain h∏ tails). In conventional view, the first equation minimizes the discriminator binary cross • This is where the adversarial mechanism comes about h respect to discriminative parameter , while the second trains the feature extractor e the cross entropy with respect to the transformation parameter ✓. [Eric: I think for nedness, it would be better to explain both of the cross-entropy notion above.] y, we can interpret the objectives as optimizing the reconstruction of the domain variable ed on feature x. [Eric: I can not understand this point.] We explore this perspective next section. Note that the only (but critical) difference between the objective of ✓ from lacement of q(y|x) with q r (y|x). This is where the adversarial mechanism comes about. 3
© Petuum,Inc. 123
and y = 1 the source domain. The data distributions conditioning on the domain as p(z|y). Let p(y) be the prior distribution (e.g., uniform) of the domain indic extractor maps z to representations x = G✓ (z) with parameters ✓. The data distrib deterministic transformation G✓ together form an implicit distribution over x, de which is intractable to evaluate likelihood but easy to sample from:
ADA vs. Variational EM Variational EM • Objectives
To enforce domain invariance of feature x, a discriminator is trained to adversa between the two domains, which defines a conditional distribution q (y|x) with p r the feature extractor is optimized to fool the discriminator. Let q (y|x) = q (1 y| ADA distribution over domains. The objectives of ADA are therefore given as:
maxì ℒñ,c = Vrí (ö|[) log bg X k
+ Q hì k X ||b k
maxg ℒñ,c = Vrí (ö|[) log bg X k
+ Q hì k X ||b k
• Objectives max L = Ep✓ (x|y)p(y) [log q (y|x)] ⇥ ⇤ r max✓ L✓ = Ep✓ (x|y)p(y) log q (y|x) , • Two objectives
• Single objective for both Ç and Å we omit the additional loss of ✓ to fit to the data label pairs of source domain where • Have global optimal state in the game for more details). In conventional view, the first equation minimizes the discrimi • Extra prior regularization by b(k) theoretic view entropy with respect to discriminative parameter , while the second trains the to maximize the cross entropy with respect to the transformation parameter ✓. self-containedness, it would be better to explain both of the cross-entrop Alternatively, we can interpret the objectives as optimizing the reconstruction of th y conditioned on feature x. [Eric: I can not understand this point.] We explor more in the next section. Note that the only (but critical) difference between the o is the replacement of q(y|x) with q r (y|x). This is where the adversarial mecha 3
© Petuum,Inc. 124
and y = 1 the source domain. The data distributions conditioning on the domain as p(z|y). Let p(y) be the prior distribution (e.g., uniform) of the domain indic extractor maps z to representations x = G✓ (z) with parameters ✓. The data distrib deterministic transformation G✓ together form an implicit distribution over x, de which is intractable to evaluate likelihood but easy to sample from:
ADA vs. Variational EM Variational EM • Objectives
To enforce domain invariance of feature x, a discriminator is trained to adversa between the two domains, which defines a conditional distribution q (y|x) with p r the feature extractor is optimized to fool the discriminator. Let q (y|x) = q (1 y| ADA distribution over domains. The objectives of ADA are therefore given as:
maxì ℒñ,c = Vrí (ö|[) log bg X k
+ Q hì k X ||b k
maxg ℒñ,c = Vrí (ö|[) log bg X k
+ Q hì k X ||b k
• Objectives max L = Ep✓ (x|y)p(y) [log q (y|x)] ⇥ ⇤ r max✓ L✓ = Ep✓ (x|y)p(y) log q (y|x) , • Two objectives
• Single objective for both Ç and Å we omit the additional loss of ✓ to fit to the data label pairs of source domain where • Have global optimal state in the game for more details). In conventional view, the first equation minimizes the discrimi • Extra prior regularization by b(k) theoretic view entropy with respect to discriminative parameter , while the second trains the • The reconstruction term: maximize the conditional • Thewith objectives: the conditional to maximize the cross entropy respect tomaximize the transformation parameter ✓. log-likelihood of X with the generative distribution it would log-likelihood (or 1 both − `) of with self-containedness, be better toof ` explain thethe cross-entrop bg (X|k) conditioning on the latent code k inferred distribution hì (`|X) conditioning on latent of th Alternatively, we can interpret the objectives as optimizing the reconstruction by hì (k|X) y conditioned on feature x. feature [Eric: I can not understand this point.] We explor X inferred by bg (X|`) more in the next section. Note that the only (but critical) difference between the o is the replacement of q(y|x) with q r (y|x). This is where the adversarial mechan • bg (X|k) is the generative model • Interpret hì (`|X) as the generative model • hì (k|X) is the inference model • Interpret bg (X|`) as 3 the inference model © Petuum,Inc. 125
We frame our new interpretation of ADA, and review conventional formu materials. To make clear notational correspondence to other models in th a figure drawing a graphical model here for ADA.] let z be a data e or target domain, and y 2 {0, 1} be the domain indicator with y = 0 in and y = 1 the source domain. The data distributions conditioning on th as p(z|y). Let p(y) be the prior distribution (e.g., uniform) of the dom Define: extractor maps z to representations x = G✓ (z) with parameters ✓. The d deterministic • Solid-line arrows (X → `): transformation G✓ together form an implicit distribution which is intractable to evaluate likelihood but easy to sample from: • generative process
ADA: graphical model
To enforce • Dashed-line arrows (y, z domain → X): invariance of feature x, a discriminator is trained
between the two domains, which defines a conditional distribution q (y r the feature extractor is optimized to fool the discriminator. Let q (y|x) = • Hollow arrows (z → X): over domains. The objectives of ADA are therefore given as • deterministic distribution transformation • inference
• leading to implicit distributions
• Blue arrows (X → `): • adversarial mechanism ∏ • involves both hì (`|ç) and hì (`|ç)
max L = Ep✓ (x|y)p(y) [log q (y|x)] ⇥ ⇤ r max✓ L✓ = Ep✓ (x|y)p(y) log q (y|x) ,
where we omit the additional loss of ✓ to fit to the data label pairs of sour for more details). In conventional view, the first equation minimizes the © Petuum,Inc. 126 entropy with respect to discriminative parameter , while the second
GANs: a variant of ADA • Transfer the properties of source domain to target domain • Source domain: e.g. real image, ` = 1 • Target domain: e.g. generated image, ` = 0
ADA
GANs
© Petuum,Inc. 127
ed sample domain (y = 0), the implicit distribution p✓ (x|y = 0) is defined by the prior of generator G✓ (z), which is also denoted as pg✓ (x) in the literature. For the real example = 1), the code space and generator are degenerated, and we are directly presented with a bution p(x|y = 1), which is just the real data distribution pdata (x). Note that pdata (x) is plicit distribution allowing efficient empirical sampling. In summary, the distribution over ucted as• Implicit distribution over ç ∼ bg (ç|`) ⇢ (distribution of generated images) pg✓ (x) y=0 p✓ (x|y) = (5) pdata (x) y = 1. (distribution of real images)
GANs: a variant of ADA
parameters pg✓b(x) • ç ∼✓b∞are çonly ⟺associated ç = ùg |with , |∼ | `of=the 0 generated sample domain, while ú s constant. As in ADA, discriminator D is simultaneously trained to infer the probability mes from• the dataçdomain. That is, q (y = 1|x) = D (x). ç ∼real b)±≤±
the code space between of | is degenerated stablished •correspondence GANs and ADA, we can see that the objectives of sample directly fromand data precisely •expressed as Eq.(4) as the graphical model in Figure 1(c). To make this
4
© Petuum,Inc. 128
make clear notational correspondence in the in sequel, [Eric: scale Please add where the prior p(y) is uniform astois other widelymodels set, resulting the constant factor 1/2. Note that wing a here graphical model ADA.] let z be a objective data example either the source the generator is here trainedfor using the unsaturated [16] which is in commonly used in practice. ain, and y 2 {0, 1} be the domain indicator with y = 0 indicating the target domain source domain. The data distributions conditioning on the domain are then denoted L = Ep✓ (x|y=0)p(y=0) [log q (y = 0|x)] + Ep✓ (x|y=1)p(y=1) [log q (y = 1|x)] et p(y) be max the prior distribution (e.g., uniform) of the domain indicator. The feature 1 = G✓ (z) with parameters ✓. The data 1distributions over z and s z to representations x • Again, rewrite GAN objectives in the ”variational-EM” format [log D (x)] = Ex=G [log(1 D (x))] + Ex=G✓ (z),z⇠p(z|y=1) ✓ (z),z⇠p(z|y=0) 2 transformation G✓ together form an implicit distribution over2x, denoted as p✓ (x|y), • Recap: conventional formulation: ctable to evaluate likelihood but easy to sample from:
GANs: new formulation max L = Ex=G✓ (z),z⇠p(z|y=0) [log(1
D (x))] + Ex⇠pdata (x) [log D (x)]
omain invariance of feature x, a discriminator is trained to adversarially distinguish maxdefines [log D q(x)] + Ex⇠p D (x))] ✓ L✓ = E (x) [log(1 ✓ (z),z⇠p(z|y=0) wo domains, which a x=G conditional distribution (y|x) withdata parameters , and Ex=G D (x)] ractor is optimized to fool=the discriminator. Let [log q r (y|x) = q (1 y|x) be the reversed ✓ (z),z⇠p(z|y=0) ver domains. The objectives of ADA • Rewrite in the new formare therefore given as:
We now take a closer look at the form of Eq.(4) which is essentially reconstructing the real/fake indicator ymax (or its L reverse y) conditioned x. Further, for each optimization step of p✓ (x|y) at = E1p✓ (x|y)p(y) [log q on (y|x)] point (✓0 , 0 ) in the parameter space,⇥we have ⇤ (1)
max✓ L✓ = Ep✓ (x|y)p(y) log q r (y|x) ,
Lemma 1 Let p(y) be the uniform distribution. Let p✓0 (x) = Ep(y) [p✓0 (x|y)], and q r (x|y) / • Exact the✓ same with !label q r 0 (y|x)p Therefore, theADA updates of ✓ at ✓of ✓0 (x). 0 have t the additional loss of to fit to the data pairs source domain (see supplements • The same correspondence to minimizes variational EMdiscriminator ! h view, i ⇤ ils). In conventional the first ⇥equation the binary cross r r✓ Ep✓parameter (y|x)the second = trains the feature extractor (x|y)p(y) log q, while 0 respect to discriminative ✓=✓0 © Petuum,Inc. 129 h i (7) the cross entropy with respect to the transformation parameter ✓. [Eric: I think for
and y = 1 the source domain. The data distributions conditioning on the domain as p(z|y). Let p(y) be the prior distribution (e.g., uniform) of the domain indic extractor maps z to representations x = G✓ (z) with parameters ✓. The data distrib deterministic transformation G✓ together form an implicit distribution over x, de which is intractable to evaluate likelihood but easy to sample from:
GANs vs. Variational EM Variational EM • Objectives
To enforce domain invariance of feature x, a discriminator is trained to adversa between the two domains, which defines a conditional distribution q (y|x) with p r the feature extractor is optimized to fool the discriminator. Let q (y|x) = q (1 y| GAN distribution over domains. The objectives of ADA are therefore given as:
maxì ℒñ,c = Vrí (ö|[) log bg X k
+ Q hì k X ||b k
maxg ℒñ,c = Vrí (ö|[) log bg X k
+ Q hì k X ||b k
• Objectives max L = Ep✓ (x|y)p(y) [log q (y|x)] ⇥ ⇤ r max✓ L✓ = Ep✓ (x|y)p(y) log q (y|x) , • Two objectives
• Single objective for both Ç and Å we omit the additional loss of ✓ to fit to the data label pairs of source domain where • Have global optimal state in the game for more details). In conventional view, the first equation minimizes the discrimi • Extra prior regularization by b(k) theoretic view entropy with respect to discriminative parameter , while the second trains the • The reconstruction term: maximize the conditional • Thewith objectives: the conditional to maximize the cross entropy respect tomaximize the transformation parameter ✓. log-likelihood of X with the generative distribution it would log-likelihood (or 1 both − `) of with self-containedness, be better toof ` explain thethe cross-entrop bg (X|k) conditioning on the latent code k inferred distribution hì (`|X) conditioning on Alternatively, we can interpret the objectives as optimizing the reconstruction of th by hì (k|X) y conditioned on feature x. data/generation [Eric: I can not understand thisbpoint.] X inferred by g (X|`) We explor more in the next section. Note that the only (but critical) difference between the o is the replacement of q(y|x) with q r (y|x). This is where the adversarial mechan • bg (X|k) is the generative model • Interpret hì (`|X) as the generative model • hì (k|X) is the inference model • Interpret bg (X|`) as 3 the inference model © Petuum,Inc. 130
and y = 1 the source domain. The data distributions conditioning on the domain as p(z|y). Let p(y) be the prior distribution (e.g., uniform) of the domain indic extractor maps z to representations x = G✓ (z) with parameters ✓. The data distrib deterministic transformation G✓ •together form an distribution over x, de Interpret ç implicit as latent variables which is intractable to evaluate likelihood but easy to sample from:
GANs vs. Variational EM
• Interpret generation of ç as To enforce domain invariance of feature x, a discriminator is trained adversa performing inference over to latent
Variational EM • Objectives
between the two domains, which defines a conditional distribution q (y|x) with p r the feature extractor is optimized to fool the discriminator. Let q (y|x) = q (1 y| GAN distribution over domains. The objectives of ADA are therefore given as:
maxì ℒñ,c = Vrí (ö|[) log bg X k
+ Q hì k X ||b k
maxg ℒñ,c = Vrí (ö|[) log bg X k
+ Q hì k X ||b k
• Objectives max L = Ep✓ (x|y)p(y) [log q (y|x)] ⇥ ⇤ r max✓ L✓ = Ep✓ (x|y)p(y) log q (y|x) , • Two objectives
• Single objective for both Ç and Å we omit the additional loss of ✓ to fit to the data label pairs of source domain where • Have global optimal state in the game for more details). In conventional view, the first equation minimizes the discrimi • Extra prior regularization by b(k) theoretic view entropy with respect to discriminative parameter , while the second trains the • The reconstruction term: maximize the conditional • Thewith objectives: the conditional to maximize the cross entropy respect tomaximize the transformation parameter ✓. log-likelihood of X with the generative distribution it would log-likelihood (or 1 both − `) of with self-containedness, be better toof ` explain thethe cross-entrop bg (X|k) conditioning on the latent code k inferred distribution hì (`|X) conditioning on Alternatively, we can interpret the objectives as optimizing the reconstruction of th by hì (k|X) y conditioned on feature x. data/generation [Eric: I can not understand thisbpoint.] X inferred by g (X|`) We explor more in the next section. Note that the only (but critical) difference between the o is the replacement of q(y|x) with q r (y|x). This is where the adversarial mechan • bg (X|k) is the generative model • Interpret hì (`|X) as the generative model • hì (k|X) is the inference model • Interpret bg (X|`) as 3 the inference model © Petuum,Inc. 131
=
2
Ex=G✓ (z),z⇠p(z|y=0) [log(1
D (x))] + Ex=G✓ (z),z⇠p(z|y=1) [log D (x)] 2
max minimizing L =E [log(1 GANs: KLD x=G✓ (z),z⇠p(z|y=0)
D (x))] + Ex⇠pdata (x) [log D (x)]
max✓ L✓ = Ex=G✓ (z),z⇠p(z|y=0) [log D (x)] + Ex⇠pdata (x) [log(1
D (x))]
Ex=G [log D (x)] • As in Variational=EM, we✓ (z),z⇠p(z|y=0) can further rewrite in the form of minimizing KLD to reveal more insights into the optimization problem takeoptimization a closer look step at theofform of Eq.(4) which Çis=essentially the real/fak •We Fornow each bg (ç|`) at point Ça , Å = Åreconstructing a , let
indicator its reverse y) conditioned on x. Further, for each optimization step of p✓ (x|y) • b ` y: (or uniform prior 1distribution point (✓0 , 0 ) in the parameter space, we have • bgºgΩ ç = Eõ(æ) bgºgΩ ç ` ∏ • h∏ 1 ç `Let∝p(y) hìºì bgºgΩ (ç)distribution. Let p✓0 (x) = Ep(y) [p✓0 (x|y)], and q r (x|y) Lemma be `theç uniform a r (y|x)p✓0 (x). Therefore, the updates of ✓ at ✓0 have 0 •q Lemma 1: The updates of c at ca have h i ⇥ ⇤ r✓ Ep✓ (x|y)p(y) log q r = 0 (y|x) = ✓=✓0 h i ( r✓ Ep(y) [KL (p✓ (x|y)kq r (x|y))] JSD (p✓ (x|y = 0)kp✓ (x|y = 1)) , ✓=✓0
• KL: KL divergence
where KL(·k·) and JSD(·k·) are the KL and Jensen-Shannon Divergences, respectively. • JSD: Jensen-shannon divergence
© Petuum,Inc. 132
We provide the proof in the supplement materials. Eq.(7) offers several insights into the generat
We then obtain the conventional formulation of adversarial domain adaptation used or similar in [3, 4, 5, 2].
Proof Lemma 1 2 Lemma of 1 Proof. Ep✓ (x|y)p(y) [log q r (y|x)] = Ep(y) [KL (p✓ (x|y)kq r (x|y))
KL(p✓ (x|y)kp✓0 (x))] ,
(3)
where Ep(y) [KL(p✓ (x|y)kp✓0 (x))] ✓ ◆ p✓0 (x|y = 0) + p✓0 (x|y = 1) = p(y = 0) · KL p✓ (x|y = 0)k 2 ✓ ◆ p✓0 (x|y = 0) + p✓0 (x|y = 1) + p(y = 1) · KL p✓ (x|y = 1)k . 2 Note that p✓ (x|y = 0) = pg✓ (x), and p✓ (x|y = 1) = pdata (x). Let pM✓ = be simplified as: Ep(y) [KL(p✓ (x|y)kp✓0 (x))] =
pg✓ +pdata . 2
1 1 KL pg✓ kpM✓0 + KL pdata kpM✓0 . 2 2
(4)
Eq.(4) can
(5) © Petuum,Inc. 133
Proof of Lemma 1 (cont.) On the other hand, 1 pg ✓ 1 pdata JSD(pg✓ kpdata ) = Epg✓ log + Epdata log 2 pM ✓ 2 pM ✓ " # p M ✓0 1 pg ✓ 1 = Epg✓ log + Epg✓ log 2 pM ✓0 2 pM ✓ " # p M ✓0 1 pdata 1 + Epdata log + Epdata log 2 p M ✓0 2 pM ✓ " # " # pM ✓0 1 pg ✓ 1 pdata = Epg✓ log + Epdata log + EpM✓ log 2 pM ✓0 2 p M ✓0 pM ✓ =
1 1 KL pg✓ kpM✓0 + KL pdata kpM✓0 2 2
(6)
KL pM✓ kpM✓0 .
Note that r✓ KL pM✓ kpM✓0 |✓=✓0 = 0.
(7)
Taking derivatives of Eq.(5) w.r.t ✓ at ✓0 we get r✓ Ep(y) [KL(p✓ (x|y)kp✓0 (x))] |✓=✓0 ◆ ✓ 1 1 KL pg✓ kpM✓0 |✓=✓0 + KL pdata kpM✓0 |✓=✓0 = r✓ 2 2 = r✓ JSD(pg✓ kpdata ) |✓=✓0 .
(8)
Taking derivatives of the both sides of Eq.(3) at w.r.t ✓ at ✓0 and plugging the last equation of Eq.(8), we obtain the desired results.
© Petuum,Inc. 134
h
r✓ Ep(y) [KL (p✓ (x|y)kq r (x|y))]
JSD (p✓ (x|y = 0)kp✓ (x|y = 1))
i
(7) ✓=✓0
,
where KL(·k·) and JSD(·k·) are the KL and Jensen-Shannon Divergences, respectively.
GANs: minimizing KLD We provide the proof in the supplement materials. Eq.(7) offers several insights into the generator learning in GANs.
• Lemma 1: The updates of c at c have a h i r✓
⇥
Ep✓ (x|y)p(y) log q
h
r
=
r
0
(y|x)
r✓ Ep(y) [KL (p✓ (x|y)kq (x|y))]
⇤
✓=✓0
=
JSD (p✓ (x|y = 0)kp✓ (x|y = 1))
• Connection to variational inference • • • •
i
✓=✓0
,
5 See ç as latent variables, ` as visible bgºgΩ ç : prior distribution ∏ h ∏ ç ` ∝ hìºì ` ç bgºgΩ (ç) : posterior distribution a bg (ç|`): variational distribution • Amortized inference: updates model parameter c
• Suggests relations to VAEs, as we will explore shortly © Petuum,Inc. 135
h
r
JSD (p✓ (x|y = 0)kp✓ (x|y = 1))
r✓ Ep(y) [KL (p✓ (x|y)kq (x|y))]
i
(7) ✓=✓0
,
where KL(·k·) and JSD(·k·) are the KL and Jensen-Shannon Divergences, respectively.
GANs: minimizing KLD We provide the proof in the supplement materials. Eq.(7) offers several insights into the generator learning in GANs.
• Lemma of c at c have a h 1: The updates i ⇥ ⇤ r✓
h
Ep✓ (x|y)p(y) log q r = 0 (y|x) r
r✓ Ep(y) [KL (p✓ (x|y)kq (x|y))]
✓=✓0
=
JSD (p✓ (x|y = 0)kp✓ (x|y = 1))
• Minimizing the KLD drives b∞ú (ç) to b)±≤± (ç) 5
• By definition: bgºgΩ ç = Eõ(æ) bgºgΩ ç ` • KL bg X ` = 1 ||h∏ X ` = 1 • KL bg X ` = 0 ||h∏ X ` = 0
i
✓=✓0
,
= b∞úøú ç + b)±≤± ç Ω
/ 2
= KL b)±≤± (X)||h∏ X ` = 1 : constant, no free parameters = KL b∞ú (X)||h∏ X ` = 0 : parameter Ç to optimize
∏ • h∏ ç ` = 0 ∝ hìºì ` = 0 ç bgºgΩ ç a
• seen as a mixture of b∞úøú (ç) and b)±≤± ç Ω ∏ • mixing weights induced from hìºì `=0ç a • Drives b∞ú ç ` to mixture of b∞úøú (ç) and b)±≤± (ç) Ω ⇒ Drives b∞ú ç to b)±≤± (ç)
© Petuum,Inc. 136
h
r
r✓ Ep(y) [KL (p✓ (x|y)kq (x|y))]
JSD (p✓ (x|y = 0)kp✓ (x|y = 1))
i
(7) ✓=✓0
,
where KL(·k·) and JSD(·k·) are the KL and Jensen-Shannon Divergences, respectively.
GANs: minimizing KLD We provide the proof in the supplement materials. Eq.(7) offers several insights into the generator
!"7"# $ % = 1 = !()*) ($) learning in GANs.
$
!"7"# $ % = 0 = !./8/ ($) #
0 1 ($|% = 0) • Lemma of c at c have a h 1: The updates i ⇥ ⇤ r✓ Ep✓ (x|y)p(y) log q r = 0 (y|x) = ✓=✓0 h i r r✓ Ep(y) [KL (p✓ (x|y)kq (x|y))] JSD (p✓ (x|y $ = 0)kp✓ (x|y = 1))
!"7"345 $ % = 0 = !.
✓=✓0
/8/345
($)
,
• Minimizing the KLD drives b∞ú (ç) to b)±≤± (ç) 5
• By definition: bgºgΩ ç = Eõ(æ) bgºgΩ ç ` • KL bg X ` = 1 ||h∏ X ` = 1 • KL bg X ` = 0 ||h∏ X ` = 0
= b∞úøú ç + b)±≤± ç Ω
/ 2
= KL b)±≤± (X)||h∏ X ` = 1 : constant, no free parameters = KL b∞ú (X)||h∏ X ` = 0 : parameter Ç to optimize
∏ • h∏ ç ` = 0 ∝ hìºì ` = 0 ç bgºgΩ ç a
• seen as a mixture of b∞úøú (ç) and b)±≤± ç Ω ∏ • mixing weights induced from hìºì `=0ç a • Drives b∞ú ç ` to mixture of b∞úøú (ç) and b)±≤± (ç) Ω ⇒ Drives b∞ú ç to b)±≤± (ç)
© Petuum,Inc. 137
2
x=G✓ (z),z⇠p(z|y=0)
2
x=G✓ (z),z⇠p(z|y=1)
We now take a closer look at the form of Eq.(3) which is essentially reconstructing the real/fake indicator y (or its reverse 1 y) conditioned on x. Further, for each optimization step of p✓ (x|y) at point (✓0 , 0 ) in the parameter space, we have
GANs: minimizing KLD
$
r = 0 = !./8/ ($) !"7"Lemma $ % =11Let = p(y) !()*)be ($)the !uniform "7"# $ % distribution. Let # # p✓0 (x) = Ep(y) [p✓0 (x|y)], and q (x|y) / r q (y|x)p✓0 (x). Therefore, the updates of ✓ at ✓0 have 1• 0Lemma 1 0 ($|% = 0) !"7"345 $ % = 0 = !. 345 ($) h i ⇥ ⇤ /8/ r✓ Ep✓ (x|y)p(y) log q r 0 (y|x) = ✓=✓0 missed mode h i (7) r r✓ Ep(y) [KL (p✓ (x|y)kq (x|y))] JSD (p✓ (x|y = 0)kp✓ (x|y = 1)) ,
$
✓=✓0
where KL(·k·) and JSD(·k·) are the KL and Jensen-Shannon Divergences, respectively. • Missing mode phenomena of GANs KL b∞ú (X)||h ∏ X ` = 0 We provide the proof inoftheKLD supplement materials. Eq.(7) offers several insights into the generator b∞ú X • Asymmetry = û b∞ú X log ∏ ¡X learning in GANs. h X ` = 0 • Concentrates bg ç ` = 0 to large • Resemblance to variational If we treat y as visible and x as latent (as in ADA), it is modes of h ∏ ç inference. ` straightforward to see the connections to the variational inference where contribution q r (x|y) playsto the KLD in the • algorithm Large positive (ç)variationalregions ∞ú ç misses )±≤± the the role of⇒ thebposterior, p✓0 (x) themodes prior, andofp✓b(x|y) distribution approximates of Xthat space where h ∏ X ` = 0 is the• posterior. Optimizing the generator G✓ is equivalent to minimizing the KL divergence small, unless b∞ú X between is also small Symmetry of JSD the variational distribution and the posterior, minus a JSD between the distributions p (x) and g ✓ • ⇒ b X tends to avoid where ∞ • Does not affect the behavior of the connections toú VAEs, as we discussregions pdata (x). The Bayesian interpretation further reveals in h ∏ X ` = 0 is small mode missing the next section. © Petuum,Inc. 138 • Training dynamics. By definition, p✓0 (x) = (pg✓0 (x)+pdata (x))/2 is a mixture of pg✓0 (x) and
r
2
x=G✓ (z),z⇠p(z|y=0)
2
x=G✓ (z),z⇠p(z|y=1)
We now take a closer look at the form of Eq.(3) which is essentially reconstructing the real/fake indicator y (or its reverse 1 y) conditioned on x. Further, for each optimization step of p✓ (x|y) at point (✓0 , 0 ) in the parameter space, we have
GANs: minimizing KLD
Lemma 1 Let p(y) be the uniform distribution. Let p✓0 (x) = Ep(y) [p✓0 (x|y)], and q r (x|y) / q r 0 (y|x)p✓0 (x). Therefore, the updates of ✓ at ✓0 have • Lemmah 1: The updates of ci at ca have ⇥ ⇤ r r✓ Ep✓ (x|y)p(y) log q 0 (y|x) = ✓=✓0 h i (7) r r✓ Ep(y) [KL (p✓ (x|y)kq (x|y))] JSD (p✓ (x|y = 0)kp✓ (x|y = 1)) , ✓=✓0 ∏ compared to GANs (Figure 1(c)), adds Figure 3: (a) Graphical model of discriminator InfoGAN (Eq.9), which, •where NoKL(·k·) assumption on optimal h ` ç respectively. ì and JSD(·k·) are the KL and Jensen-Shannon Divergences, conditional generation of code z with distribution q⌘ (z|x,a y). See the captions of Figure 1 for the • Previous on(b) (near) discriminator meaning ofresults differentusually types of rely arrows. VAEs optimal (Eq.12), which is obtained by swapping the generation We provide the in the supplement materials. Eq.(7)ofoffers several model, insightsswapping into the solid-line generatorarrows and inference processes of InfoGAN, i.e., inçterms the graphical ∗ proof • h ` = 1 ç = b ç /(b + b (ç)) ∞ learning (generative in GANs. process) and)±≤± dashed-line)±≤± arrows (inference) of (a). (c) Adversarial Autoencoder (AAE), • Optimality assumption is impractical: limited expressiveness of tì [Arora et al 2017] which is to obtained by swapping data xIfand (see supplements for more • Resemblance variational inference. we code treat zy in asInfoGAN visible and x the as latent (as in ADA), it isdetails). • Our resulttoissee a the generalization previous theorem [Arjovsky & Bottou 2017] straightforward connections toof thethe variational inference algorithm where q r (x|y) plays the role•ofgeneralization the posterior, p✓the the prior, and p✓into (x|y) the above variational distribution that approximates Plug the optimal discriminator the equation, recover the theorem 0 (x) of previous theorem [1]: plugging Eq.(7) into Eq.(6)we we obtain the posterior. Optimizing the generator G✓ isi equivalent to minimizing the KL divergence between h ⇥ ⇤ 1 between the distributions p (x) and the variational andlog the minus a ✓JSD g✓ ) r✓ distribution Ep✓ (x|y)p(y) q r posterior, (y|x) =r KL (pg✓ kpdata ) JSD (pg✓ kpdata , (8) 0 2 ✓=✓0 ✓=✓0 pdata (x). The Bayesian interpretation further reveals the connections to VAEs, as we discuss in which simplified explanations of the training dynamics and the missingismode issue only when the next• section. Givegives insights on the generator training when discriminator optimal discriminator meets certain optimality Our (x))/2 generalized result enables © Petuum,Inc. 139 • Trainingthe dynamics. By definition, p✓0 (x) = (pg✓0criteria. (x)+pdata is a mixture of pg✓0 understanding (x) and of broader situations. For instance, when the discriminator distribution q 0 (y|x) gives uniform r
GANs: minimizing KLD In summary: • Reveal connection to variational inference • Build connections to VAEs (slides soon) • Inspire new model variants based on the connections
• Offer insights into the generator training • Formal explanation of the missing mode behavior of GANs • Still hold when the discriminator does not achieve its optimum at each iteration
© Petuum,Inc. 140
KL (p✓ (y|x)kq (y|x)) ,
(14)
or KL (p (x|y)kq (x|y)) . Variant of GAN: InfoGAN
(15)
✓
We can see GANs and VAEs (Variational Auto-encoders Kingma & Welling (2013)) as extending the sleep and wake phases, respectively. In particular, VAEs extend the wake phase by minimizing • GANs functionality of inferring code | given data ç Eq. (12) w.r.tdon’t both offer and ✓.the GANs extend ⇣ the sleep phase ⌘ by minimizing Eq.(15) w.r.t , and 0 minimizing the y-switched objective JSD in Eq.(7) w.r.t ✓. • InfoGAN [Chen et al., 2016] KL p✓ (x|y)kq (x|y)
1
• Introduce inference model E¬ (||ç) with parameters √ InfoGAN • Augment the objectives of GANs by additionally inferring | maxD LD = Ex⇠pdata (x) [log D(x)] + Ex⇠G(z),z⇠p(z) [log(1 maxG,Q LG,Q = Ex⇠G(z),z⇠p(z) [log D(x)+ log Q(z|x)] .
D(x))] ,
(16)
!"
GANs
InfoGAN 3
© Petuum,Inc. 141
s of the KL and JSD terms in Eq.(4) cancel out, disabling the learning of generator. Moreover, esian interpretation [Eric: In earlier explanations, you never say it was a "Bayesian etation", and now you suddenly say it was. Please claim Bayesian interpretation where you provided some interpretation.] of our result enables us to discover connections , as we discuss in the next section.
InfoGAN: new formulation
• et Defines conditional h¬ | for ç, `disentangled representation learning which N Chen al. [6] developed InfoGAN ally recovers of)` the latent codewithout z givenfree example x. Thisto can be straightforwardly • h(part (||ç, = 1) is fixed parameters learn ¬ ed in our framework by introducing an extra conditional q⌘ (z|x, y) parameterized by ⌘. As • As GANs assume the code space of real data is degenerated d above, GANs assume a degenerated code space for real examples, thus q⌘ (z|x, y = 1) Parameters are only with h¬to (||ç, 0) InfoGAN is then without free• parameters to ƒlearn, and ⌘associated is only associated y =`0.= The ed by combining q⌘ (z|x, y) with q(y|x) in Eq.(1) to perform full reconstruction of both z • Rewrite in the new form:
max L = Ep✓ (x|y)p(y) [log q⌘ (z|x, y)q (y|x)] ⇥ ⇤ r max✓,⌘ L✓,⌘ = Ep✓ (x|y)p(y) log q⌘ (z|x, y)q (y|x) ,
(5)
he ground-truth z to reconstruct is sampled from the prior p(z|y) and encapsulated in the distribution p✓ (x|y). Let q r (x|z, y) / q⌘0 (z|x, y)q r 0 (y|x)p✓0 (x), the result in the form of ill holds by replacing q r 0 (y|x) with q⌘0 (z|x, y)q r 0 (y|x), and q r (x|y) with q r (x|z, y): © Petuum,Inc. 142 ⇥ ⇥ ⇤ ⇤ r E r E log q (z|x, y)q (y|x) | =
tends to concentrate p✓ (x|y) to large modes of q (x|y) and ignore smaller ls, by learning domain-invariant features [13, 42, 43, 7]. That is, it learns a Arjovsky Bottou [1]between derive a the similar result utput cannot be distinguished by aand discriminator source andof minimizing the KL divergen pdata (x). Our result does not rely on assumptions of (near) optimal discrimina to the practice [2]. Indeed, when the discriminator distribution q 0 (y|x) give tation of ADA, and review conventional formulations in the supplementary ! cancel out, disabling the learning o gradients of the KL and JSD terms in Eq.(4) otational correspondence to other models in the sequel, [Eric: Please"add the Bayesian interpretation [Eric: In earlier explanations, you never say hical model here for ADA.] let z be a data example either in the source interpretation", now you say it was. Please claim Bay {0, 1} be the domain indicator with y = 0and indicating the suddenly target domain earlier where you provided interpretation.] of our result enables us t ain. The data distributions conditioning on the domainsome are then denoted VAEs, asofwe the next section. e prior distribution (e.g.,touniform) thediscuss domaininindicator. The feature entations x = G✓ (z) with parameters ✓. The data distributions over z and Chen etover al. [6] developed for disentangled represen on G✓ together form an InfoGAN implicit distribution x, denoted as p✓InfoGAN (x|y), recovers (part of) the latent code z given example x. This can luate likelihood but easy additionally to sample from: formulated in our framework by introducing an extra conditional q⌘ (z|x, y) p ance of feature x, a discriminator trained to adversarially distinguish code space for real examples discussed is above, GANs assume a degenerated which defines a conditional distribution qfree (y|x) with parameters , and⌘ is only associated to y = 0. is fixed without parameters to learn, and r mized to fool the discriminator. Let q = q (1q (z|x, y|x) be reversed recovered by(y|x) combining y)the with q(y|x) in Eq.(1) to perform full rec ⌘ The objectives of ADA and are therefore given as: y:
GANs vs InfoGAN
max L = Ep✓ (x|y)p(y) [log q (y|x)] ⇥ ⇤ r max✓ L✓ = Ep✓ (x|y)p(y) log q (y|x) ,
max L = Ep✓ (x|y)p(y) [log q⌘ (z|x, y)q (y|x)] ⇥ ⇤ (1) r max✓,⌘ L✓,⌘ = Ep✓ (x|y)p(y) log q⌘ (z|x, y)q (y|x) , © Petuum,Inc. 143
recovered by combining q⌘ (z|x, y) with q(y|x) in Eq.(3) to perform full reconstruction of both z and y: L =E [log q (z|x, y)q (y|x)] InfoGAN:max new formulation ⇤ ⇥ max L = E log q (z|x, y)q (y|x) , ⌘
p✓ (x|y)p(y)
✓,⌘
✓,⌘
⌘
p✓ (x|y)p(y)
(9)
r
where the ground-truth z to is sampled • Similar results asreconstruct in GANs hold: from the prior p(z|y) and encapsulated in the implicit distribution p✓ (x|y). The model is expressed as graphical model in Figure 3(a). Let ∏ ∏ r r ç |, y)q ` ∝(y|x)p h¬º¬Ω✓(||ç, ` the ç bform gºgΩofçEq.(6) still holds by replacing ìºì a in q (x|z,•y)Let / qh⌘0 (z|x, (x), `)h the result 0 0 r r r q r 0 (y|x) • with We qhave: ⌘0 (z|x, y)q 0 (y|x), and q (x|y) with q (x|z, y): r✓
h
h
⇥
r
Ep✓ (x|y)p(y) log q⌘0 (z|x, y)q 0 (y|x)
r✓ Ep(y) [KL (p✓ (x|y)kq r (x|z, y))]
⇤i
✓=✓0
=
JSD (p✓ (x|y = 0)kp✓ (x|y = 1))
i
(10) ✓=✓0
,
AAE/PM/CycleGAN As a side result, the idea of interpreting data space x as latent immediately • Next we show between GANs/InfoGAN and discovers relations betweencorrespondences InfoGAN with Adversarial Autoencoder (AAE) [35] and Predictability Minimization VAEs (PM) [50]. That is, InfoGAN is precisely an AAE that treats the data space x as latent and to be adversarially regularized while the code space z as visible. Figure 3(c) shows the graphical model of AAE obtained by simply swapping x and z in InfoGAN. We defer the detailed © Petuum,Inc. 144 formulations of AAE to the supplementary materials. Further, instead of considering x and z as data
Relates VAEs with GANs • Resemblance of GAN generator learning to variational inference • Suggest strong relations between VAEs and GANs
• Indeed, VAEs are basically minimizing KLD with an opposite direction, and with a degenerated adversarial discriminator degenerated discriminator
swap the generation (solid-line) and inference (dashed-line) processes of InfoGAN
InfoGAN
VAEs
© Petuum,Inc. 145
mily" or "second family"?] of deep generative model learning algorithms. The resembl AN generator learning to variational inference as shown in Eq.(4) suggests strong relations b AEs [25] and GANs. We build correspondence between the two approaches, and show tha e basically minimizing a KL divergence with an opposite direction, with a degenerated adv scriminator.
Recap: conventional formulation of VAEs
he conventional definition of VAEs is written as: • Objective: ⇥ vae max✓,⌘ L✓,⌘ = Epdata (x) Eq˜⌘ (z|x) [log p˜✓ (x|z)]
⇤
KL(˜ q⌘ (z|x)k˜ p(z)) ,
here p˜✓ (x|z) is the generator, q˜⌘ (z|x) the inference network, and p˜(z) the prior over b≈(|): prior over | rameters to• learn are intentionally denoted with the notations of corresponding modules in • b≈gVAEs (ç||): appear generative model t first glance, to differ from GANs greatly as they use only real examples a • h≈¬ (||ç): inference model versarial mechanism. However, our interpretation shows VAEs indeed include a dege • Only uses real b)±≤± (ç), lacks adversarial mechanism versarial discriminator thatexamples blocks outfrom generated samples from contributing to training.
pecifically, againwith introduce thelet’s real/fake variable and assume a perfect`discriminator • Towealign GANs, introduce they,real/fake indicator and hich always predicts ydiscriminator = 1 with probability 1 given real examples, and y = 0 given ge adversarial mples. Again, for notational simplicity, let q⇤r (y|x) = q⇤ (1 y|x) be the reversed distribu © Petuum,Inc. 146
mma 2. Let p✓ (z, y|x) / p✓ (x|z, y)p(z|y)p(y). Therefore, ⇥ ⇤ ae r r KL(q⌘ (z|x, y)q⇤ (y|x)kp(z|y)p(y)) ,⌘ = 2 · Ep✓0 (x) Eq⌘ (z|x,y)q⇤ (y|x) [log p✓ (x|z, y)]
VAEs: new formulation
KL (q⌘ (z|x, y)q⇤r (y|x)kp✓ (z, y|x))] .
(8)
= 2 · Ep✓0 (x) [ • Assume a perfect discriminator h∗ (`|ç) • hcomponents if ç isexact real examples ∗ ` =1 ç =1 e most of the have correspondences (and the same definitions) in GANs • h∗ `1), = 0except ç = 1that if ç the is generated samples InfoGAN (Table generation distribution p✓ (x|z, y) differs slightly from its • h∗∏ in ` Eq.(2) ç ∶= h∗to (1additionally − `|ç) nterpart p✓ (x|y) account for the uncertainty of generating x given z: ⇢ • Generative distribution p✓ (x|z) y = 0 p✓ (x|z, y) = (9) pdata (x) y = 1.
∝closely bg ç3:|, ` bGraphical |to`that b(`) resulting• Let KL bdivergence relates inmodel GANsof(Eq.4) and InfoGAN (a) InfoGAN (Eq.10), (Eq.6), which, with compared t g |, ` ç Figure r Lemma 2. Let p (z, y|x) / p (x|z, y)p(z|y)p(y). Therefore, generative module p (x|z, y) and inference networks q (z|x, y)q (y|x) placed inq⌘the opposite conditional generation of code z with distribution (z|x, y). See the ✓ ✓ ✓ ⌘ • Lemma 2 ctions, vae and with inverted hidden/visible treatments of (z, and x. (b) In section 6, we givewhich a general ⇥meaning of different ⇤ types ofy) arrows. VAEs (Eq.13), is obtained r r L✓,⌘ 2 · difference Ep✓0 (x) E [log KL(q y)q⇤of (y|x)kp(z|y)p(y)) ✓ (x|z, ⌘ (z|x, q⌘ (z|x,y)q ussion that=the between GANs and pVAEs iny)] hidden/visible is relatively ⇤ (y|x) and inference processes of InfoGAN, i.e., intreatments terms the graphical model( or. process) and dashed-line = 2 · Ep✓0 (x) [ (generative KL (q⌘ (z|x, y)q⇤r (y|x)kp ✓ (z, y|x))] .arrows (inference) of (a). (c) Adve which is obtained by swapping data recall x andthat codeforz the in InfoGAN (see the su proof is provided in the supplementary materials. Intuitively, real example © Petuum,Inc. 147
Lemma 2: sketch of proof Lemma 2. Let p✓ (z, y|x) / p✓ (x|z, y)p(z|y)p(y). Therefore, • Lemma 2 ⇥ ⇤ vae r L✓,⌘ = 2 · Ep✓0 (x) Eq⌘ (z|x,y)q⇤r (y|x) [log p✓ (x|z, y)] KL(q⌘ (z|x, y)q⇤ (y|x)kp(z|y)p(y)) = 2 · Ep✓0 (x) [
• Proof
KL (q⌘ (z|x, y)q⇤r (y|x)kp✓ (z, y|x))] .
(8
Here most of the components have$ exact correspondences (and the same definitions) in GAN $ 1) Expand Võexcept . = V generation + Võú (ç|æºa) . y) differs slightly from i and InfoGAN (Table 1), p✓ (x|z, (ç|æº$) . distribution úΩ (ç) that &theõú & Ω Ω counterpart p✓$(x|y) in Eq.(2) to additionally account for the uncertainty of generating x given z: 2) Võú (ç|æºa) . is constant ⇢ & Ω y=0 ✓ (x|z) • Due to the perfect discriminator h∗∏ `pç p✓ (x|z, y) = (9 • Blocks out generated samples in theptraining data (x)lossy = 1. $
$
The resulting 3) KL V divergence . =closely V relates . to that in GANs (Eq.4) and InfoGAN (Eq.6), wi & õúΩ (ç|æº$) & õ≥¥µ¥ ([) the generative• module p✓the (x|z, y) and inference networks q⌘ (z|x, y)q r (y|x) placed in the opposi Recovers conventional formulation © Petuum,Inc. 148 directions, and with inverted hidden/visible treatments of (z, y) and x. In section 6, we give a gener
C Lemme 2 Proof of Lemma 2
Proof. For the reconstruction term: ⇥ ⇤ Ep✓0 (x) Eq⌘ (z|x,y)q⇤r (y|x) [log p✓ (x|z, y)] ⇥ ⇤ 1 = Ep✓0 (x|y=1) Eq⌘ (z|x,y=0),y=0⇠q⇤r (y|x) [log p✓ (x|z, y = 0)] 2 (25) ⇥ ⇤ 1 r + Ep✓0 (x|y=0) Eq⌘ (z|x,y=1),y=1⇠q⇤ (y|x) [log p✓ (x|z, y = 1)] 2 ⇥ ⇤ 1 = Epdata (x) Eq˜⌘ (z|x) [log p˜✓ (x|z)] + const, 2 r where y = 0 ⇠ q⇤ (y|x) means q⇤r (y|x) predicts y = 0 with probability 1. Note that both q⌘ (z|x, y = 1) and p✓ (x|z, y = 1) are constant distributions without free parameters to learn; q⌘ (z|x, y = 0) = q˜⌘ (z|x), and p✓ (x|z, y = 0) = p˜✓ (x|z). For the KL prior regularization term: Ep✓0 (x) [KL(q⌘ (z|x, y)q⇤r (y|x)kp(z|y)p(y))] Z = Ep✓0 (x) q⇤r (y|x)KL (q⌘ (z|x, y)kp(z|y)) dy + KL (q⇤r (y|x)kp(y))
(26) 1 1 = Ep✓0 (x|y=1) [KL (q⌘ (z|x, y = 0)kp(z|y = 0)) + const] + Ep✓0 (x|y=1) [const] 2 2 1 = Epdata (x) [KL(˜ q⌘ (z|x)k˜ p(z))] . 2 Combining Eq.(25) and Eq.(26) we recover the conventional VAE objective in Eq.(7) in the paper. © Petuum,Inc. 149
⇥ y, x now denotes a real example or a generated latent code. For vae sample, z is the respective L✓,⌘ = 2 · Ep✓0 (x) Eq⌘ (z|x,y)q⇤r (y|x) [log p✓ (x|z, y)] KL(q⌘ (z|x, y)q⇤r (y|x) ated sample domain (y = 0), the implicit distribution p✓ (x|y = 0) is defined by the prior of r 2 · Ein KL (q⌘ (z|x, ✓ (z, y|x))] . e generator G✓ (z), which is also denoted as pg=✓ (x) Fory)q the⇤ (y|x)kp real example p✓0the (x) [literature. y = 1), the code space and generator are degenerated, and we are directly presented with a ribution p(x|y = 1), which is just the real data distribution pdata (x). have Noteexact that pcorrespondences Here most of the components (and the same defi data (x) is and InfoGAN (Table except that generation distribution p✓ (x|z, y) diffe mplicit distribution allowing efficient empirical sampling. In 1), summary, thethe distribution over GANs (InfoGAN) VAEs counterpart p✓ (x|y) in Eq.(2) to additionally account for the uncertainty of gene tructed as ⇢ ⇢ Generative pg✓ (x) y=0 p✓ (x|z) y = 0 p (x|z, y) = p (x|y) = (5) ✓ ✓ distribution p pdata (x) y = 1. data (x) y = 1.
GANs vs VAEs side by side
Thepgresulting KL generated divergence sample closely relates to that in GANs (Eq.4) and Info ee parameters ✓ are only associated with (x) of the domain, while Discriminator ✓ r h (`|ç) h (`|ç), perfect, degenerated the generative module p (x|z, y) and inference networks q (z|x, y)q (y|x) pla ì ∗ ✓ ⌘ is constant. As in ADA, discriminator D is simultaneously trained to infer the probability distribution and = with mes from the real data domain. That is, directions, q (y = 1|x) Dinverted (x). hidden/visible treatments of (z, y) and x. In section 6
discussion that the difference between GANs and VAEs in hidden/visible treat |-inference established correspondencehbetween and ADA, we can see that the objectives of minor. of InfoGAN h¬ (||ç, `) ¬ | ç, `GANs model re precisely expressed as Eq.(4) and as The the graphical model ininthe Figure 1(c). To make this Intuitively, recall that fo proof is provided supplementary materials. domain with y = 1, both q⌘ (z|x, y = 1) and p✓ (x|z, y = 1) are constant distri ∏ ∏ ming KL (bg çwith ` fake || hsample ç |,x`generated ) minfrom KL h | ç, ` h || bgdiscriminator |, ` ç p (x), the reversed q⇤r g ∗ ` çperfect ✓0 ¬ 4 KLD to prediction y = 1, making the reconstruction loss on fake samples degenerated to minimize only real examples, where q⇤r predicts y = 0 with probability 1, are effective for ~ ming KL(! g || E) g KL(E || ! g) identical to Eq.(7). We extend VAEs to~min also leverage fake samples in section 4. © Petuum,Inc. 150
VAE/GAN Joint Models Previous work has explored combination of VAEs
ke Sleep Algorithm (WS)
Link back to wake sleep algorithm
iscuss the connections of GANs and VAEs to the classic wake-sleep algorithm [18] which sed for learning deep generative models such as Helmholtz machines [9]. WS consists of se and sleep phase, which optimize the generative network and inference network [Eric: • Denote been using "model" and "network" interchangeably earlier, please stay consistent, • Latent variables » st call both "model".], respectively. We follow the above notations, and introduce new h to denote• general latents Parameters … [Eric: what do you mean by "latents", latent variables?] general •parameters. The wake-sleep algorithm is thus written as: Recap: wake sleep algorithm Wake :
max✓ Eq
Sleep :
max Ep✓ (x|h)p(h) [log q (h|x)]
(h|x)pdata (x)
[log p✓ (x|h)]
(10)
ons between VAEs and WS are clear in previous discussions [3, 25]. Indeed, WS was proposed to minimize the variational lower bound as in VAEs (Eq.7) with sleep phase ation [18]. Alternatively, VAEs can be seen as extending the wake phase. Specifically, if iate h with z and with ⌘, the wake phase objective recovers VAEs (Eq.7) in terms of optimization (i.e., optimizing ✓). Therefore, we can see VAEs as generalizing the wake © Petuum,Inc. 151 also optimizing the inference network q⌘ , with additional prior regularization on latents z.
ss the connections of GANs and VAEs to the classic wake-sleep algorithm [18] which graphical model of AAE obtained byVAEs simply swapping x wake-sleep and z in [9]. InfoGAN. We[18] defer the detailed xtfor discuss the deep connections of GANs andsuch the classic algorithm which learning generative models as to Helmholtz machines WS consists of formulations ofdeep AAE to the supplementary materials. Further, instead of considering x and of z as data oposed for learning generative such as Helmholtz machines [9]. WS consists nd sleep phase, which optimize themodels generative network and inference network [Eric: and code respectively, if we instantiate x and znetwork as data spaces of two modalities, combine phase and sleepspaces phase, optimize the generative and inference network and [Eric: en using "model" andwhich "network" interchangeably earlier, please stay consistent, the objectives of InfoGAN and AAE asinterchangeably a joint model, weearlier, recover please the cycleGAN model [56] which ave been using "model" and "network" stay consistent, call both "model".], respectively. We follow the aboveMore notations, andprovided introduce new performs transformation between the two modalities. details are in the supplements. e just call both "model".], respectively. We follow the above notations, and introduce o denote general latents [Eric: what do you mean by "latents", latent variables?] new ons to denote general latents [Eric: what do you written mean by eralhparameters. The wake-sleep algorithm is thus as:"latents", latent variables?] 3.3• Wake Variational Autoencoders (VAEs) for general parameters. The wake-sleep algorithm is thus written as: sleep algorithm Wake : max✓ Eq (h|x)pdata (x) [log p✓ (x|h)] : max [log p✓ (x|h)] ✓ Eqof (h|x)p We next exploreWake the second family deep generative model learning algorithms. The(10) resemblance of data (x) (10) Sleep : learning max toEvariational [log q (h|x)] p✓ (x|h)p(h) GAN generator inference as shown in Eq.(7) suggests strong relations between
VAEs vs. Wake-sleep
Sleep : max Ep✓ (x|h)p(h) [log q (h|x)] VAEs [28] and GANs. We build correspondence between the two approaches, and show that VAEs Let » be |, anda…KL bedivergence √ are•basically minimizing in an opposite direction, with a degenerated adversarial discriminator. ⇒ max✓ E q⌘ (z|x)p (x) [log p✓ (x|z)] , recovers VAE objective of optimizing c max ✓ Eq⌘data (z|x)pdata (x) [log p✓ (x|z)]
• VAEs extend wake phase by also learning the inference model (ƒ)
between VAEs and WS are clear in previous discussions [3, 25]. Indeed, WS was lations between VAEs and vae WS are clear in previous discussions [3, 25]. Indeed, WS was =E [logbound pas [KL(q (z|x)kp(z))] ✓,⌘ ✓ (x|z)] q⌘ (z|x)p (x) pdata (x) ✓,⌘the posed to minimize theLvariational lower bound in as VAEs (Eq.7) with ⌘with sleep phase datalower ally proposed tomax minimize variational inEVAEs (Eq.7) sleep phase n [18]. Alternatively, VAEs VAEs can becan seen extending the wake phase. Specifically, if if ximation [18]. Alternatively, beas seen as extending the wake phase. Specifically, • Minimize the in phase the original variational free energy wrt. in √ terms of The of written as: recovers h with and with definition ⌘,with theKLD wake objective VAEs (Eq.7) tantiate hzconventional with z and ⌘, theVAEs wakeis phase objective recovers VAEs (Eq.7) in terms of ⇥ ⇤ wake • Stick tooptimizing minimizing theTherefore, wake-phase KLD wrt.VAEs both c and √ the wake vae tor optimization (i.e., ✓). we can see as generalizing the imization (i.e., optimizing ✓). Therefore, we can see VAEs as generalizing max✓,⌘ L✓,⌘ = Epdata (x) Eq˜⌘ (z|x) [log p˜✓ (x|z)] KL(˜ q⌘ (z|x)k˜ p(z)) , (12) • Do not involve sleep-phase objective by also optimizing the inference network q⌘ ,additional with additional regularization on latents optimizing the inference network q⌘ , with prior prior regularization on latents z. z.
where p˜•✓ (x|z) is sleep the generator, q˜⌘ (z|x) the model, p˜(z) the prior z. The Recall: phase minimizes theinference reverse KLD in and the variational freeover energy ehand, otherparameters hand, our interpretation of GANs reveals close resemblance to the sleep phase. To our interpretation ofintentionally GANs reveals close resemblance sleep phase. To in GANs. to learn are denoted with the notationstoofthe corresponding modules his weglance, instantiate with , resulting in aassleep objective identical rer,clearer, we hVAEs withhyappear and ytoand with , resulting ingreatly a sleep phase objective identical Atinstantiate first differwith from GANs they phase use only real examples and lack 152 © Petuum,Inc. adversarial mechanism. However, our interpretation shows VAEs indeed include a degenerated
en using "model" and "network" interchangeably earlier, please stay consistent,r xt discuss connections of GANs VAEs thethe classic wake-sleep which The respectively. discriminator defines atoconditional distribution (y|x) =[18] D (x). Let q (y|x)D=(x))] q (1. max L =and Ex=G [log Dnotations, (x)] +qalgorithm Ex=G [log(1 ✓ (z),z⇠p(z|y=1) call boththe "model".], We follow above and✓ (z),z⇠p(z|y=0) introduce new oposed for learning be deep generative models such as Helmholtz machines [9].ofWS consists of thefeature reversed distribution over domains. objectives ADA are therefore rewritten as (up The extractor G✓ is thenmean trainedby to The fool the discriminator: o denote general latents [Eric: what do you "latents", latent variables?] phase and sleep phase, whichscale optimize constant factorthe 2): generative network and inference network [Eric: neral parameters. The wake-sleep algorithm is thus written as: [log(1 Dplease (x))] +stay Ex=G [log D (x)] . ✓L ✓ = Ex=Ginterchangeably ave been using "model"max and "network" earlier, consistent, ✓ (z),z⇠p(z|y=1) ✓ (z),z⇠p(z|y=0)
GANs vs. Wake-sleep maxtheLabove = Enotations, and [logintroduce q (y|x)] new e just call both "model".], respectively. We follow
p✓ (x|y)p(y) Here we omit the additional loss on ✓ that fits Wake : max✓ Eq (h|x)pdata (x) [log p✓ (x|h)]the features ⇥ tor the data ⇤ label pairs of source dom ons h to denote general [Eric: what materials do you mean by latent variables?] max E"latents", (seelatents the supplementary for details). ✓ Lthe ✓ = p✓ (x|y)p(y) log q (y|x) . (10) Sleep : The max Ep✓ (x|h)p(h) [log (h|x)] for general parameters. wake-sleep algorithm is qthus written as: • Wake sleep algorithm Withabove the background formulation, wethe nowlog frame our newofinterpretation The objectives of canthebeconventional interpreted as maximizing likelihood y (or 1 y) of witA Wake : distribution max [log p✓ (x|h)] ✓ Eq (h|x)p (x)deterministic The data p(z|y) and transformation G✓ together form an implicit distribu “generative distribution” q data (y|x) conditioning on the latent code x inferred (10) by p✓ (x|y). Note th over(but x, which is(h|x)] intractable evaluate likelihood but easy to sample from ✓ (x|y),of Sleep : denoted max Eas [log only critical) difference the qobjectives of ✓tofrom is the replacement of q(y|x) with qr ( p✓p(x|h)p(h) max Ethe [log p(y)is✓bewhere prior distribution ofpthe domain indicator as in Eqs.(1 ✓ (x|z)] q⌘ (z|x)p data (x) This the adversarial mechanism comes about.y, e.g., a uniform distribution • Let » beThe `, discriminator and … be ñdefines a conditional distribution q (y|x) = D (x). Let q r (y|x) = q (1 ⇒ bemax , recovers GAN objective of optimizing ñ Epq✓⌘(x|y)p(y) [logover q (y|x)] the reversed distribution max [log pdomains. ✓E ✓ (x|z)] The objectives of ADA are therefore rewritten as (up (z|x)p data (x) constantsleep scale factor 2): by also max learning L = Ep✓the [log q (y|x)] (x|y)p(y) • GANs extend phase generative model (c) between VAEs and and WSWS are are clear inphase: previous [3, 25]. 25].[log Indeed, WS maxdiscussions L = E (x|y)p(y) [log q (y|x)] lations between VAEs clear in previous discussions [3, Indeed, WS waswas max q (y|x)] . • Directly extending sleep ✓ L✓ = Epp✓✓(x|y)p(y) ⇥ with ⇤phase ally proposed to minimize the variational lower bound as in VAEs (Eq.7) r sleep oposed to minimize the variational lower bound as in VAEs (Eq.7) with sleep max✓ L✓ = Ep✓ (x|y)p(y) log q (y|x) . phase • GANs: ximation [18]. Alternatively, VAEs can be seen asextending extending the1(c) wake phase. Specifically, ifmodel on [18]. Alternatively, VAEs can be seen as wake phase. Specifically, if y (or ∏ the Graphical model representation Figure illustrates the graphical of 1the formul • The only difference is replacing h with h The above objectives can be interpreted as maximizing the log likelihood of y) with ì ì with ⌘, the wake phase objective recovers VAEs (Eq.7) in terms of etantiate h withh zwith andz and with ⌘, the wake phase objective recovers VAEs (Eq.7) in terms of in Eq.(4),adversarial where, in the new view, solid-line arrows denote thex generative da “generative distribution” q (y|x) conditioning on thegeneralizing latent code inferred by process p✓ (x|y).while Note tha • This is where mechanism come about !as tor optimization (i.e., optimizing ✓). Therefore, we can see VAEs the wake rar imization (i.e., optimizing ✓). Therefore, we can see VAEs asfrom generalizing theelements, wakeof q(y|x) line arrows denote the inference process. Weofintroduce new visual e.g., hollow only (but critical) difference of the objectives ✓ is the replacement with q ( • GANs stick to minimizing the sleep-phase KLD by also optimizing the inference network q , with additional prior regularization on latents z. ⌘ for expressing implicit distributions, andprior blue regularization arrows mechanism. As noted a optimizing the inference network q⌘ , with additional on latents z. This is where the adversarial mechanism comes about. for adversarial Dointerpretation not involve wake-phase objective adversarial modeling achieved by swapping between q(y|x) q r (y|x) other hand, •our of GANsisreveals close resemblance to the sleepand phase. To when training respe modules. his clearer, we instantiate h withmodel y and representation with , resulting in a sleep phase objective identicalmodel © Petuum,Inc. Graphical Figure 1(c) illustrates the graphical of 153 the formula
Mutual exchanges of ideas: augment the loss functions
KLD to minimize
GANs (InfoGAN)
VAEs
ming KL (bg ç ` || h∏ ç |, ` ) ~ ming KL(!g || E)
ming KL(h¬ | ç, ` h∗∏ ` ç || bg (|, `|ç))
~ming KL(E || !g )
• Asymmetry of KLDs inspires combination of GANs and VAEs • GANs: ming KL(!g ||E) tends to missing mode 10.1. Variational Inference • VAEs: ming KL(E||!g ) tends to cover regions with small469values of b)±≤± • Augment VAEs with GAN loss [Larsen et al., 2016] • Alleviate the mode covering issue of VAEs • Improve the sharpness of VAE generated images • Augment GANs with VAE loss [Che et al., 2017] • Alleviate the mode missing issue of GANs [Figure courtesy: PRML]
Mode covering (a)
Mode missing (b)
(c)
© Petuum,Inc. 154
Mutual exchanges of ideas: augment the loss functions
KLD to minimize
GANs (InfoGAN)
VAEs
ming KL (bg ç ` || h∏ ç |, ` ) ~ ming KL(!g || E)
ming KL(h¬ | ç, ` h∗∏ ` ç || bg (|, `|ç))
~ming KL(E || !g )
• Asymmetry of KLDs inspires combination of GANs and VAEs • GANs: ming KL(!g ||E) tends to missing mode • VAEs: ming KL(E||!g ) tends to cover regions with small values of b)±≤± • Augment VAEs with GAN loss [Larsen et al., 2016] • Alleviate the mode covering issue of VAEs • Improve the sharpness of VAE generated images • Augment GANs with VAE loss [Che et al., 2017] • Alleviate the mode missing issue of GANs © Petuum,Inc. 155
Mutual exchanges of ideas: augment the graphical model Discriminator distribution
GANs (InfoGAN)
VAEs
hì (`|ç)
h∗ (`|ç), perfect, degenerated
• Activate the adversarial mechanism in VAEs • Enable adaptive incorporation of fake samples for learning • Straightforward derivation by making symbolic analog to GANs !
Vanilla VAEs
Adversary Activated VAEs
© Petuum,Inc. 156
to a minibatch of samples in standard GAN update. Thus the only computational cost added by th importance weighting method is evaluating the weight for each sample, which is generally negligibl The discriminator is trained in the same way as in the standard GANs.
Adversary Activated VAEs (AAVAE)
discussion that the difference between GANs and VAEs in latent/visible treatments is 4.2a general Adversary Activated VAEs (AAVAE) relatively minor.
• Vanilla VAEs:VAEs include a degenerated adversarial discriminator which blocks out generate In our formulation, ⇥ ⇤ samples from of fake samples b vaecontributing to model rlearning. We enable adaptive incorporation r max✓,⌘ L✓,⌘ = Ep✓0 (x) Eq⌘ (z|x,y)q⇤ (y|x) [log p✓ (x|z, y)] KL(q⌘ (z|x, y)q⇤ (y|x)kp(z|y)p(y)) activating the adversarial mechanism. Again, derivations are straightforward by making symboli analog to GANs. The proof ofhLemma 2 iswith provided in the supplementary materials. Intuitively, recall thatñfor the • Replace (`|ç) learnable one h (`|ç) with parameters ∗ ì
example y = 1, bothq⇤q(y|x) y= 1) and p✓ (x|z, y= are constant distributions. Wereal replace the domain perfect with discriminator vanilla VAEs with the1)discriminator network q (y|x ⌘ (z|x,in ∏ • As usual, denote distribution hìthe ` reversed Xobjective = hìperfect ` Therefore, with fake sample x generated from discriminator q⇤r (y|x) ✓0 (x), parameterized with as inreversed GANs, resulting in pan adapted of XEq.(13):
always gives prediction yh= 1, making the reconstruction loss on fake samples degenerated to a i r r1, are effective for aavae constant. Hence only real examples, where q = 0 with ⇤ r p✓ (x|z,yy)] KL(qprobability max✓,⌘ L✓,⌘ = Ep✓0 (x) Eq⌘ (z|x,y)q (y|x) [logpredicts ⌘ (z|x, y)q (y|x)kp(z|y)p(y)) . learning, which is identical to Eq.(12). We extend VAEs to also leverage fake samples in section 4. (22 VAE/GAN Joint Models Previous works have explored combination of VAEs and GANs. This can The of Eq.(22) is precisely symmetricbehaviors to the objective of divergences InfoGAN inthat Eq.(10) with the additiona beform naturally motivated by the asymmetric of the KL the two algorithms KLaim prior regularization. Before analyzing effect ofmodel adding thethat learnable we firs to optimize respectively. Specifically, thethe VAE/GAN [32] improvesdiscriminator, the sharpness of look at generated how the discriminator learned. Inmotivated analog tobyGANs as inthe Eqs.(4) (10),behavior the objective o VAE images can be is alternatively remedying mode and covering of optimizing is obtained replacing q r (y|x) with of q the (y|x): the KL in VAEs. That is,by thesimply KL tends to drive the the inverted generativedistribution model to cover all modes data © Petuum,Inc. 157 distribution as well as regions with small values of pdata , resulting in blurred, implausible samples.
samples from contributing to model learning. We enable adaptive incorporation of fake samples by activating the adversarial mechanism. Again, derivations are straightforward by making literal [Eric maybe the word "symbolic" is better here?] analog to GANs.
AAVAE: adaptive data selection We replace the perfect discriminator q (y|x) in vanilla VAEs with the discriminator network q (y|x) parameterized with
⇤
as in GANs, resulting in an adapted objective of Eq.(8): h
max✓,⌘ Laavae ✓,⌘ = Ep✓0 (x) Eq⌘ (z|x,y)q r (y|x) [log p✓ (x|z, y)]
i
KL(q⌘ (z|x, y)q r (y|x)kp(z|y)p(y)) .
(16)
• An data selection mechanism: The form of effective Eq.(16) is precisely symmetric to the objective of InfoGAN in Eq.(5) with the additional • Both generated samples and real examples are weighted by discriminator, we first KL prior regularization. Before analyzing the effect of adding the learnable ∏ h = 0 ç = hì is `= 1ç ì `discriminator look at how the learned. In analog to GANs as in Eqs.(1) and (5), the objective of samplesby that resembles realthe data and fool the discriminator optimizing• Only is obtained simply replacing inverted distribution q r (y|x) will withbe q used (y|x):
for training h i ∏ aavae • A real example receiving large weight h ` ç ì max L = Ep✓0 (x) Eq⌘ (z|x,y)q (y|x) [log p✓ (x|z, y)] KL(q⌘ (z|x, y)q (y|x)kp(z|y)p(y)) . (17) ⇒ Easily recognized by the discriminator as real Intuitively, the discriminator is trained to distinguish between real and fake instances by predicting ⇒ Hard to be simulated from the generator appropriate y that selects the components of q⌘ (z|x, y) and p✓ (x|z, y) to best reconstruct x. The ⇒ Hard get difficulty of Eq.(17) is examples that p✓ (x|z, y larger = 1) =weights pdata (x) is an implicit distribution which is intractable
for likelihood evaluation. We thus use the alternative objective as in GANs to train a binary classifier © Petuum,Inc. 158
avae
= Ep✓0 (x) Eq⌘ (z|x,y)q
(y|x)
[log p✓ (x|z, y)]
KL(q⌘ (z|x, y)q (y|x)kp(z|y)p(y))
, the discriminator is trained to distinguish between real and fake instances by pred AAVAE: discriminator learning e y that selects the components of q⌘ (z|x, y) and p✓ (x|z, y) to best reconstruct x of Eq.(17) is that p✓ (x|z, y = 1) = pdata (x) is an implicit distribution which is intra ood evaluation. We thus use the alternative objective as in GANs to train a binary cla • Use the binary classification objective as in GAN max L
aavae
= Ep✓ (x|z,y)p(z|y)p(y) [log q (y|x)] .
ted discriminator enables an effective data selection mechanism. First, AAVAE us xamples, but also generated samples for training. Each sample is weighted by the in ator q r (y|x), so that only those samples that resemble real data and successfully fo ator will be incorporated for training. This is consistent with the importance weig n IWGAN. Second, real examples are also weighted by q r (y|x). An example rece ht indicates it is easily recognized by the discriminator, which further indicates the ex be simulated from the generator. That is, AAVAE emphasizes more on harder exam © Petuum,Inc. 159
AAVAE: empirical results • Applied the adversary activating method on • vanilla VAEs • class-conditional VAEs (CVAE) • semi-supervised VAEs (SVAE)
© Petuum,Inc. 160
GAN IWGAN
8.34±.03 5.18±.03 8.45±.04 5.34±.03
CGAN IWCGAN
0.985±.002 0.797±.005 0.987±.002 0.798±.006
SVAE AASVAE
0.9412 0.9768 0.9425 0.9797
Table 2: Left: Inception scores of vanilla GANs and the importance weighted extension. Middle: Classification accuracy of the generations by class-conditional GANs and the IW extension. Right: Classification accuracy of semi-supervised VAEs and the adversary activated extension on the MNIST test set, with varying size of real labeled training examples.
AAVAE: empirical results
• Evaluated test-set variational lower bound on MNIST • The higher the better
the ratio of on training datatest for set. learning 0.1,the 1.)ratio of training data Figure •1:X-axis: Lower bound values the MNIST X-axis (0.01, represents Y-axis:(0.01, value0.1, of and test-set lower boundthe value of lower bound. Solid lines represent used for• learning 1.). Y-axis represents the base models; dashed lines represent the adversary activated models. Left: VAE vs. AA-VAE. Middle: CVAE vs. AA-CVAE. Right: SVAE vs. AA-SVAE, where remaining training© Petuum,Inc. data are161 used
r✓ Lk (y) = Ez1 ,...,zk
"
k X i=1
q r (y|xi ) r✓ log q r (y|x(zi , ✓)) q(y|xi )
AAVAE: empirical results
#
(34)
Experimental Results of SVAE
• Evaluated ble 3 shows the results.
classification accuracy of SVAE and AA-SVAE
SVAE AASVAE
1%
10%
0.9412±.0039 0.9425±.0045
0.9768±.0009 0.9797±.0010
ble 3: Classification accuracy of semi-supervised VAEs and the adversary activated extension on MNIST test•set, with1% varying real labeled examples. Used and size 10%ofdata labelstraining in MNIST
© Petuum,Inc. 162
Mutual exchanges of ideas • AAVAE enhances VAEs with ideas from GANs • We can also enhance GANs with ideas from VAEs • VAEs maximize a variational lower bound of log likelihood • Importance weighted VAE (IWAE) [Burda et al., 2016] • Maximizes a tighter lower bound through importance sampling
• The variational inference interpretation of GANs allows the importance weighting method to be straightforwardly applied to GANs • Just copy the derivations of IWAE side by side with little adaptions! © Petuum,Inc. 163
Importance weighted GANs (IWGAN) • Generator learning in vanilla GANs
⇥ ⇤⇥ ⇤ r r max✓ Ex⇠p✓ (x|y)p(y) log✓q(x|y)p(y) (y|x) log . q (y|x) . max✓ Ex⇠p 0 0
• Generator learning in IWGAN X
X) q r 0 (y|x q rr0 (y|xi ) i k r max✓ Ex1 ,...,xmax log q (y|x ) . E log q (y|xi ) . i ⇠p (x|y)p(y) ✓ x1 ,...,xk ⇠p✓i=1 k ✓ (x|y)p(y) 0 q 0 (y|xi )i=1 q 00(y|xi ) k
•= Assigns to samples thatfor are more realistic and fool ANs, As only 0 (i.e., generated samples) is effective the ✓.the Intuin yGANs, onlyhigher y = 0weights (i.e., generated samples) islearning effective forparameters learning the parameters discriminator better he algorithm assigns higherassigns weights to those samples that are more that realistic and fool the and itively, the algorithm higher weights to those samples are more realistic natordiscriminator better, which better, is consistent that to emphasizes on codemore stateson providing which to is IWAE consistent IWAE thatmore emphasizes code states pr constructions. In practice, theInkpractice, samples the in Eq.(15) correspond to correspond a minibatchtoofasamples in of sam better reconstructions. k samples in Eq.(15) minibatch © Petuum,Inc. 164 GANstandard update. GAN Thus the onlyThus computational cost added by cost the importance weighting method update. the only computational added by the importance weighting
IWGAN: empirical results • Applied the importance weighting method to • vanilla GANs • class-conditional GANs (CGAN) • CGAN adds one dimension to code k to represent the class label • The derivations of the IW extension remain the same as in vanilla GANs
© Petuum,Inc. 165
IWGAN: empirical results • Evaluated on MNIST and SVHN • Used pretrained NN to evaluate: • Inception scores of samples from GANs and IW-GAN • Confidence of a pre-trained classifier on generated samples + diversity of generated samples MNIST GAN IWGAN
SVHN
8.34±.03 5.18±.03 8.45±.04 5.34±.03
MNIST CGAN IWCGAN
SVHN
1%
0.985±.002 0.797±.005 0.987±.002 0.798±.006
SVAE AASVAE
0.941 0.942
2: Left: of Inception scoresfrom of vanilla GANs and IW-CGAN the importance weighted extension • ClassificationTable accuracy samples CGAN and Classification accuracy of the generations by class-conditional GANs and the IW extensio
GAN IWGAN
Classification accuracy of semi-supervised VAEs and the adversary activated extension on th MNIST SVHN MNISTsize ofSVHN 1% 10% test set, with varying real labeled training examples. 8.34±.03 5.18±.03 CGAN 0.985±.002 0.797±.005 SVAE 0.9412 0.9768 8.45±.04 5.34±.03 IWCGAN 0.987±.002 0.798±.006 AASVAE 0.9425 0.9797 © Petuum,Inc. 166
Table 2: Left: Inception scores of vanilla GANs and the importance weighted extension. Middle:
Recap: Variational Inference Maximize the variational lower bound ℒ c, ñ; ç , or equivalently, minimize free energy
s c, Å; ç = −log b ç + Q(hì | ç || bc (||ç)) • E-step: maximize ℒ wrt. Å with c fixed maxì ℒ c, ñ; ç = Vrí (ö|[) log bg X k
+ Q(hì k X ||b(k))
• If with closed form solutions ∗ hì (k|X) ∝ exp[log bg (X, k)]
• M-step: maximize ℒ wrt. c with Å fixed maxg ℒ c, ñ; ç = Vrí k X log bg X k
+ Q(hì k X ||b(k))
© Petuum,Inc. 167
Discussion: Modeling latent vs. visible variables • Latent and visible variables are traditionally distinguished clearly and modeled in very different ways • A key thought in the new formulation: • Not necessary to make clear boundary between latent and visible variables, • And between inference and generation • Instead treat them as a symmetric pair
© Petuum,Inc. 168
Symmetric modeling of latent & visible variables • Help with modeling and understanding: • Treating the generation space ç in GANs as latent • reveals the connection between GANs and ADA • provides an variational inference interpretation of generation Treat generation of X as performing inference
Inference on features
ADA
GANs
© Petuum,Inc. 169
Symmetric modeling of latent & visible variables • Help with modeling and understanding: • Treating the generation space ç in GANs as latent • reveals the connection between GANs and ADA • provides an variational inference interpretation of generation • Wake sleep algorithm • wake phase reconstructs visible variables based on latents • sleep phase reconstructs latent variables based on visibles • latent and visible variables are treated in a completely symmetric way Wake: maxc Erí(||ç) log bc (ç, |) Sleep: maxñ Eõú(|,ç) log hì | ç
© Petuum,Inc. 170
Symmetric modeling of latent & visible variables • New modeling approaches narrow the gap Empirical distributions over visible variables
Prior distributions over latent variables
• Impossible to be explicit distribution • The only information we have is the observe data examples • Do not know the true parametric form of data distribution
• Traditionally defined as explicit distributions, e.g., Gaussian prior distribution • Amiable for likelihood evaluation • We can assume the parametric form according to our prior knowledge
• Naturally an implicit distribution • Easy to sample from, hard to evaluate likelihood
• New tools to allow implicit priors and models • GANs, density ratio estimation, approximate Bayesian computations • E.g., adversarial autoencoder [Makhzani et al., 2015] replaces the Gaussian prior of vanilla VAEs with implicit priors © Petuum,Inc. 171
Symmetric modeling of latent & visible variables • No difference in terms of formulations • with implicit distributions and black-box NN models • just swap the symbols X and k
| ∼ bõ∏0À∏ (|) ç ∼ ÉÃÕ±ŒèÜÃÀ[ |
ç ∼ b)±≤± (ç) | ∼ É′ÃÕ±ŒèÜÃÀ[ ç
prior distr. + ∼ $4/5/ (+) " ∼ ,′-./012-(3 +
Generation model " ∼ $%&'(& (") + ∼ ,-./012-(3 "
Inference model
data distr.
© Petuum,Inc. 172
Figure 5: Symmetric view of generation and inference. There is little difference of the two processes in terms of formulation: with implicit distribution modeling, both processes only need to perform simulation through black-box neural transformations between the latent and visible spaces.
Symmetric modeling of latent & visible variables
activating the adversary mechanism on the VAE models. We see that the adversary activated models consistently outperform their respective base models. Generally, larger improvement can be obtained with smaller set of real training data. Table 2, right panel, further shows the classification accuracy of semi-supervised VAE and its adversary activated variants with different size of labeled training data. We can observe improved performance of the AA-SVAE model. The full results of standard deviations are reported in the supplementary materials.
• No difference in terms of formulations 6
Discussions
• with implicit distributions and black-box NN models Our new interpretations of GANs and VAEs have revealed strong connections between them, and linked the emerging new approaches to the classic wake-sleep algorithm. The generality of the proposed formulation offers a unified statistical insight of the broad landscape of deep generative modeling, and encourages mutual exchange of improvement ideas across research lines. It is depend on the problem at hand interesting to further generalize the framework to connect to other learning paradigms such as reinforcement learning as previous works have started exploration [14, 44]. GANs simultaneously choose appropriate tools: learn a metric (defined by the discriminator) to guide the generator learning, which resembles the iterative teacher-student distillation framework [23, 24] where a teacher… network is simultaneously • implicit/explicit distribution, adversarial/maximum-likelihood optimization, learned from structured knowledge (e.g., logic rules) and provides knowledge-informed learning signals for student networks of interest. It will be intriguing to build formallikelihood connections between maximum loss adversarial loss prior distr. prior distr. these approaches and enable incorporation of structured knowledge in deep generative modeling. :
• Difference in terms of space complexity • •
maxC log $ "%&'(& D
zprior
Generation model
Generation model
+789
adversarial loss
:
Symmetric view of generation and inference Traditional modeling approaches usually distinprior distr. guish between latent and visible variables clearly and treat them Inference in very different ways. One of the Inference data distr. key thoughts in our formulation is that it is not necessary to make model clear boundary between the two model types of variables (and between generation and inference), but instead, treating them as a symmetric pair helps with modeling and understanding. For instance, we treat the generation space x in GANs as +&8/. + latent, which immediately reveals the connection between GANs and adversarial domain adaptation, and provides a variational inference interpretation of the generation. A second example is the classic max> log $ +&8/. B Petuum,Inc. wake-sleep algorithm, where the wake phase reconstructs visibles ©conditioned on173 latents, while the data distr. data distr. maximum likelihood losssleep phase reconstructs latents conditioned on visibles (i.e., generated samples). Hence, visible and
Part-II: Conclusions
Z Hu, Z YANG, R Salakhutdinov, E Xing, “On Unifying Deep Generative Models”, arxiv 1706.00550
• Deep generative models research have a long history • Deep blief nets / Helmholtz machines / Predictability Minimization / …
• Unification of deep generative models • GANs and VAEs are essentially minimizing KLD in opposite directions • Extends two phases of classic wake sleep algorithm, respectively
• A general formulation framework useful for • Analyzing broad class of existing DGM and variants: ADA/InfoGAN/Joint-models/…
• Inspiring new models and algorithms by borrowing ideas across research fields
• Symmetric view of latent/visible variables • No difference in formulation with implicit prior distributions and black-box NN transformations • Difference in space complexity: choose appropriate tools © Petuum,Inc. 174
Plan • Statistical And Algorithmic Foundation and Insight of Deep Learning
• On Unified Framework of Deep Generative Models
• Computational Mechanisms: Distributed Deep Learning Architectures © Petuum,Inc. 175
Part-III (1) Inference and Learning
© Petuum,Inc. 176
Outline • Deep Learning as Dataflow Graphs • Auto-differentiable Libraries
© Petuum,Inc. 177
Outline • Deep Learning as Dataflow Graphs • Auto-differentiable Libraries
© Petuum,Inc. 178
A Computational Layer in DL • A layer in a neural network is composed of a few finer computational operations • A layer œ has input X and output k, and transforms X into k following: ` = –X + M, k = —“Q”(`) • Denote the transformation of layer œ as ÉÕ , which can be represented as a dataflow graphs: the input X flow though the layer X
k ÉÕ
© Petuum,Inc. 179
From Layers to Networks • A neural network is thus a few stacked layers œ = 1, … , Q, where every layer represents a function transform ÉÕ • The forward computation proceeds by sequentially executing É$ , É& , É' , … , É‘ É$
É&
É'
⋯
É‘
• Training the neural network involves deriving the gradient of its parameters with a backward pass (next slides)
© Petuum,Inc. 180
A Computational Layer in DL • Denote the backward pass through a layer œ as MÕ
• MÕ derives the gradients of the input X(dX),given the gradient of k as dk, as well as the gradients of the parameters W, b • dX will be the backward input of its previous layer œ − 1 • Backward pass can be thought as a backward dataflow where the gradient flow through the layer ¡X
¡k MÕ
© Petuum,Inc. 181
Backpropagation through a NN • The backward computation proceeds by sequentially executing M‘ , M‘Ü$ , M‘Ü& , … , M$ M$
M&
⋯
M‘
© Petuum,Inc. 182
A Layer as a Dataflow Graph • Give the forward computation flow, gradients can be computed by auto differentiation • Automatically derive the backward gradient flow graph from the forward dataflow graph
Photo from TensorFlow website
© Petuum,Inc. 183
A Network as a Dataflow Graph • Gradients can be computed by auto differentiation • Automatically derive the gradient flow graph from the forward dataflow graph É$
É&
M$
M&
⋯
⋯
É‘
M‘
Photo from TensorFlow website
© Petuum,Inc. 184
Gradient Descent via Backpropagation • The computational workflow of deep learning • Forward, which we usually also call inference: forward dataflow • Backward, which derives the gradients: backward gradient flow • Apply/update gradients and repeat Backward
• Mathematically,
Model parameters
Forward
Data
© Petuum,Inc. 185
Program a neural network • Define a neural network • Define operations and layers: fully-connected? Convolution? Recurrent? • Define the data I/O: read what data from where? • Define a loss function/optimization objective: L2 loss? Softmax? Ranking Loss? • Define an optimization algorithm: SGD? Momentum SGD? etc
• Auto-differential Libraries will then take over • Connect operations, data I/O, loss functions and trainer. • Build forward dataflow graph and backward gradient flow graphs. • Perform training and apply updates © Petuum,Inc. 186
Outline • Deep Learning as Dataflow Graphs • Auto-differentiable Libraries
© Petuum,Inc. 187
Auto-differential Libraries • Auto-differential Library automatically derives the gradients following the backpropagation rule. • A lot of auto-differentiation libraries have been developed: • So-called Deep Learning toolkits
© Petuum,Inc. 188
Deep Learning Toolkits • They are adopted differently in different domains • For example
Vision
NLP © Petuum,Inc. 189
Deep Learning Toolkits • They are also designed differently • Symbolic v.s. imperative programming
Imperative
Symbolic
© Petuum,Inc. 190
Deep Learning Toolkits • Symbolic vs. imperative programming • Symbolic: write symbols to assemble the networks first, evaluate later • Imperative: immediate evaluation
Symbolic
Imperative
© Petuum,Inc. 191
Deep Learning Toolkits • Symbolic • Good
• easy to optimize (e.g. distributed, batching, parallelization) for developers • More efficient
• Bad
• The way of programming might be counter-intuitive • Hard to debug for user programs • Less flexible: you need to write symbols before actually doing anything
• Imperative: • Good
• More flexible: write one line, evaluate one line • Easy to program and easy to debug: because it matches the way we use C++ or python
• Bad
• Less efficient • More difficult to optimize
© Petuum,Inc. 192
Deep Learning Toolkits • They are also designed differently • For another example, dataflow graphs v.s. layer-by-layer construction
Layer-by-layer construction
Dataflow graphs
© Petuum,Inc. 193
Good and Bad of Dataflow Graphs • Dataflow graphs seems to be a dominant choice for representing deep learning models • What’s good for dataflow graphs • • • •
Good for static workflows: define once, run for arbitrary batches/data Programming convenience: easy to program once you get used to it. Easy to parallelize/batching for a fixed graph Easy to optimize: a lot of off-the-shelf optimization techniques for graph
• What‘s bad for dataflow graphs
• Not good for dynamic workflows: need to define a graph for every training sample > overheads • Hard to program dynamic neural networks: how can you define dynamic graphs using a language for static graphs? (e.g. LSTM, tree-LSTM). • Not easy for debugging. • Difficult to parallelize/batching across multiple graphs: every graph is different, no natural batching.
© Petuum,Inc. 194
Static vs. Dynamic Dataflow Graphs • Static Dataflow graphs • Define once, execute many times • For example: convolutional neural networks
• Execution: Once defined, all following computation will follow the defined computation • Advantages • No extra effort for batching optimization, because it can be by nature batched • It is always easy to handle a static computational dataflow graphs in all aspects, because of its fixed structure • Node placement, distributed runtime, memory management, etc.
• Benefit the developers
© Petuum,Inc. 195
Static vs. Dynamic Dataflow Graphs • Dynamic Dataflow graphs • When do we need? • In all cases that static dataflow graphs do not work well
• • • • •
Variably sized inputs Variably structured inputs Nontrivial inference algorithms Variably structured outputs Etc.
© Petuum,Inc. 196
Static vs. Dynamic Dataflow Graphs • Can we handle dynamic dataflow graphs? Using static methods (or declaration) will have a lot of problems • Difficulty in expressing complex flow-control logic • Complexity of the computation graph implementation • Difficulty in debugging
© Petuum,Inc. 197
Introducing DyNet • Designed for dynamic deep learning workflow, e.g. • Tree-LSTM for neural machine translation, where each sentence defines a structure that corresponds to the computational flow • Graph-LSTM for image parsing, where each image has a specific connection between segments • etc.
© Petuum,Inc. 198
Key Ingredients in DyNet • Concept • Separate parameter declaration and graph construction • Declare trainable parameters and construct models first • Parameters, e.g. the weight matrices in an LSTM unit. • Construct a model as a collection of trainable parameters
• Construct computation graphs • Allocate a few nodes for our computation (node can be seen as layers in NN) • Specify the dataflow graph by connecting nodes together • If necessary, different graphs for different input samples
• Conclusion: Define parameter once, but define graphs dynamically depending on inputs
© Petuum,Inc. 199
Key Ingredients in DyNet • Backend and programing model • Graph construction • In TensorFlow, constructing a graph has a considerable overhead. • TensorFlow users avoid defining graphs repeatedly
• DyNet: highly optimized graph definition • Little overhead defining a graph: good for dynamic neural networks. • Easy to write recursive programs to define graphs (very effective for many dynamic networks, such as tree-LSTM or graph-LSTM).
© Petuum,Inc. 200
Key Ingredients in DyNet • A visual comparison
DyNet TreeLSTM (30 LoC)
TensorFlow TreeLSTM (200 LoC) © Petuum,Inc. 201
Part-III (2) Distributed Deep Learning
© Petuum,Inc. 202
Outline • Overview: Distributed Deep Learning on GPUs • Challenges 1: Addressing the communication bottleneck • Challenges 2: Handling the limited GPU memory
© Petuum,Inc. 203
Review – DL toolkits on single machine • Using GPU is a must • A small number of GPU-equipped machines could achieve satisfactory speedup compared to CPU clusters with thousands of cores
• A cluster of 8 GPU-equipped machines • A cluster of 2000 CPU cores
More readily available to researchers
© Petuum,Inc. 204
Review – DL toolkits on single machine • However, using a single GPU is far from sufficient • average-sized deep networks can take days to train on a single GPU when faced with 100s of GBs to TBs of data • Demand faster training of neural networks on ever-larger datasets
AlexNet, 5 – 7 days
GoogLeNet, 10+ days
• However, current distributed DL implementations (e.g. in TensorFlow) can scale poorly due to substantial parameter synchronization over the network (we will show later) © Petuum,Inc. 205
Outline • Overview: Distributed Deep Learning on GPUs • Challenges 1: Addressing the communication bottleneck • Challenges 2: Handling the limited GPU memory
© Petuum,Inc. 206
Challenges • Communication challenges
• GPUs are at least one order of magnitude faster than CPUs
GPU are faster
High Comm Load
Limited network bandwidth
Bursty Communication
bottleneck
Low GPU utilization Poor Scalability with additional machines
• High communication load raises the network communication as the main bottleneck given limited bandwidth of commodity Ethernet • Managing the computation and communication in a distributed GPU cluster often complicates the algorithm design © Petuum,Inc. 207
Let’s see what causes the problem • Deep Learning on a single node – an iterative-convergent formulation Backward
Model parameters
Forward
Data
Apply gradients Zhang et al., 2017
© Petuum,Inc. 208
Let’s see what causes the problem • Deep Learning on a single node – an iterative-convergent formulation Backward
Forward
Data
Forward and backward are the main computation (99%) workload of deep learning programs. © Petuum,Inc. 209
Distributed Deep Learning • Distributed DL: parallelize DL training using multiple machines. • i.e. we want to accelerate the heaviest workload (in the box) to Backward multiple machines
Forward
Data
Forward and backward are the main computation (99%) workload of deep learning programs. © Petuum,Inc. 210
Data parallelism with stochastic gradient descent • We usually seek a parallelization strategy called data parallelism, based on SGD • We partition data into different parts • Let different machines compute the gradient updates on different data partitions • Then aggregate/sync. Data
Worker 1 Worker 2
Data
Sync (one or more machines)
Data
Worker 3 Worker 4
Data
© Petuum,Inc. 211
Data Parallel SGD • Data parallel stochastic gradient descent • Data-parallelism requires every worker to have read and write access to the shared model parameters Ç, which causes In total P workers communication among workers;
Data partition p Collect and aggregate before application, where communication is required Zhang et al., 2015, Zhang et al. 2017
Happening locally on each worker
© Petuum,Inc. 212
How to communicate • Parameter server, e.g. Bosen, SSP • A parameter server (PS) is a shared memory system that provides a shared access for the global model parameters Ç
• Deep learning can be trivially data-parallelized over distributed workers using PS by 3 steps: • Each worker computes the gradients (¢L) on their own data partition (tõ ) and send them to remote servers; • servers receive the updates and apply (+) them on globally shared parameters; • Each worker pulls back the updated parameters (Ç_ÿ) Ho et al., 2013, Wei et al. 2015
© Petuum,Inc. 213
How PS works Worker 1
Worker 2
¢Ç$
¢Ç&
Ç
Ç
PS Ç
Ç ¢Ç'
¢Ç2
Worker 3
Ho et al., 2013, Wei et al. 2015, Zhang et al., 2015, Zhang et al. 2017
Worker 4
© Petuum,Inc. 214
Parameter Server • Parameter server has been successful for CPU-based deep learning • Google Distbelief, Dean et al. 2012 • Scale up to thousands of CPU machines and 16000 CPU cores
• SSPTable, Ho et al, 2013 • Stale-synchronous parallel consistency model
• Microsoft Adam, Chilimbi et al. 2014 • 63 machines, state-of-art results on ImageNet 22K • Bosen, Wei et al. 2015 • Managed communication
© Petuum,Inc. 215
Parameter Server on GPUs • Directly applying parameter server for GPU-based distributed deep learning will underperform (as will show later). • GPU is too fast • Ethernet bandwidth is limited, and has latency
• For example • AlexNet: 61.5M float parameters, 0.25s/iteration on Geforce Titan X (batchsize = 256) • Gradient generation rate: 240M float/(s*GPU)
• Parallelize it over 8 machines each w/ one GPU using PS. • To ensure the computation not blocked on GPU (i.e. linear speed-up with additional nodes) • As a worker: send 240M floats/s and pull back 240M floats/s (at least) • As a server: receive 240M * 8 floats/s and send back 240M * 8/s (at least) Zhang et al., 2015, Zhang et al. 2017
© Petuum,Inc. 216
Parameter Server on GPUs • Let’s see where we are This is what the GPU workstation in you lab has
Ethernet standards
One of the most expensive instances AWS could provide you (18$/h?)
Specialized hardware! Noncommodity anymore, inaffordable
© Petuum,Inc. 217
Parameter Server on GPUs The problem is more severe than described above • We only use 8 nodes (which is small). How about 32,128, or even 256? • We haven’t considered other issues (which might be also troublesome), e.g. • Memory copy between DRAM and GPU will have a non-trivial cost • The Ethernet might be shared with other tasks, i.e. available bandwidth is even less. • Burst communication happens very often on GPUs (which will explain later).
© Petuum,Inc. 218
Address the Communication Bottleneck • A simple fact: • Communication time may be reduced, but cannot be eliminated (of course)
• Therefore, possible ideas to address the communication bottleneck • Hide the communication time by overlapping it with the computation time • Reduce the size of messages needed to be communications
© Petuum,Inc. 219
Address the Communication Bottleneck • A simple fact: • Communication time may be reduced, but cannot be eliminated (of course).
• Therefore, possible ideas to address the communication bottleneck • Hide the communication time by overlapping it with the computation time • Reduce the size of messages needed to be communications
© Petuum,Inc. 220
Overlap Computation and Communication • Revisit on a single node the computation flow of BP • MÕ : backpropagation computational through layer l • Ÿ≤ : forward and backward computation at iteration t M$
ٲ
ÿ
Zhang et al., 2015, Zhang et al. 2017
M&
Ÿ≤∂$ Ÿ≤∂&
M‘
⋯
⋯
⁄
© Petuum,Inc. 221
Overlap Computation and Communication • On multiple nodes, when communication is involved • Introduce two communication operations • • • •
€Õ : send out the gradients in layer œ to the remote mÕ : pull back the globally shared parameters of layer œ from the remote ‹≤ : the set €Õ ‘Õº$ at iteration t ›≤ : the set mÕ ‘Õº$ at iteration t €Õ mÕ
‘ Õº$
M$
‘ Õº$
ٲ
Zhang et al., 2015, Zhang et al. 2017
M&
‹≤
›≤
M‘
⋯
Ÿ≤∂$
‹≤∂$ ›≤∂$
Computation and communication happen sequentially!
© Petuum,Inc. 222
Overlap Computation and Communication • Note the following independency • The send-out operation €Õ is independent of backward operations • The read-in operation mÕ could update the layer parameters as long as MÕ was finished, without blocking the subsequent backward operations M0 (m < œ)
• Idea: overlap computation and communication by utilizing concurrency • Pipelining the updates and computation operations
© Petuum,Inc. 223
WFBP: Wait-free backpropagation • Idea: overlap computation and communication by utilizing concurrency • Pipelining the updates and computation operations ‘ Õº$
€Õ mÕ
M$
M&
‘ Õº$
M‘
⋯
reschedule
€$
€& M$
m$ Zhang et al., 2015, Zhang et al. 2017
m&
€' M& m'
€‘ M‘
⋯
m‘ © Petuum,Inc. 224
WFBP: Wait-free backpropagation • Idea: overlap computation and communication by utilizing concurrency • Communication overhead is hidden under computation • Results: more computations in unit time Ÿ≤
‹≤
›≤
Ÿ≤∂$
‹≤∂$ ›≤∂$
pipelining
ٲ
Ÿ≤∂$
‹≤
‹≤∂$ ›≤
ÿ Zhang et al., 2015, Zhang et al. 2017
›≤∂$
Ÿ≤∂& ‹≤∂& ›≤∂&
Ÿ≤∂' ‹≤∂' ›≤∂'
⁄ © Petuum,Inc. 225
WFBP: Distributed Wait-free backpropagation • How does WFBP perform? • Using Caffe as an engine:
50% comms bottleneck reduction
Zhang et al. 2017
• Using TensorFlow as engine:
© Petuum,Inc. 226
WFBP: Distributed Wait-free backpropagation • Observation: Why DWBP would be effective • More statistics of modern CNNs
Params/FLOP distribution of modern CNNs
• 90% computation happens at bottom layers • 90% communication happens at top layers • WFBP overlaps 90% and 90%
Zhang et al., 2015, Zhang et al. 2017
© Petuum,Inc. 227
WFBP: Wait-free Backpropagation • Does overlapping communication and computation solve all the problems? • When communication time is longer than computation, no (see the figure below). • Say, if communication and computation are perfectly overlapped, how many scalability we can achieve?
Single node
ٲ
Distributed
gap
ٲ ܲ ݲ
Zhang et al., 2015, Zhang et al. 2017
© Petuum,Inc. 228
Address the communication bottleneck • Note a simple fact: • Communication time may be reduced, but cannot be eliminated (of course).
• Therefore, possible ideas to address the communication bottleneck • Hide the communication time by overlapping it with the computation time – which we have described before. • Reduce the size of messages needed to be communications • While without compromising statistical convergence
© Petuum,Inc. 229
Introducing Sufficient Factor Broadcasting • Matrix-parametrized models Multiclass Logistic Regression Feature dim.
Distance Metric Learning Feature dim.
#classes
Sparse Coding
Latent dim.
Neural Network #neurons in layer fl − ‡
Feature dim.
Dictionary size
#neurons in layer fl © Petuum,Inc. 230
Distributed Learning of MPMs • Learning MPMs by communicating parameter matrices between server and workers • Dean and Ghemawat, 2008; Dean et al, 2012; Sindhwani and Ghoting, 2012; Gopal and Yang, 2013; Chilimbi et al, 2014, Li et al, 2015
• High communication cost and large synchronization delays Multiclass Logistic Regression
Neural Network (AlexNet) #neurons in layer fc6=4096
Feature dim. = 20K
26G
#classes=325K
200M
#neurons in layer fc7 =4096
© Petuum,Inc. 231
Contents: Sufficient Factor (SF) Updates Full parameter matrix update ΔW can be computed as outer product of two vectors uvT (called sufficient factors) •
Example: Primal stochastic gradient descent (SGD)
1 min W N
N
å f (Wa ; b ) + h(W ) i
i =1
DW = uv T u = •
i
i
¶f (Wai , bi ) v = ai ¶ (Wai )
Example: Stochastic dual coordinate ascent (SDCA)
1 min Z N
N
åf i =1
* i
(- zi ) + h* (
1 ZAT ) N
DW = uv T u = Dzi v = ai
Send lightweight SF updates (u,v), instead of expensive full-matrix ΔW updates! © Petuum,Inc. 232
Sufficient Factor Broadcasting: P2P Topology + SF Updates
Xie et al., 2015
© Petuum,Inc. 233
A computing & communication tradeoff • Full update:
Training examples Individual update matrices Aggregated update matrix
• Pre-update
Training examples Sufficient vectors
• Stochastic algorithms
Sum
·$ , G$
·& , G&
·' , G'
·2 , G2
Cannot be aggregated
• Mini-batch: C samples
Matrix Representation SV Representation
‹(‚) ‹( ‚ + Ÿ)
© Petuum,Inc. 234
Synchronization of Parameter Replicas Transfer SVs instead of ΔW
parameter server
• A Cost Comparison Size of one message Number of messages
Network Traffic
P2P SV-Transfer
‹(‚ + )
‹(!& )
‹((‚ + )!& )
Parameter Server
‹(‚)
‹(!)
‹(‚!)
© Petuum,Inc. 235
Convergence Speedup
Multiclass Logistic Regression (MLR)
Distance Metric Learning (DML)
Sparse Coding (SC)
• 3 Benchmark ML Programs • Big parameter matrices with 6.5-8.6b entries (30+GB), running on 12- & 28machine clusters
• 28-machine SFB finished in 2-7 hours • Up to 5.6x faster than 28-machine PS, 12.3x faster than 28-machine Spark
• PS cannot support SF communication, which requires decentralized storage Xie et al., 2015
© Petuum,Inc. 236
Convergence Guarantee • Assumptions • Bridging model • Staleness Synchronous Parallel (SSP) with staleness parameter „ • Bulk Synchronous Parallel is a special case of SSP when „=0
• Communication methods • Partial broadcast (PB): sending messages to a subset of E (E < ! − 1) machines • Full broadcast is a special case of PB when E = ! − 1
• Additional assumptions
© Petuum,Inc. 237
Convergence Guarantee • Results
© Petuum,Inc. 238
Convergence Guarantee • Take-home message: • Under full broadcasting, given a properly-chosen learning rate, all local worker parameters –õŒ eventually converge to stationary points (i.e. local minima) of the objective function, despite the fact that SV transmission can be delayed by up to „ iterations. • Under partial broadcasting, the algorithm converges to a ‹(Qù(! − E)) neighbourhood if Ÿ ⟶ ∞.
© Petuum,Inc. 239
Parameter Storage and Communication Paradigms Centralized Storage
Decentralized Storage
Server Send change ΔW
Send W itself
Worker
Worker Send change ΔW
Send change ΔW
Worker
• Centralized: send parameter W itself from server to worker • Advantage: allows compact comms topology, e.g. bipartite
• Decentralized: always send changes ΔW between workers • Advantage: more robust, homogeneous code, low communication (?) © Petuum,Inc. 240
Topologies: Master-Slave versus P2P?
Master-slave • Used with centralized storage paradigm • Disadvantage: need to code/manage clients and servers separately • Advantage: bipartite topology is commsefficient • Popular for Parameter Servers: Yahoo LDA, Google DistBelief, Petuum PS, Project Adam, Li&Smola PS, …
P2P • Used with decentralized storage • Disadvantage (?): high comms volume for large # of workers • Advantage: same code for all workers; no single point of failure, high elasticity to resource adjustment • Less well-explored due to perception of high communication overhead?
© Petuum,Inc. 241
Hybrid Updates: PS + SFB • Hybrid communications: Parameter Server + Sufficient Factor Broadcasting • Parameter Server: MasterSlave topology • Sufficient factor broadcasting: P2P topology
• For problems with a mix of large and small matrices, • Send small matrices via PS • Send large matrices via SFB
Zhang et al., 2015, Zhang et al. 2017
© Petuum,Inc. 242
Hybrid example: CNN
Hao Zhang, Zhiting Hu, Jinliang Wei, Pengtao Xie, Gunhee Kim, Qirong Ho, Eric P. Xing. Poseidon: A System Architecture for Efficient GPU-based Deep Learning on Multiple Machines. USENIX ATC 2016.
• Example: AlexNet CNN model • Final layers = 4096 * 30000 matrix (120M parameters) • Use SFB to communicate • 1. Decouple into two 4096 vectors: u, v • 2. Transmit two vectors • 3. Reconstruct the gradient matrix
Figure from Krizhevsky et al. 2012
Zhang et al., 2015, Zhang et al. 2017
© Petuum,Inc. 243
Hybrid example: CNN
Hao Zhang, Zhiting Hu, Jinliang Wei, Pengtao Xie, Gunhee Kim, Qirong Ho, Eric P. Xing. Poseidon: A System Architecture for Efficient GPU-based Deep Learning on Multiple Machines. USENIX ATC 2016.
• Example: AlexNet CNN model • Convolutional layers = e.g. 11 * 11 matrix (121 parameters) • Use Full-matrix updates to communicate • 1. Send/receive using Master-Slave PS topology
Figure from Krizhevsky et al. 2012
Zhang et al., 2015, Zhang et al. 2017
© Petuum,Inc. 244
Hybrid Communication • Idea
• Sync FC layers using SFB • Sync Conv layer using PS
• Effectiveness
• It directly reduces the size of messages in many situations
• Is SFB always optimal?
• No, its communication load increases quadratically • The right strategy: choose PS whenever it results in less communication
© Petuum,Inc. 245
Hybrid Communication • A best of both worlds strategy • For example, AlexNet parameters between FC6 and FC7 • Tradeoff between PS and SFB communication
Zhang et al., 2015
© Petuum,Inc. 246
Hybrid Communication • How to choose? Where is the threshold? • Determine the best strategy depending on • • • •
Layer type: CONV or FC? Layer size Batch size # of Cluster nodes
Zhang et al., 2015, Zhang et al. 2017
© Petuum,Inc. 247
Hybrid Communication • Hybrid communication algorithm Determine the best strategy depending on • Layer type: CONV or FC? • Layer size: M, N • Batch size: K • # of Cluster nodes: !$ , !&
Zhang et al., 2015, Zhang et al. 2017
© Petuum,Inc. 248
Hybrid Communication • Results: achieve linear scalability across different models/data with 40GbE bandwidth • Using Caffe as an engine:
• Using TensorFlow as engine Improve over WFBP
Zhang et al., 2015, Zhang et al. 2017
© Petuum,Inc. 249
Hybrid Communication • Linear scalability on throughput, even with limited bandwidth! • Make distributed deep learning affordable
# parameters
5M
Zhang et al., 2015, Zhang et al. 2017
143M
229M
© Petuum,Inc. 250
Hybrid Communication • Discussion: Utilizing SFs is not a new idea, actually • Microsoft Adam uses the third strategy (c)
PS
SFB
push: SFs Pull: full matrices
© Petuum,Inc. 251
Hybrid Communication • Adam’s strategy leads to communication bottleneck
• Pushing SFs to server is fine • Pulling full matrices back will create a bottleneck on the server node.
• Hybrid communication yields communication load balancing • Which is important to address the problem of burst communication.
© Petuum,Inc. 252
Introducing Poseidon • Poseidon: An efficient communication architecture • A distributed platform to amplify existing DL toolkits
toolkits
platform
Poseidon © Petuum,Inc. 253
Poseidon’s position • Design principles • Efficient distributed platform for amplifying any DL toolkits • Preserve the programming interface for any high-level toolkits • i.e. distribute the DL program without changing any line of code
• Easy deployment, easy adoption.
© Petuum,Inc. 254
Poseidon System Architecture data flow allocation instruction
GPU CPU
KV Store
KV Store
Synceri
SFB
Coordinator Stream Pool Zhang et al., 2015, Zhang et al. 2017
Thread Pool © Petuum,Inc. 255
Poseidon APIs • KV Store, Syncer and Coordinator • Standard APIs similar to parameter server • Push/Pull API for parameter synchronization • BestScheme method to return the best communication method
Zhang et al., 2015, Zhang et al. 2017
© Petuum,Inc. 256
Amplify DL toolboxes Using Poseidon • For developers: plug Poseidon API into the backpropagation code, all you need to do is: • Back propagate through layer œ • Sync parameters of layer œ • Wait for finishing
• Amplifying Google TensorFlow • 250 line of code
• Amplifying Caffe • 150 line of code
Zhang et al., 2015, Zhang et al. 2017
© Petuum,Inc. 257
Using Poseidon • Poseidon: An efficient communication architecture • Preserve the programming interface for any high-level toolkits • i.e. distribute the DL program without changing any line of application code
toolkits
platform
Poseidon
© Petuum,Inc. 258
Outline • Overview: Distributed Deep Learning on GPUs • Challenges 1: Addressing the communication bottleneck • Challenges 2: Handling the limited GPU memory
© Petuum,Inc. 259
What is the Issue • Memory • GPUs have dedicate memory • For a DL training program to be efficient, its data must be placed on GPU memory • GPU memory is limited, compared to CPU, e.g. maximally 12Gb • Memcpy between CPU and GPU is expensive – a memcpy takes the same time as launching a GPU computation kernel
• Problems to be answered • How to Avoid memcpy overhead between CPU and GPU? • How to proceed the training of a gigantic network with very limited available memory? © Petuum,Inc. 260
A Machine w/o GPU CPU cores ...
Network NIC
Local storage
DRAM (CPU memory)
© Petuum,Inc. 261
A Machine w/ GPU CPU cores ...
Network NIC
Local storage
GPU device GPU cores DRAM (CPU memory)
GPU memory (a few GB)
Small GPU memory Expensive to copy between GPU/CPU mem © Petuum,Inc. 262
Machine Learning on GPU Staging memory for input data batch
Input data file (training data)
a mini-batch of training data Input data
Intermediate data
Parameter data
CPU memory
GPU memory
© Petuum,Inc. 263
Deep Learning on GPU Class probabilities Training batch
parameters GPU memory
Eagle
Vulture
Osprey
Accipiter
Intermediate states
© Petuum,Inc. 264
Training batch
Numbers
parameters GPU memory
Max available GPU memory: 12G
Intermediate states
Network
Batch size
Input size
Parameters + grads
Intermediat e states
AlexNet
256
150MB
<500M
4.5G
GoogLeNet
64
19MB
<40M
10G
VGG19
16
10MB
<1.2G
10.8G © Petuum,Inc. 265
Why Memory is an Issue? • Intermediate states occupy 90% of the GPU memory • Intermediate states is proportional to input batch size
• However, • If you want high throughput, you must have large batch size (because of the SIMD nature of GPUs) • If you have large batch size, your GPU will be occupied by intermediate states, which thereby limits your model size/depth © Petuum,Inc. 266
Saving Memory: A Simple Trick • Basic idea
• The fact: intermediate states are proportional to the batch size K • Idea: achieve large batch size by accumulating gradients generated by smaller batch sizes which are affordable in the GPU memory
• Solution: • Parition K into M parts, every part has K/M samples • For iter = 1:M • Train with mini-batchsize K/M • Accumulate the gradient on GPU w/o updating model parameters
• Update the model parameter all together when all M parts finished
• Drawbacks
• What if the GPU still cannot afford the intermediate states even if K=1? • Small batch size usually leads to insufficient use of GPUs’ computational capability © Petuum,Inc. 267
Memory Management using CPU Memory • Core ideas • If the memory is limited, trade something for memory • Trade extra computations for memory • Trade other cost (e.g. memory exchange) for more available memory
• If the memory is limited, then get more • model parallel • CPU memory
© Petuum,Inc. 268
Memory Management using CPU Memory Class probabilities
• For each iteration (minibatch) • A forward pass • Then a backward pass
• Each time only data of two layers are used Training images
Cui et al., 2016
© Petuum,Inc. 269
Memory Management using CPU Memory Class probabilities
• For each iteration (minibatch) • A forward pass • Then a backward pass
• Each time only data of two layers are used Training images
Cui et al., 2016
© Petuum,Inc. 270
Memory Management using CPU Memory Class probabilities
• For each iteration (minibatch) • A forward pass • Then a backward pass
• Each time only data of two layers are used Training images
Cui et al., 2016
© Petuum,Inc. 271
Memory Management using CPU Memory Class probabilities
• For each iteration (minibatch) • A forward pass • Then a backward pass
• Each time only data of two layers are used Training images
Cui et al., 2016
© Petuum,Inc. 272
Memory Management using CPU Memory Class probabilities
• For each iteration (minibatch) • A forward pass • Then a backward pass
• Each time only data of two layers are used Training images
Cui et al., 2016
© Petuum,Inc. 273
Memory Management using CPU Memory Class probabilities
• For each iteration (minibatch) • A forward pass • Then a backward pass
• Each time only data of two layers are used Training images The idea • Use GPU mem as a cache to keep actively used data • Store the remaining in CPU memory Cui et al., 2016
© Petuum,Inc. 274
Memory Management using CPU Memory
Very expensive, sometimes more expensive than computation
Staging memory for input data batch
CPU/GPU data transfer
Input data file (training data)
Input data
Intermediate data parameters
CPU memory Cui et al., 2016
GPU memory © Petuum,Inc. 275
Memory Management using CPU Memory
Controller/Scheduler to alleviate/hide this overhead
Staging memory for input data batch
CPU/GPU data transfer
Input data file (training data)
Input data
Intermediate data parameters
CPU memory Cui et al., 2016
GPU memory © Petuum,Inc. 276
Memory Management using CPU Memory • Controller • The fact: the memory access order is deterministic and can be exactly known by a single forward and backward pass • Idea: • Obtain the memory access order by a virtual iteration • Pre-fetch memory blocks from CPU to GPU • Overlap memory swap overhead with computation
© Petuum,Inc. 277
Memory Management using CPU Memory • What’s the best we can do with this strategy • We only need 3 memory blocks (peak size) on GPU for: • Input, Parameters, Output
• The whole training can process with ONLY these three blocks by • Scheduling memcpy between CPU and GPU to be overlapped with computation • Move in and out for each layer’s computation as training proceeds
peak
Cui et al., 2016
© Petuum,Inc. 278
Throughput vs. memory budget All data in GPU memory Only buffer pool in GPU memory Twice the peak size for double buffering
• Only 27% reduction in throughput with 35% memory • Can do 3x bigger problems with little overhead Cui et al., 2016
© Petuum,Inc. 279
Larger models
• Models up to 20 GB
Cui et al., 2016
© Petuum,Inc. 280
Summary • Deep learning as dataflow graphs • A lot of auto-differentiation libraries have been developed to train NNs • Different adoption, advantages, disadvantages • DyNet is a new framework for next-wave dynamic NNs
• Difficulties arise when scaling up DL using distributed GPUs • Communication bottleneck • Memory limit
• Poseidon as a platform to support and amplify different kinds of DL toolboxes
© Petuum,Inc. 281
Elements of Modern AI Data Task
Model
Algorithm Implementation
• Graphical Models
• Large-Margin
• Deep Learning
• Sparse Coding
• Nonparametric Bayesian Models
• Regularized Bayesian Methods
• Spectral/Matrix Methods
• Sparse Structured I/O Regression
• Stochastic Gradient Descent / Back propagation • Mahout (MapReduce)
• Coordinate Descent
• Mllib (BSP)
• L-BFGS
• Gibbs Sampling
• CNTK
• MxNet
MPI
RPC
• MetropolisHastings
• Tensorflow (Async)
…
System Hadoop
Platform and Hardware
• Network switches • Infiniband
Spark • Network attached storage • Flash storage
• Server machines • Desktops/Laptops • ARM-powered devices • Mobile devices • GPUs
GraphLab • RAM • Flash • SSD
• IoT device networks (e.g. Amazon EC2)
… • Virtual machines
© Petuum,Inc. 282
Sys-Alg Co-design Inside! Data Task
Model Our “VML” Software Layer
Algorithm Implementation System
Platform and Hardware
© Petuum,Inc. 283
Better Performance • Fast and Real-Time
• Any Scale
• Orders of magnitude faster than Spark and TensorFlow
• Perfect straight-line speedup with more computing devices
• As fast as hand-crafted systems
• Spark, TensorFlow can slow down with more devices
• Low Resource • Turning a regular cluster into a super computer: • Achieve AI results with much more data, but using fewer computing devices • Google brain uses ~1000 machines whereas Petuum uses ~10 for the same job
Up to 20x faster deep learning vs TensorFlow
Speedup vs
Speedup
Time taken (minutes)
Up to 200x faster on some ML algorithms
Spark
HandCrafted System
PetuumOS
Number of GPU computers
© Petuum,Inc. 284
A Petuum Vision Data Task
Model
• Graphical Models • Nonparametric Bayesian Models
Algorithm Implementation
• Stochastic Gradient Descent / Back propagation • Mahout (MapReduce)
• Large-Margin
• Deep Learning
• Sparse Coding
• Regularized Bayesian Methods
• Omni-Source • Spectral/Matrix Methods (Any Data)
• Sparse Structured I/O Regression
• Coordinate Descent
• Mllib (BSP)
• L-BFGS
• CNTK
System Hadoop
Platform and Hardware
• Network switches • Infiniband
Spark • Network attached storage • Flash storage
MPI
• Gibbs Sampling
• Metropolis-
Hastings • Omni-Lingual (Any Programming Language) … • MxNet • Tensorflow (Async) • Omni-Mount (Any Hardware) RPC GraphLab …
• Server machines • Desktops/Laptops • ARM-powered devices • Mobile devices • GPUs
• RAM • Flash • SSD
• IoT device networks (e.g. Amazon EC2)
• Virtual machines
© Petuum,Inc. 285