Deep Generative Models - Columbia...

34
Deep Generative Models STAT G8201: Deep Generative Models 1 / 34

Transcript of Deep Generative Models - Columbia...

Page 1: Deep Generative Models - Columbia Universitystat.columbia.edu/~cunningham/teaching/GR8201/STAT_GR8201_2019… · STAT G8201: Deep Generative Models 4 / 34. Updates DGM website I Lecture

Deep Generative Models

STAT G8201: Deep Generative Models 1 / 34

Page 2: Deep Generative Models - Columbia Universitystat.columbia.edu/~cunningham/teaching/GR8201/STAT_GR8201_2019… · STAT G8201: Deep Generative Models 4 / 34. Updates DGM website I Lecture

Part III.2: adding (discrete)structure

STAT G8201: Deep Generative Models 2 / 34

Page 3: Deep Generative Models - Columbia Universitystat.columbia.edu/~cunningham/teaching/GR8201/STAT_GR8201_2019… · STAT G8201: Deep Generative Models 4 / 34. Updates DGM website I Lecture

Discrete structure

Second idea (from last time): hack reparameterization into a continuous relaxation:

I Generally, assume a parameterization z = g(φ, ε)parameters φ, noise variable ε

I Continuous relaxation:

zτ = softmax

(g(φ, ε)

τ

), lim

τ→0zτ , ...

I Apparently this is differentiable, has a tractable likelihood, etc.

Then we used Gumbel variables:

I Things became clean and (more) closed form

I The parameters π (or α) had interpretable meaningthough apparently if I’m a deep learning zealot I don’t care...

I But what did we sacrifice? Does that matter?

Today: relaxations that play with this tension (on more elaborate discrete objects)

STAT G8201: Deep Generative Models 3 / 34

Page 4: Deep Generative Models - Columbia Universitystat.columbia.edu/~cunningham/teaching/GR8201/STAT_GR8201_2019… · STAT G8201: Deep Generative Models 4 / 34. Updates DGM website I Lecture

Discrete structure

Haven’t we completed discrete structure?

I We learned how to use Gumbel to reparameterize to ∆n−1

I Isn’t that complete?

I permutations of m items are discrete objects of size n = m!I subsets of m items are discrete objects of size n =

(mk

)I So, yes... but naive approaches will get intractable quickly

Some discrete objects have special structure

I Special structure can imply a geometry more convenient to relaxation

I Example: permutations

I write any permutation as an m×m permutation matrixI observe all row and column sums equal 1I suggests an O(m2) (or really O(m)) relaxation, vs m!I (but what do we lose?)

Today: relaxations that play with this special structure

STAT G8201: Deep Generative Models 4 / 34

Page 5: Deep Generative Models - Columbia Universitystat.columbia.edu/~cunningham/teaching/GR8201/STAT_GR8201_2019… · STAT G8201: Deep Generative Models 4 / 34. Updates DGM website I Lecture

Updates

DGM website

I Lecture slides

I Presentation tex template

I Reading

I Presenter schedule

I Also for discussion: upcoming topics and papers (eg, NLP day?)

Projects

I This week: please identify key publications (literature review)

I Do this in journal.md or begin a .bib in doc

I If you are still entirely uncertain (or are waiting), let’s talk

STAT G8201: Deep Generative Models 5 / 34

Page 6: Deep Generative Models - Columbia Universitystat.columbia.edu/~cunningham/teaching/GR8201/STAT_GR8201_2019… · STAT G8201: Deep Generative Models 4 / 34. Updates DGM website I Lecture

Learning permutations: examples

Figure: Example 1: Input X̃imapping P→ Output Xi

Example 2: Input X̃i = (1, 2, 3, 4, 5)mapping P→ Output Xi(3.2, 4.6, 1.2, 4.9, 1.8)

I Goal: Infer the permutation mapping P : X → X .

I Method: Permutation regression

Xi = P (X̃i) + εi.

where P is the latent variable.

STAT G8201: Deep Generative Models 6 / 34

Page 7: Deep Generative Models - Columbia Universitystat.columbia.edu/~cunningham/teaching/GR8201/STAT_GR8201_2019… · STAT G8201: Deep Generative Models 4 / 34. Updates DGM website I Lecture

Infer Permutation mapping P

Figure: Example 1: Input X̃imapping P→ Output Xi

I The permutation mapping P has only finite # of possible values.

⇒ Discrete latent variable! (Each possible value is a category.)

I Finite # of possible values = N !

⇒ Too many categories! Can not use Gumbel softmax. :-(

STAT G8201: Deep Generative Models 7 / 34

Page 8: Deep Generative Models - Columbia Universitystat.columbia.edu/~cunningham/teaching/GR8201/STAT_GR8201_2019… · STAT G8201: Deep Generative Models 4 / 34. Updates DGM website I Lecture

Permutation mappings

Figure: Permutation mapping as matrices

I Each permutation mapping P is a matrix multiplication.

I The matrices have structures:

I Each entry is 0 or 1.I row sum = Column sum = 1.

STAT G8201: Deep Generative Models 8 / 34

Page 9: Deep Generative Models - Columbia Universitystat.columbia.edu/~cunningham/teaching/GR8201/STAT_GR8201_2019… · STAT G8201: Deep Generative Models 4 / 34. Updates DGM website I Lecture

Relaxing permutation mappings

Figure: Permutation mapping as matrices

I The permutation matrix has structures:

I Each entry is 0 or 1.I row sum = Column sum = 1.

I Continuous relaxation (Birkhoff polytope; doubly stochastic)

I Now only N2 dimensional space to worry about!

STAT G8201: Deep Generative Models 9 / 34

Page 10: Deep Generative Models - Columbia Universitystat.columbia.edu/~cunningham/teaching/GR8201/STAT_GR8201_2019… · STAT G8201: Deep Generative Models 4 / 34. Updates DGM website I Lecture

Sampling from the relaxed space BN

Figure: The Birkhoff polytope: relaxed space for permutation matrices

I How to reparametrize the permutation matrix?

→ How to map arbitraty matrices into the Birkhoff polytope?

Solution Ease Rigor Empirical

Stick-breaking (Linderman et al.) A− B BRounding (Linderman et al.) A C B+Sinkhorn network (Mena et al.) B A A

Table: Three solutions

STAT G8201: Deep Generative Models 10 / 34

Page 11: Deep Generative Models - Columbia Universitystat.columbia.edu/~cunningham/teaching/GR8201/STAT_GR8201_2019… · STAT G8201: Deep Generative Models 4 / 34. Updates DGM website I Lecture

Stick-breaking I (Linderman et al.)

I Idea: fill in the entries one by one; watch out the constraints

I First row is easy: row sum = 1

x1n = β1n(1−n−1∑k=1

xik), n = 2, . . . , N − 1 (1)

x1N = 1−N−1∑n=1

x1n, (2)

where β1n ∈ [0, 1] (e.g. Beta r.v.)

STAT G8201: Deep Generative Models 11 / 34

Page 12: Deep Generative Models - Columbia Universitystat.columbia.edu/~cunningham/teaching/GR8201/STAT_GR8201_2019… · STAT G8201: Deep Generative Models 4 / 34. Updates DGM website I Lecture

Stick-breaking II (Linderman et al.)

I From second row on is tricky: row sum = col sum = 1.

I We can’t sample too large: the remaining row entries are > 0.

xmn ≤ 1−n−1∑k=1

xmk, xmn ≤ 1−m−1∑k=1

xkn,

I We can’t sample too small: the remaining row entries areconstrained by their columns too!

1−n∑k=1

xmk︸ ︷︷ ︸remaining stick

≤N∑

j=n+1

(1−m−1∑k=1

xkj)︸ ︷︷ ︸remaining upper bounds from columns

STAT G8201: Deep Generative Models 12 / 34

Page 13: Deep Generative Models - Columbia Universitystat.columbia.edu/~cunningham/teaching/GR8201/STAT_GR8201_2019… · STAT G8201: Deep Generative Models 4 / 34. Updates DGM website I Lecture

Stick-breaking III (Linderman et al.)

I Constraints on each entry:

xmn ≤ 1−n−1∑k=1

xmk,

1−n∑k=1

xmk︸ ︷︷ ︸remaining stick

≤N∑

j=n+1

(1−m−1∑k=1

xkj)︸ ︷︷ ︸remaining upper bounds from columns

I Figure out the lower and upper bound lmn ≤ xmn ≤ umn.I We fill in xmn = lmn + βmn(umn − lmn) where βmn ∈ [0, 1].

STAT G8201: Deep Generative Models 13 / 34

Page 14: Deep Generative Models - Columbia Universitystat.columbia.edu/~cunningham/teaching/GR8201/STAT_GR8201_2019… · STAT G8201: Deep Generative Models 4 / 34. Updates DGM website I Lecture

Stick-breaking IV (Linderman et al.)

I Grades: Ease A-, Rigor B, Empirical B

I Why?

I We are unconstrained in the first row but very constrainedin the last row.

I But all entries are born equal!

STAT G8201: Deep Generative Models 14 / 34

Page 15: Deep Generative Models - Columbia Universitystat.columbia.edu/~cunningham/teaching/GR8201/STAT_GR8201_2019… · STAT G8201: Deep Generative Models 4 / 34. Updates DGM website I Lecture

Rounding I (Linderman et al.)

Figure: The rounding solution

I Idea: Start with real matrices

→ Map to doubly stochastic matrices

→ Add noise

→ round to permutation matrix

I We only need the map to be differentiable.

STAT G8201: Deep Generative Models 15 / 34

Page 16: Deep Generative Models - Columbia Universitystat.columbia.edu/~cunningham/teaching/GR8201/STAT_GR8201_2019… · STAT G8201: Deep Generative Models 4 / 34. Updates DGM website I Lecture

Rounding II (Linderman et al.)

Figure: The rounding solution

I Step 1: real matrices M ∈ RN×N+

I Step 2: map M ∈ RN×N+ to a doubly stochastic M̃

I Goal: row sum = col sum = 1I Recursively normalize rows and columnsI This step is differentiable.

STAT G8201: Deep Generative Models 16 / 34

Page 17: Deep Generative Models - Columbia Universitystat.columbia.edu/~cunningham/teaching/GR8201/STAT_GR8201_2019… · STAT G8201: Deep Generative Models 4 / 34. Updates DGM website I Lecture

Rounding III (Linderman et al.)

Figure: The rounding solution

I Step 3: Add noise to M̃ : Ψ = M̃ + V · Z, where Z is someGaussian noise. (Location/scale shift; differentiable)

I Step 4: Round: find the nearest permutation matrix

I Can be done efficiently with the Hungarian algorithm.

round(Ψ) = argmaxP∈BN < P,Ψ >F (3)

where < A,B >F= tr(A>B). (differentiable a.e.)STAT G8201: Deep Generative Models 17 / 34

Page 18: Deep Generative Models - Columbia Universitystat.columbia.edu/~cunningham/teaching/GR8201/STAT_GR8201_2019… · STAT G8201: Deep Generative Models 4 / 34. Updates DGM website I Lecture

Rounding IV (Linderman et al.)

Figure: The rounding solution

I Grades: Ease A, Rigor C, Empirical B+

I Why?

I Rounding sounds hacky. It is not everywhere differentiable.

STAT G8201: Deep Generative Models 18 / 34

Page 19: Deep Generative Models - Columbia Universitystat.columbia.edu/~cunningham/teaching/GR8201/STAT_GR8201_2019… · STAT G8201: Deep Generative Models 4 / 34. Updates DGM website I Lecture

Sinkhorn Network I (Mena et al.)

Figure: The Sinkhorn network solution

I Idea: Start with real matrices

→ Map to doubly stochastic matrices

→ Smoothly round to permutation matrix

I Similar to rounding but less hacky.

I Mimics what we did with softmax.

STAT G8201: Deep Generative Models 19 / 34

Page 20: Deep Generative Models - Columbia Universitystat.columbia.edu/~cunningham/teaching/GR8201/STAT_GR8201_2019… · STAT G8201: Deep Generative Models 4 / 34. Updates DGM website I Lecture

Sinkhorn Network I (Mena et al.)

Figure: The Sinkhorn network solution

I Step 1: Map real matrices to doubly stochastic

I Sinkhorn operator: recursively normalize rows and columns

S0(X) = exp(X), Sl(X) = τc(τr(Sl−1(X)))

S(X) = liml→∞

Sl(X)

I Cf. Softmax (normalization+limit) limτ→0 exp(xi/τ)/∑j=1 exp(xj/τ)

STAT G8201: Deep Generative Models 20 / 34

Page 21: Deep Generative Models - Columbia Universitystat.columbia.edu/~cunningham/teaching/GR8201/STAT_GR8201_2019… · STAT G8201: Deep Generative Models 4 / 34. Updates DGM website I Lecture

Sinkhorn Network II (Mena et al.)

Figure: The Sinkhorn network solution

I Step 2: Smoothly round X to permutation matrix

S(X/τ) = argmaxP∈BN < P,X >F +τ · h(P ),

where the entropy h(P ) = −∑i,j Pij log(Pij); < A,B >F= tr(A>B).

I Cf. Softmax-to-category v∗ = argmaxv∈S < x, v >

STAT G8201: Deep Generative Models 21 / 34

Page 22: Deep Generative Models - Columbia Universitystat.columbia.edu/~cunningham/teaching/GR8201/STAT_GR8201_2019… · STAT G8201: Deep Generative Models 4 / 34. Updates DGM website I Lecture

Sinkhorn Network III (Mena et al.)

Figure: Permutation and discrete optimal transport

I Step 2: Smoothly round X to permutation matrix

S(X/τ) = argmaxP∈BN < P,X >F +τ · h(P ),

I Rigor: M(X) = limτ→0+ S(X/τ); differentiable (cf. rounding)I Entropy regularization speeds up computationI Connections to discrete optimal transport

STAT G8201: Deep Generative Models 22 / 34

Page 23: Deep Generative Models - Columbia Universitystat.columbia.edu/~cunningham/teaching/GR8201/STAT_GR8201_2019… · STAT G8201: Deep Generative Models 4 / 34. Updates DGM website I Lecture

Sinkhorn Network III (Mena et al.)

Figure: The Sinkhorn network solution

I Grades: Ease B, Rigor A, Empirical A

I Why?

I We have a theorem.I Connections to optimal transport is cool.

STAT G8201: Deep Generative Models 23 / 34

Page 24: Deep Generative Models - Columbia Universitystat.columbia.edu/~cunningham/teaching/GR8201/STAT_GR8201_2019… · STAT G8201: Deep Generative Models 4 / 34. Updates DGM website I Lecture

Empirical comparison

Figure: Empirical performance of the solutions

I Sinkhorn ≈ Rounding > Stick-breaking

STAT G8201: Deep Generative Models 24 / 34

Page 25: Deep Generative Models - Columbia Universitystat.columbia.edu/~cunningham/teaching/GR8201/STAT_GR8201_2019… · STAT G8201: Deep Generative Models 4 / 34. Updates DGM website I Lecture

Differentiable Subset Sampling (DSS)

I Goal: “Extending” Gumbel-max trick to subsets of {x1, . . . , xn} of size k.

I Issue: Naive use of Gumbel-max trick results in a simplex with(nk

)vertices. Not

mentioned on the paper, but this is a key motivation for this work.

I Note: Sampling a permutation and discarding everything but the first k elements isequivalent to sampling a subset. This direction is not explored on the paper.

I Gumbel-max samples continuous noise and then does a discrete transformation(argmax). The trick is relaxing the discrete transformation with a “similar”continous one (softmax).

I A subset can be represented as an n-dimensional binary vector with exactly k ones:this can be relaxed to a vector with entries in (0, 1) whose sum is k.

STAT G8201: Deep Generative Models 25 / 34

Page 26: Deep Generative Models - Columbia Universitystat.columbia.edu/~cunningham/teaching/GR8201/STAT_GR8201_2019… · STAT G8201: Deep Generative Models 4 / 34. Updates DGM website I Lecture

DSS: Plackett-Luce

A random subset S = {xi1 , . . . , xik} follows a Placket-Luce (PL) distribution withparameter w = (w1, . . . , wn) ∈ Rn+ if:

P (S) =wi1Z

wi2Z − wi1

· · · wikZ −

∑k−1j=1 wij

where Z =∑nj=1 wi. The following procedure samples from this distribution:

1. Sample u1, . . . , un ∼ U(0, 1).

2. Compute ri = u1/wii .

3. Output {xi1 , . . . , xik} where i1, . . . , ik are the indices corresponding to the klargest values in (r1, . . . , rn).

Can we relax the last step?

STAT G8201: Deep Generative Models 26 / 34

Page 27: Deep Generative Models - Columbia Universitystat.columbia.edu/~cunningham/teaching/GR8201/STAT_GR8201_2019… · STAT G8201: Deep Generative Models 4 / 34. Updates DGM website I Lecture

DSS: Why does the sampling work?

Full proof is not conceptually complicated but is cumbersome (citation 11). Considerthe very simplified case where n = 2 and k = 1:

P (r1 ≤ r2) =

∫ 1

0

Fr1(t)fr2(t)dt =

∫ 1

0

tw1w2tw2−1dt = w2

∫ 1

0

tw1+w2−1dt

=w2

w1 + w2

The full proof proceeds similarly, using induction.

STAT G8201: Deep Generative Models 27 / 34

Page 28: Deep Generative Models - Columbia Universitystat.columbia.edu/~cunningham/teaching/GR8201/STAT_GR8201_2019… · STAT G8201: Deep Generative Models 4 / 34. Updates DGM website I Lecture

DSS: Relaxing the last step

First, note that we only care about the ordering of (r1, . . . , rn) and − log(− log(·)) is anincreasing transformation, so we can user̂i = − log(− log(ri)) = − log(− log(ui)) + log(wi) instead of ri, which is exactlyGumbel noise.

STAT G8201: Deep Generative Models 28 / 34

Page 29: Deep Generative Models - Columbia Universitystat.columbia.edu/~cunningham/teaching/GR8201/STAT_GR8201_2019… · STAT G8201: Deep Generative Models 4 / 34. Updates DGM website I Lecture

DSS: Relaxing the last step

First, we go through a non-differentiable random procedure with a temperatureparameter t. As t→ 0+, the procedure becomes deterministic and outputs the rightanswer. The procedure has k steps, and at each step, one element is added to thesubset.

I Input: r̂i for i, . . . , n (I think these should be shifted).

I Output: Binary vector a = (a1, . . . , an) with exactly k ones.

1. Initialize α1i := log(r̂i) (paper typo, this is why I think they should be shifted).

2. For j = 1, . . . , k:

I Sample a one-hot vector aj = (aj1, . . . , ajn) from a categorical distribution

with parameter softmax(αj/t) (this is the element that will be added to thesubset).

I Update αj+1i = αji + log(1− aji ) for i = 1, . . . , n (this makes sure that we do

not add the same element more than once).

3. Compute a =∑kj=1 a

j .

Finally, this procedure can be relaxed and made deterministic even for t > 0 by replacingaji with its expectation, P (aji = 1).

STAT G8201: Deep Generative Models 29 / 34

Page 30: Deep Generative Models - Columbia Universitystat.columbia.edu/~cunningham/teaching/GR8201/STAT_GR8201_2019… · STAT G8201: Deep Generative Models 4 / 34. Updates DGM website I Lecture

DSS: Experiments

I Toy experiment to verify that, for a given w, the relaxed top k procedure works.

I Variable selection problem: maxS⊆PkI(XS , Y ), which is relaxed to:

maxE

I(XS , Y ) s.t. S ∼ E

where E is a distribution over subsets (I believe the text has an error, E is not asubset). Since the mutual information cannot be computed, a lower bound ismaximized instead.

I k “nearest” neighbors, where instead of using the k nearest neighbors, k neighborsare randomly chosen.

STAT G8201: Deep Generative Models 30 / 34

Page 31: Deep Generative Models - Columbia Universitystat.columbia.edu/~cunningham/teaching/GR8201/STAT_GR8201_2019… · STAT G8201: Deep Generative Models 4 / 34. Updates DGM website I Lecture

DSS: Some Issues

I Proposition 2 states that the relaxed procedure preserves the ordering of r̂ in a ift ≥ 1. While this is fine, the interesting behavior is when t→ 0+. Also, this meansthat the largest k elements of r̂ and a match, but it does not mean that the k-hotvector with the coordinates of the k largest elements of r̂ is close to a.

I Figure 1 compares the subset recovered by the k largest elements of a for differentvalues of t against the true distribution. By proposition 2, when t ≥ 1, thesedistributions match. The only interesting thing to observe on the figure is thatwhen t < 1, the distribution is not far from the correct one (this point is not madeon the paper). Again, it would be interesting to compare a’s against the k-hotvectors corresponding to subsets from the true distribution.

STAT G8201: Deep Generative Models 31 / 34

Page 32: Deep Generative Models - Columbia Universitystat.columbia.edu/~cunningham/teaching/GR8201/STAT_GR8201_2019… · STAT G8201: Deep Generative Models 4 / 34. Updates DGM website I Lecture

Stochastic Optimization of Sorting Networksvia Continuous Relaxations

I Object of interest are sorting permutations, i.e. given a vector s, sort(s) is thepermutation that orders s in descending order.

I Instead of using sort(s), the permutation matrix Psort(s) is used (analogous toone-hot or k-hot representations).

I Permutation matrices are relaxed to unimodal row stochastic matrices, i.e.,matrices U with non-negative elements, whose rows sum to one, and with theproperty that ui = argmaxjU [i, j] forms a valid permutation u = (u1, . . . , un).

I They show that, if As is a matrix containing the paiwise distances of the elementsof s, then:

Psort(s)[i, j] =

{1 if j = argmax(n+ 1− 2i)s−AsIn0 otherwise

I The above argmax is relaxed to a softmax (and a temperature is added). Thisresults in unimodal row stochastic matrices.

I This allows to differentiate through sorting. It also provides a way to sample fromPL distributions: Psort(log s+ g), where g is Gumbel noise (s is equivalent to where).

STAT G8201: Deep Generative Models 32 / 34

Page 33: Deep Generative Models - Columbia Universitystat.columbia.edu/~cunningham/teaching/GR8201/STAT_GR8201_2019… · STAT G8201: Deep Generative Models 4 / 34. Updates DGM website I Lecture

Final thoughts

I The reparametrization trick was first developed for continuous variables.

I By relaxing a discrete space into a continuous one in such a way that the discreteobject can be easily recovered from the continuous one (e.g. vertices on thesimplex), the reparametrization trick can still be used (e.g. Gumbel-max trick).

I If the discrete set is very large naive Gumbel-max trick might not be tractable butif there is some structure (e.g. subsets and permutations), all hope is not lost.

I What about the discrete infinite case?

STAT G8201: Deep Generative Models 33 / 34

Page 34: Deep Generative Models - Columbia Universitystat.columbia.edu/~cunningham/teaching/GR8201/STAT_GR8201_2019… · STAT G8201: Deep Generative Models 4 / 34. Updates DGM website I Lecture

Final thoughts: A simple idea for thediscrete infinite case

I β1, β2, . . . are noise samples in (0, 1).

I β̃1, β̃2, . . . are obtained from β by a stick breaking procedure, so that∑∞i=1 β̃i = 1.

I The resulting sample s is given by s = (softmax(β̃:N/τ), 0, 0, . . . ), where β̃:N arethe first N coordinates of β̃.

I Notes:

1. The simplex relaxation does not work here, as we cannot store infinitely manynumbers. If we restrict ourselves to categorical distributions with finitelymany non-zero probabilities it works though (not the same as the finite case).

2. Any sensible choice of N must involve β̃, but this can be ignored whendifferentiating, some ideas are:

I The smallest N s.t.∑∞i=N+1 β̃i < ε.

I The smallest N s.t. the m-th largest value of β̃:N is bigger than∑∞i=N+1 β̃i

(this would ensure that the m largest values of β̃ are in β̃:N ).

3. The stick breaking part is needed so that N can be adequately chosen.

STAT G8201: Deep Generative Models 34 / 34