Deep Generative Models - Columbia...
Transcript of Deep Generative Models - Columbia...
Deep Generative Models
STAT G8201: Deep Generative Models 1 / 34
Part III.2: adding (discrete)structure
STAT G8201: Deep Generative Models 2 / 34
Discrete structure
Second idea (from last time): hack reparameterization into a continuous relaxation:
I Generally, assume a parameterization z = g(φ, ε)parameters φ, noise variable ε
I Continuous relaxation:
zτ = softmax
(g(φ, ε)
τ
), lim
τ→0zτ , ...
I Apparently this is differentiable, has a tractable likelihood, etc.
Then we used Gumbel variables:
I Things became clean and (more) closed form
I The parameters π (or α) had interpretable meaningthough apparently if I’m a deep learning zealot I don’t care...
I But what did we sacrifice? Does that matter?
Today: relaxations that play with this tension (on more elaborate discrete objects)
STAT G8201: Deep Generative Models 3 / 34
Discrete structure
Haven’t we completed discrete structure?
I We learned how to use Gumbel to reparameterize to ∆n−1
I Isn’t that complete?
I permutations of m items are discrete objects of size n = m!I subsets of m items are discrete objects of size n =
(mk
)I So, yes... but naive approaches will get intractable quickly
Some discrete objects have special structure
I Special structure can imply a geometry more convenient to relaxation
I Example: permutations
I write any permutation as an m×m permutation matrixI observe all row and column sums equal 1I suggests an O(m2) (or really O(m)) relaxation, vs m!I (but what do we lose?)
Today: relaxations that play with this special structure
STAT G8201: Deep Generative Models 4 / 34
Updates
DGM website
I Lecture slides
I Presentation tex template
I Reading
I Presenter schedule
I Also for discussion: upcoming topics and papers (eg, NLP day?)
Projects
I This week: please identify key publications (literature review)
I Do this in journal.md or begin a .bib in doc
I If you are still entirely uncertain (or are waiting), let’s talk
STAT G8201: Deep Generative Models 5 / 34
Learning permutations: examples
Figure: Example 1: Input X̃imapping P→ Output Xi
Example 2: Input X̃i = (1, 2, 3, 4, 5)mapping P→ Output Xi(3.2, 4.6, 1.2, 4.9, 1.8)
I Goal: Infer the permutation mapping P : X → X .
I Method: Permutation regression
Xi = P (X̃i) + εi.
where P is the latent variable.
STAT G8201: Deep Generative Models 6 / 34
Infer Permutation mapping P
Figure: Example 1: Input X̃imapping P→ Output Xi
I The permutation mapping P has only finite # of possible values.
⇒ Discrete latent variable! (Each possible value is a category.)
I Finite # of possible values = N !
⇒ Too many categories! Can not use Gumbel softmax. :-(
STAT G8201: Deep Generative Models 7 / 34
Permutation mappings
Figure: Permutation mapping as matrices
I Each permutation mapping P is a matrix multiplication.
I The matrices have structures:
I Each entry is 0 or 1.I row sum = Column sum = 1.
STAT G8201: Deep Generative Models 8 / 34
Relaxing permutation mappings
Figure: Permutation mapping as matrices
I The permutation matrix has structures:
I Each entry is 0 or 1.I row sum = Column sum = 1.
I Continuous relaxation (Birkhoff polytope; doubly stochastic)
I Now only N2 dimensional space to worry about!
STAT G8201: Deep Generative Models 9 / 34
Sampling from the relaxed space BN
Figure: The Birkhoff polytope: relaxed space for permutation matrices
I How to reparametrize the permutation matrix?
→ How to map arbitraty matrices into the Birkhoff polytope?
Solution Ease Rigor Empirical
Stick-breaking (Linderman et al.) A− B BRounding (Linderman et al.) A C B+Sinkhorn network (Mena et al.) B A A
Table: Three solutions
STAT G8201: Deep Generative Models 10 / 34
Stick-breaking I (Linderman et al.)
I Idea: fill in the entries one by one; watch out the constraints
I First row is easy: row sum = 1
x1n = β1n(1−n−1∑k=1
xik), n = 2, . . . , N − 1 (1)
x1N = 1−N−1∑n=1
x1n, (2)
where β1n ∈ [0, 1] (e.g. Beta r.v.)
STAT G8201: Deep Generative Models 11 / 34
Stick-breaking II (Linderman et al.)
I From second row on is tricky: row sum = col sum = 1.
I We can’t sample too large: the remaining row entries are > 0.
xmn ≤ 1−n−1∑k=1
xmk, xmn ≤ 1−m−1∑k=1
xkn,
I We can’t sample too small: the remaining row entries areconstrained by their columns too!
1−n∑k=1
xmk︸ ︷︷ ︸remaining stick
≤N∑
j=n+1
(1−m−1∑k=1
xkj)︸ ︷︷ ︸remaining upper bounds from columns
STAT G8201: Deep Generative Models 12 / 34
Stick-breaking III (Linderman et al.)
I Constraints on each entry:
xmn ≤ 1−n−1∑k=1
xmk,
1−n∑k=1
xmk︸ ︷︷ ︸remaining stick
≤N∑
j=n+1
(1−m−1∑k=1
xkj)︸ ︷︷ ︸remaining upper bounds from columns
I Figure out the lower and upper bound lmn ≤ xmn ≤ umn.I We fill in xmn = lmn + βmn(umn − lmn) where βmn ∈ [0, 1].
STAT G8201: Deep Generative Models 13 / 34
Stick-breaking IV (Linderman et al.)
I Grades: Ease A-, Rigor B, Empirical B
I Why?
I We are unconstrained in the first row but very constrainedin the last row.
I But all entries are born equal!
STAT G8201: Deep Generative Models 14 / 34
Rounding I (Linderman et al.)
Figure: The rounding solution
I Idea: Start with real matrices
→ Map to doubly stochastic matrices
→ Add noise
→ round to permutation matrix
I We only need the map to be differentiable.
STAT G8201: Deep Generative Models 15 / 34
Rounding II (Linderman et al.)
Figure: The rounding solution
I Step 1: real matrices M ∈ RN×N+
I Step 2: map M ∈ RN×N+ to a doubly stochastic M̃
I Goal: row sum = col sum = 1I Recursively normalize rows and columnsI This step is differentiable.
STAT G8201: Deep Generative Models 16 / 34
Rounding III (Linderman et al.)
Figure: The rounding solution
I Step 3: Add noise to M̃ : Ψ = M̃ + V · Z, where Z is someGaussian noise. (Location/scale shift; differentiable)
I Step 4: Round: find the nearest permutation matrix
I Can be done efficiently with the Hungarian algorithm.
round(Ψ) = argmaxP∈BN < P,Ψ >F (3)
where < A,B >F= tr(A>B). (differentiable a.e.)STAT G8201: Deep Generative Models 17 / 34
Rounding IV (Linderman et al.)
Figure: The rounding solution
I Grades: Ease A, Rigor C, Empirical B+
I Why?
I Rounding sounds hacky. It is not everywhere differentiable.
STAT G8201: Deep Generative Models 18 / 34
Sinkhorn Network I (Mena et al.)
Figure: The Sinkhorn network solution
I Idea: Start with real matrices
→ Map to doubly stochastic matrices
→ Smoothly round to permutation matrix
I Similar to rounding but less hacky.
I Mimics what we did with softmax.
STAT G8201: Deep Generative Models 19 / 34
Sinkhorn Network I (Mena et al.)
Figure: The Sinkhorn network solution
I Step 1: Map real matrices to doubly stochastic
I Sinkhorn operator: recursively normalize rows and columns
S0(X) = exp(X), Sl(X) = τc(τr(Sl−1(X)))
S(X) = liml→∞
Sl(X)
I Cf. Softmax (normalization+limit) limτ→0 exp(xi/τ)/∑j=1 exp(xj/τ)
STAT G8201: Deep Generative Models 20 / 34
Sinkhorn Network II (Mena et al.)
Figure: The Sinkhorn network solution
I Step 2: Smoothly round X to permutation matrix
S(X/τ) = argmaxP∈BN < P,X >F +τ · h(P ),
where the entropy h(P ) = −∑i,j Pij log(Pij); < A,B >F= tr(A>B).
I Cf. Softmax-to-category v∗ = argmaxv∈S < x, v >
STAT G8201: Deep Generative Models 21 / 34
Sinkhorn Network III (Mena et al.)
Figure: Permutation and discrete optimal transport
I Step 2: Smoothly round X to permutation matrix
S(X/τ) = argmaxP∈BN < P,X >F +τ · h(P ),
I Rigor: M(X) = limτ→0+ S(X/τ); differentiable (cf. rounding)I Entropy regularization speeds up computationI Connections to discrete optimal transport
STAT G8201: Deep Generative Models 22 / 34
Sinkhorn Network III (Mena et al.)
Figure: The Sinkhorn network solution
I Grades: Ease B, Rigor A, Empirical A
I Why?
I We have a theorem.I Connections to optimal transport is cool.
STAT G8201: Deep Generative Models 23 / 34
Empirical comparison
Figure: Empirical performance of the solutions
I Sinkhorn ≈ Rounding > Stick-breaking
STAT G8201: Deep Generative Models 24 / 34
Differentiable Subset Sampling (DSS)
I Goal: “Extending” Gumbel-max trick to subsets of {x1, . . . , xn} of size k.
I Issue: Naive use of Gumbel-max trick results in a simplex with(nk
)vertices. Not
mentioned on the paper, but this is a key motivation for this work.
I Note: Sampling a permutation and discarding everything but the first k elements isequivalent to sampling a subset. This direction is not explored on the paper.
I Gumbel-max samples continuous noise and then does a discrete transformation(argmax). The trick is relaxing the discrete transformation with a “similar”continous one (softmax).
I A subset can be represented as an n-dimensional binary vector with exactly k ones:this can be relaxed to a vector with entries in (0, 1) whose sum is k.
STAT G8201: Deep Generative Models 25 / 34
DSS: Plackett-Luce
A random subset S = {xi1 , . . . , xik} follows a Placket-Luce (PL) distribution withparameter w = (w1, . . . , wn) ∈ Rn+ if:
P (S) =wi1Z
wi2Z − wi1
· · · wikZ −
∑k−1j=1 wij
where Z =∑nj=1 wi. The following procedure samples from this distribution:
1. Sample u1, . . . , un ∼ U(0, 1).
2. Compute ri = u1/wii .
3. Output {xi1 , . . . , xik} where i1, . . . , ik are the indices corresponding to the klargest values in (r1, . . . , rn).
Can we relax the last step?
STAT G8201: Deep Generative Models 26 / 34
DSS: Why does the sampling work?
Full proof is not conceptually complicated but is cumbersome (citation 11). Considerthe very simplified case where n = 2 and k = 1:
P (r1 ≤ r2) =
∫ 1
0
Fr1(t)fr2(t)dt =
∫ 1
0
tw1w2tw2−1dt = w2
∫ 1
0
tw1+w2−1dt
=w2
w1 + w2
The full proof proceeds similarly, using induction.
STAT G8201: Deep Generative Models 27 / 34
DSS: Relaxing the last step
First, note that we only care about the ordering of (r1, . . . , rn) and − log(− log(·)) is anincreasing transformation, so we can user̂i = − log(− log(ri)) = − log(− log(ui)) + log(wi) instead of ri, which is exactlyGumbel noise.
STAT G8201: Deep Generative Models 28 / 34
DSS: Relaxing the last step
First, we go through a non-differentiable random procedure with a temperatureparameter t. As t→ 0+, the procedure becomes deterministic and outputs the rightanswer. The procedure has k steps, and at each step, one element is added to thesubset.
I Input: r̂i for i, . . . , n (I think these should be shifted).
I Output: Binary vector a = (a1, . . . , an) with exactly k ones.
1. Initialize α1i := log(r̂i) (paper typo, this is why I think they should be shifted).
2. For j = 1, . . . , k:
I Sample a one-hot vector aj = (aj1, . . . , ajn) from a categorical distribution
with parameter softmax(αj/t) (this is the element that will be added to thesubset).
I Update αj+1i = αji + log(1− aji ) for i = 1, . . . , n (this makes sure that we do
not add the same element more than once).
3. Compute a =∑kj=1 a
j .
Finally, this procedure can be relaxed and made deterministic even for t > 0 by replacingaji with its expectation, P (aji = 1).
STAT G8201: Deep Generative Models 29 / 34
DSS: Experiments
I Toy experiment to verify that, for a given w, the relaxed top k procedure works.
I Variable selection problem: maxS⊆PkI(XS , Y ), which is relaxed to:
maxE
I(XS , Y ) s.t. S ∼ E
where E is a distribution over subsets (I believe the text has an error, E is not asubset). Since the mutual information cannot be computed, a lower bound ismaximized instead.
I k “nearest” neighbors, where instead of using the k nearest neighbors, k neighborsare randomly chosen.
STAT G8201: Deep Generative Models 30 / 34
DSS: Some Issues
I Proposition 2 states that the relaxed procedure preserves the ordering of r̂ in a ift ≥ 1. While this is fine, the interesting behavior is when t→ 0+. Also, this meansthat the largest k elements of r̂ and a match, but it does not mean that the k-hotvector with the coordinates of the k largest elements of r̂ is close to a.
I Figure 1 compares the subset recovered by the k largest elements of a for differentvalues of t against the true distribution. By proposition 2, when t ≥ 1, thesedistributions match. The only interesting thing to observe on the figure is thatwhen t < 1, the distribution is not far from the correct one (this point is not madeon the paper). Again, it would be interesting to compare a’s against the k-hotvectors corresponding to subsets from the true distribution.
STAT G8201: Deep Generative Models 31 / 34
Stochastic Optimization of Sorting Networksvia Continuous Relaxations
I Object of interest are sorting permutations, i.e. given a vector s, sort(s) is thepermutation that orders s in descending order.
I Instead of using sort(s), the permutation matrix Psort(s) is used (analogous toone-hot or k-hot representations).
I Permutation matrices are relaxed to unimodal row stochastic matrices, i.e.,matrices U with non-negative elements, whose rows sum to one, and with theproperty that ui = argmaxjU [i, j] forms a valid permutation u = (u1, . . . , un).
I They show that, if As is a matrix containing the paiwise distances of the elementsof s, then:
Psort(s)[i, j] =
{1 if j = argmax(n+ 1− 2i)s−AsIn0 otherwise
I The above argmax is relaxed to a softmax (and a temperature is added). Thisresults in unimodal row stochastic matrices.
I This allows to differentiate through sorting. It also provides a way to sample fromPL distributions: Psort(log s+ g), where g is Gumbel noise (s is equivalent to where).
STAT G8201: Deep Generative Models 32 / 34
Final thoughts
I The reparametrization trick was first developed for continuous variables.
I By relaxing a discrete space into a continuous one in such a way that the discreteobject can be easily recovered from the continuous one (e.g. vertices on thesimplex), the reparametrization trick can still be used (e.g. Gumbel-max trick).
I If the discrete set is very large naive Gumbel-max trick might not be tractable butif there is some structure (e.g. subsets and permutations), all hope is not lost.
I What about the discrete infinite case?
STAT G8201: Deep Generative Models 33 / 34
Final thoughts: A simple idea for thediscrete infinite case
I β1, β2, . . . are noise samples in (0, 1).
I β̃1, β̃2, . . . are obtained from β by a stick breaking procedure, so that∑∞i=1 β̃i = 1.
I The resulting sample s is given by s = (softmax(β̃:N/τ), 0, 0, . . . ), where β̃:N arethe first N coordinates of β̃.
I Notes:
1. The simplex relaxation does not work here, as we cannot store infinitely manynumbers. If we restrict ourselves to categorical distributions with finitelymany non-zero probabilities it works though (not the same as the finite case).
2. Any sensible choice of N must involve β̃, but this can be ignored whendifferentiating, some ideas are:
I The smallest N s.t.∑∞i=N+1 β̃i < ε.
I The smallest N s.t. the m-th largest value of β̃:N is bigger than∑∞i=N+1 β̃i
(this would ensure that the m largest values of β̃ are in β̃:N ).
3. The stick breaking part is needed so that N can be adequately chosen.
STAT G8201: Deep Generative Models 34 / 34