Markov Chain Monte Carlo InferenceIterative conditional modes: Instead of sampling, update wrt a...

60
Markov Chain Monte Carlo Inference Melih Kandemir Heidelberg Grad Days 2019 Lecture 3

Transcript of Markov Chain Monte Carlo InferenceIterative conditional modes: Instead of sampling, update wrt a...

Page 1: Markov Chain Monte Carlo InferenceIterative conditional modes: Instead of sampling, update wrt a point estimate (e.g. mean, mode). 5Bishop, PRML, 2006 16/?? Collapsed Gibbs Sampling

Markov Chain Monte Carlo InferenceMelih KandemirHeidelberg Grad Days 2019

Lecture 3

Page 2: Markov Chain Monte Carlo InferenceIterative conditional modes: Instead of sampling, update wrt a point estimate (e.g. mean, mode). 5Bishop, PRML, 2006 16/?? Collapsed Gibbs Sampling

2/??

Monte Carlo Integration

The big question : Evaluate

Ep(z)[f(z)] =

∫f(z)p(z)dz

ExamplesI Bayesian prediction:

p(znew|z,D) =∫p(znew|θ)p(θ|D)dθ = Ep(θ|D)[p(znew|θ)]

I Difficult variational updates:log q(z1)← Ep(z2)[log p(z1, z2)]

I Difficult E-step in EM:Q(θ, θold) = Ep(z|D,θold)[log p(z,D|θ)]

Page 3: Markov Chain Monte Carlo InferenceIterative conditional modes: Instead of sampling, update wrt a point estimate (e.g. mean, mode). 5Bishop, PRML, 2006 16/?? Collapsed Gibbs Sampling

3/??

Approximating the integral by samples

Ep(z)[f(z)] =

∫f(z)p(z)dz

' 1

L

L∑l=1

f(z(l))

where z(l) are samples drawn from p(z(l)).

As long as iid samples are drawn from the true p(z(l)), ' 20samples are sufficient for a good approximation.

Page 4: Markov Chain Monte Carlo InferenceIterative conditional modes: Instead of sampling, update wrt a point estimate (e.g. mean, mode). 5Bishop, PRML, 2006 16/?? Collapsed Gibbs Sampling

4/??

Sampling from inverse CDF1

Draw u ∼ Uniform(0, 1)Calculate y = h−1(u)

Because:Pr(h−1(u) ≤ y) = Pr(u ≤ h(y)) = h(y)

Problem: How do we compute h−1(u) for an arbitrarydistribution?

1Bishop, PRML, 2006

Page 5: Markov Chain Monte Carlo InferenceIterative conditional modes: Instead of sampling, update wrt a point estimate (e.g. mean, mode). 5Bishop, PRML, 2006 16/?? Collapsed Gibbs Sampling

5/??

Rejection Sampling2

Target distribution p(z), and envelop distribution q(z)

Procedure:I z(t) ∼ q(z)I u(t) ∼ Uniform(0, kq(z(t)))I Accept sample if u(t) ≤ p(z)

p(accept) =∫ p(z)

kq(z)q(z)dz =

1

k

∫p(z)dz

2Bishop, PRML, 2006

Page 6: Markov Chain Monte Carlo InferenceIterative conditional modes: Instead of sampling, update wrt a point estimate (e.g. mean, mode). 5Bishop, PRML, 2006 16/?? Collapsed Gibbs Sampling

6/??

Adaptive Rejection Sampling3

Envelope function is a set of piecewise exponential functions:

q(z) = kiλi exp{−λi(z − zi−1)} zi−1 ≤ z ≤ zi

Each rejected sample is added as a grid point.

Acceptance rate decays exponentially wrt dimensionality!3Bishop, PRML, 2006

Page 7: Markov Chain Monte Carlo InferenceIterative conditional modes: Instead of sampling, update wrt a point estimate (e.g. mean, mode). 5Bishop, PRML, 2006 16/?? Collapsed Gibbs Sampling

7/??

Importance Sampling (1)

Ep(z)[f(z)] =

∫f(z)p(z)dz

=

∫f(z)

p(z)

q(z)q(z)dz

Draw l samples from q(z). Then,

Ep(z)[f(z)] ' 1

L

∫ L

l=1f(z(l))

p(z(l))

q(z(l))︸ ︷︷ ︸importance weight

Page 8: Markov Chain Monte Carlo InferenceIterative conditional modes: Instead of sampling, update wrt a point estimate (e.g. mean, mode). 5Bishop, PRML, 2006 16/?? Collapsed Gibbs Sampling

8/??

Importance Sampling (2)

I (+) All samples are retained.I (-) Too much dependent on how similar q(z) is to p(z).I (-) No diagnostic measures available!

Page 9: Markov Chain Monte Carlo InferenceIterative conditional modes: Instead of sampling, update wrt a point estimate (e.g. mean, mode). 5Bishop, PRML, 2006 16/?? Collapsed Gibbs Sampling

9/??

Markov Chain Monte Carlo

I Robust to high dimensionalitiesI Samples form a Markov chain with a transition functionT (z|z′)

I Samples are drawn from the target distribution p(z) if,I p(z) is invariant wrt T (z|z′),

p(z) =

∫p(z′)T (z|z′)dz′.

I the Markov chain governed by T (z|z′) is ergodic.I Invariance : Ensured by detailed balance:

p(z)T (z′|z) = p(z′)T (z|z′)

I Ergodicity : More tricky. Imposed by sampling algorithms.

Page 10: Markov Chain Monte Carlo InferenceIterative conditional modes: Instead of sampling, update wrt a point estimate (e.g. mean, mode). 5Bishop, PRML, 2006 16/?? Collapsed Gibbs Sampling

10/??

Metropolis-HastingsProcedure:

I Propose the next state by Q(z′|z), e.g. N (z, σ2)

I Accept with probability min

(1,p(z′)Q(z|z′)p(z)Q(z′|z)

)I Stay at the current state (add another copy of it to the

samples list) otherwiseThe proposal variance σ2 is very influential.

I Determines step sizeI If large, low acceptance rateI If small, slow convergence

Page 11: Markov Chain Monte Carlo InferenceIterative conditional modes: Instead of sampling, update wrt a point estimate (e.g. mean, mode). 5Bishop, PRML, 2006 16/?? Collapsed Gibbs Sampling

11/??

Metropolis-Hastings (2)

Detailed balance is provided:

p(z)T (z′|z) = p(z)Q(z′|z) min

(1,p(z′)Q(z|z′)p(z)Q(z′|z)

)= min

(p(z)Q(z′|z), p(z′)Q(z|z′)

)= p(z′)Q(z|z′) min

(p(z)Q(z′|z)p(z′)Q(z|z′)

, 1

)= p(z′)T (z|z′)

Page 12: Markov Chain Monte Carlo InferenceIterative conditional modes: Instead of sampling, update wrt a point estimate (e.g. mean, mode). 5Bishop, PRML, 2006 16/?? Collapsed Gibbs Sampling

12/??

Metropolis-Hastings (3)4

1-D Demo:

4Murray,MLSS,2009

Page 13: Markov Chain Monte Carlo InferenceIterative conditional modes: Instead of sampling, update wrt a point estimate (e.g. mean, mode). 5Bishop, PRML, 2006 16/?? Collapsed Gibbs Sampling

13/??

Gibbs Sampling

Procedure:I Initialize z(1)1 , z

(1)2 , z

(1)3

I For l = 1 to L− 1I z

(l+1)1 ∼ p(z1|z(l)2 , z

(l)3 )

I z(l+1)2 ∼ p(z2|z(l+1)

1 , z(l)3 )

I z(l+1)3 ∼ p(z3|z(l+1)

1 , z(l+1)2 )

Page 14: Markov Chain Monte Carlo InferenceIterative conditional modes: Instead of sampling, update wrt a point estimate (e.g. mean, mode). 5Bishop, PRML, 2006 16/?? Collapsed Gibbs Sampling

14/??

Gibbs Sampling (2)

I Invariance: All conditioned variates are constant bydefinition, and the remaining variable is sampled from thetrue distribution.

I Ergodicity: Guaranteed if all conditional probabilities arenon-zero in their entire domain.

I Gibbs sampling is a special case of Metropolis-Hastingswith qk(z′|z) = p(zk|z\k), thus

A(z′|z) =p(z′k|z′\k)p(z′\k)p(zk|z′\k)p(zk|z\k)p(z\k)p(z′k|z\k)

= 1

Hence, all samples are accepted.

Page 15: Markov Chain Monte Carlo InferenceIterative conditional modes: Instead of sampling, update wrt a point estimate (e.g. mean, mode). 5Bishop, PRML, 2006 16/?? Collapsed Gibbs Sampling

15/??

Gibbs Sampling (3)5

Step size is governed by covariances of conditionaldistributions.

Iterative conditional modes: Instead of sampling, update wrta point estimate (e.g. mean, mode).

5Bishop, PRML, 2006

Page 16: Markov Chain Monte Carlo InferenceIterative conditional modes: Instead of sampling, update wrt a point estimate (e.g. mean, mode). 5Bishop, PRML, 2006 16/?? Collapsed Gibbs Sampling

16/??

Collapsed Gibbs Sampling

Integrating out some of the variables may yield others to appearconditionally-independent, which entails faster convergence.

Rao-Blackwell Theorem: Let z and θ be dependent variables,and f(z, θ) be some scalar function. Then,

varz,θ[f(z, θ)] ≥ varz[Eθ[f(z, θ)|z]].

Page 17: Markov Chain Monte Carlo InferenceIterative conditional modes: Instead of sampling, update wrt a point estimate (e.g. mean, mode). 5Bishop, PRML, 2006 16/?? Collapsed Gibbs Sampling

17/??

Example: Gaussian Mixture Model 6

Employ conjugate priors to:I cluster meansI cluster covariancesI mixture probabilities

Then integrate them out!

6Murphy, Mach. Learn., 2012

Page 18: Markov Chain Monte Carlo InferenceIterative conditional modes: Instead of sampling, update wrt a point estimate (e.g. mean, mode). 5Bishop, PRML, 2006 16/?? Collapsed Gibbs Sampling

18/??

Implementation tricks

I Thinning : Take every Kth sample to decorrelateI Burn-in : Discard first (e.g. half) of the samples which

were prior to mixingI Multiple runs : To neutralize the effect of initialization

Page 19: Markov Chain Monte Carlo InferenceIterative conditional modes: Instead of sampling, update wrt a point estimate (e.g. mean, mode). 5Bishop, PRML, 2006 16/?? Collapsed Gibbs Sampling

19/??

Diagnosing Convergence 1: Traceplots

Page 20: Markov Chain Monte Carlo InferenceIterative conditional modes: Instead of sampling, update wrt a point estimate (e.g. mean, mode). 5Bishop, PRML, 2006 16/?? Collapsed Gibbs Sampling

20/??

Diagnosing Convergence 2: Running mean plots

Page 21: Markov Chain Monte Carlo InferenceIterative conditional modes: Instead of sampling, update wrt a point estimate (e.g. mean, mode). 5Bishop, PRML, 2006 16/?? Collapsed Gibbs Sampling

21/??

Diagnosing Conv. 3: Rubin-Gelman Metric

I Calculate within-chain variance W and between-chainvariance B

I Calculate estimated variance

ˆV ar(θ) = (1− 1/n)W + (1/n)B

I Calculate and monitor Potential Scale Reduction Factor(PSRF)

R =

√ˆV ar(θ)

W

I R should get smaller until convergence.

Page 22: Markov Chain Monte Carlo InferenceIterative conditional modes: Instead of sampling, update wrt a point estimate (e.g. mean, mode). 5Bishop, PRML, 2006 16/?? Collapsed Gibbs Sampling

22/??

Diagnosing Convergence 4: Other metrics

I Geweke diagnostic: Take first x and last y samples inthe chain and test if they come from the same distribution.

I Raftery and Lewis diagnostic: Calculate nr of iterationsuntil a desired level of accuracy is reached for a posteriorquantile.

I Heidelberg and Welch diagnostic: Repeatedsignificance testing (stationary vs null)

Page 23: Markov Chain Monte Carlo InferenceIterative conditional modes: Instead of sampling, update wrt a point estimate (e.g. mean, mode). 5Bishop, PRML, 2006 16/?? Collapsed Gibbs Sampling

23/??

Example: Bayesian logistic regression

p(fi|w,xi) = N (fi|wTxi, σ2), i = 1, · · · , N

p(yi|fi) =1

1 + e−fiyi, i = 1, · · · , N

p(wd|αd) = N (wd|0, α−1d ), d = 1, · · · , Dp(αd) = G(αd|a, b), d = 1, · · · , D

Page 24: Markov Chain Monte Carlo InferenceIterative conditional modes: Instead of sampling, update wrt a point estimate (e.g. mean, mode). 5Bishop, PRML, 2006 16/?? Collapsed Gibbs Sampling

24/??

Let’s aim for a Gibbs samples

We require the following conditional distributions:

p(w|f ,α,X,y), (1)p(α|w, f ,X,y), (2)p(f |w,α,X,y) (3)

Page 25: Markov Chain Monte Carlo InferenceIterative conditional modes: Instead of sampling, update wrt a point estimate (e.g. mean, mode). 5Bishop, PRML, 2006 16/?? Collapsed Gibbs Sampling

25/??

The log joint

log p(w, f ,α,X,y) =

N∑i=1

log p(fi|w,xi) +

N∑i=1

log p(yi|fi)

+

D∑d=1

log p(wi|αi) +

D∑d=1

log p(αi)

= −1

2log |σ2I| − 1

2σ2(fT −wTXT )(f −Xw)

−N∑i=1

log(1 + e−yifi) +1

2

D∑d=1

logαd −1

2wTAw

+

D∑d=1

(a− 1) logαd −D∑d=1

bαd + const

where Add = αd and Aij = 0, i 6= j

Page 26: Markov Chain Monte Carlo InferenceIterative conditional modes: Instead of sampling, update wrt a point estimate (e.g. mean, mode). 5Bishop, PRML, 2006 16/?? Collapsed Gibbs Sampling

26/??

The conditionals

p(αd|α−d,w, f ,X,y) = G(αd|a+1

2, b+

1

2w2d)

p(w|f ,α,X,y) = N

(w

∣∣∣∣∣ (XTX + A)−1

XT f ,(XTX + A

)−1)p(f |w,α,X,y) = Metropolis with q(fi) = N (fi|wTxi, σ

2)

Page 27: Markov Chain Monte Carlo InferenceIterative conditional modes: Instead of sampling, update wrt a point estimate (e.g. mean, mode). 5Bishop, PRML, 2006 16/?? Collapsed Gibbs Sampling

27/??

Problems with standard MCMCs

Metropolis and Metropolis-Hastings:I The proposal distribution is agnostic about the target

model.I The proposals tend to perform random walk, hence shoot

blindly.I The outcome often is an unacceptably low acceptance

rate.Gibbs:

I Requires conditional distributions available in closed form(not tractable even for logistic regression).

Remedy⇒ Incorporate model curvature into the samplingscheme.

Page 28: Markov Chain Monte Carlo InferenceIterative conditional modes: Instead of sampling, update wrt a point estimate (e.g. mean, mode). 5Bishop, PRML, 2006 16/?? Collapsed Gibbs Sampling

28/??

Back to the past: Potential and Kinetic Energy

Figure. http://99daveva31893.blogspot.com.tr/2013/05/potential-and-kinetic-energy.html.jpg

Page 29: Markov Chain Monte Carlo InferenceIterative conditional modes: Instead of sampling, update wrt a point estimate (e.g. mean, mode). 5Bishop, PRML, 2006 16/?? Collapsed Gibbs Sampling

29/??

For physicists

I System state: Position of the roller coaster θ.I Potential energy: A score proportional to the height of the

roller coaster U(θ).I Kinetic energy: A score proportional to the speed

(actually momentum) r of the roller coaster K(r).I Total energy: Rule that governs how potential and kinetic

energies are related H(θ, r) = K(r) + U(θ).

Here H(θ, r) is also called the Hamiltonian function.

Page 30: Markov Chain Monte Carlo InferenceIterative conditional modes: Instead of sampling, update wrt a point estimate (e.g. mean, mode). 5Bishop, PRML, 2006 16/?? Collapsed Gibbs Sampling

30/??

For us machine learners

I System state: Position on the explored space of latentvariables θ.

I Potential energy: A score proportional (i.e. nonormalization constant) to the posterior we aim to samplefrom U(θ).

I Kinetic energy: Auxiliary variables r that animate thesystem state as fast as an auxiliary score K(r).

I Total energy: Rule H(θ, r) = K(r) + U(θ) assuring thatthe animated system is a Markov chain which has theposterior as the stationary distribution.

Page 31: Markov Chain Monte Carlo InferenceIterative conditional modes: Instead of sampling, update wrt a point estimate (e.g. mean, mode). 5Bishop, PRML, 2006 16/?? Collapsed Gibbs Sampling

31/??

Hamiltonian dynamicsA physically-inspired way to model the Markov chain dynamics:For ith latent variable (particle for physicists), we have

dθidt

=∂H

∂ri, (4)

dridt

= −∂H∂θi

, (5)

meaningI (1) A particle moves as fast as the change in kinetic

energy.I (2) The momentum of the particle increases as fast as the

decrease in its height.We choose

K(r) =1

2rTM−1r

U(θ) = − log p(θ)− log p(D|θ) = − log p(θ)−∑x∈D

log p(x|θ).

Page 32: Markov Chain Monte Carlo InferenceIterative conditional modes: Instead of sampling, update wrt a point estimate (e.g. mean, mode). 5Bishop, PRML, 2006 16/?? Collapsed Gibbs Sampling

32/??

Understanding Hamiltonian dynamics: The hockeypuck analogy

dθidt

=∂H

∂ri,

dridt

= −∂H∂θi

,

I Assume a hockey puck (θ) placed on a rugged surface offrictionless ice (U(θ)).

I We let the puck move by pushing it towards an arbitrarydirection K(r).

I As the surface is frictionless, the puck will keep movingforever (so we can sample as much as we want!).

I On a flat surface, the puck will keep constant speed.I Under positive slope ∂H/∂θi > 0 it will climb and then lose

speed and vice versa.I The puck swings between steep regions (modes) and

keeps speed on plateaus!

Page 33: Markov Chain Monte Carlo InferenceIterative conditional modes: Instead of sampling, update wrt a point estimate (e.g. mean, mode). 5Bishop, PRML, 2006 16/?? Collapsed Gibbs Sampling

33/??

How to solve the Hamiltonian system ofdifferential equations

No analytically tractable solution for interesting models.Approximate by finite difference.

Way 1: Euler’s method:

rt+ε ← rt + εdrtdt

= rt − ε∇θtU(θ)

θt+ε ← θt + εdθtdt

= θt + εM−1rt

I Advantage: Ultimately trivial.I Disadvantage: The finite difference approximation will

diverge from the true gradient at every time step.

Page 34: Markov Chain Monte Carlo InferenceIterative conditional modes: Instead of sampling, update wrt a point estimate (e.g. mean, mode). 5Bishop, PRML, 2006 16/?? Collapsed Gibbs Sampling

34/??

How to solve the Hamiltonian system ofdifferential equations

Way 2: Modified Euler’s method:

rt+ε ← rt + εdrtdt

= rt − ε∇θtU(θ)

θt+ε ← θt + εdθtdt

= θt + εM−1rt+ε

I Coupling the two updates brings about a charmingimprovement in accuracy, yet does not solve all theproblems.

I The finite difference approximation is still not sufficientlyaccurate.

I Stronger coupling is required.

Page 35: Markov Chain Monte Carlo InferenceIterative conditional modes: Instead of sampling, update wrt a point estimate (e.g. mean, mode). 5Bishop, PRML, 2006 16/?? Collapsed Gibbs Sampling

35/??

How to solve the Hamiltonian system ofdifferential equations

Way 3: The Leapfrog method:

rt+ε/2 ← rt + (ε/2)drtdt

= rt − (ε/2)∇θtU(θ)

θt+ε ← θt + εdθtdt

= θt + εM−1rt+ε/2

rt+ε ← rt+ε/2 + (ε/2)drtdt

= rt+ε/2 − (ε/2)∇θtU(θ)

Here is how the Leapfrog method takes one step ahead:I Take half a step with the right leg.I Take a full step with the left leg.I Take half a step with the right leg.

Page 36: Markov Chain Monte Carlo InferenceIterative conditional modes: Instead of sampling, update wrt a point estimate (e.g. mean, mode). 5Bishop, PRML, 2006 16/?? Collapsed Gibbs Sampling

36/??

Euler’s method and Leapfrog

Figure. R. Neal, MCMC using Hamiltonian Dynamics, 2011

Page 37: Markov Chain Monte Carlo InferenceIterative conditional modes: Instead of sampling, update wrt a point estimate (e.g. mean, mode). 5Bishop, PRML, 2006 16/?? Collapsed Gibbs Sampling

37/??

Hamiltonian Monte Carlo Sampling

Figure. T. Chen et al. Stochastic Gradient Hamiltonian MonteCarlo, ICML, 2014

Page 38: Markov Chain Monte Carlo InferenceIterative conditional modes: Instead of sampling, update wrt a point estimate (e.g. mean, mode). 5Bishop, PRML, 2006 16/?? Collapsed Gibbs Sampling

38/??

The Metropolis CorrectionThe numerical inaccuracies resulting from the finite differenceapproximation accumulate and cause the chain diverge fromthe target distribution. We can avoid this by accessing the trueposterior every now and then.Define the Hamiltonian joint distribution

π(θ, r) ∝ exp(− U(θ)− 1

2rTM−1r

).

Applying the Metropolis criterion, we get the acceptanceprobability

min

(1,

π(θ, r)

π(θ0, r0)

)= min

1, eH(θ,r)−H(θ0,r0)︸ ︷︷ ︸ρ

.

Remark: Accessing the posterior p(θ|D) is nice, but could beunacceptably expensive for some models, such as deep neuralnets!

Page 39: Markov Chain Monte Carlo InferenceIterative conditional modes: Instead of sampling, update wrt a point estimate (e.g. mean, mode). 5Bishop, PRML, 2006 16/?? Collapsed Gibbs Sampling

39/??

HMC in R

Figure. R. Neal, MCMC using Hamiltonian Dynamics, 2011

Page 40: Markov Chain Monte Carlo InferenceIterative conditional modes: Instead of sampling, update wrt a point estimate (e.g. mean, mode). 5Bishop, PRML, 2006 16/?? Collapsed Gibbs Sampling

40/??

Random walk versus Hamiltonian dynamics

Figure. R. Neal, MCMC using Hamiltonian Dynamics, 2011

Page 41: Markov Chain Monte Carlo InferenceIterative conditional modes: Instead of sampling, update wrt a point estimate (e.g. mean, mode). 5Bishop, PRML, 2006 16/?? Collapsed Gibbs Sampling

41/??

Random walk versus Hamiltonian dynamics

Figure. R. Neal, MCMC using Hamiltonian Dynamics, 2011

Page 42: Markov Chain Monte Carlo InferenceIterative conditional modes: Instead of sampling, update wrt a point estimate (e.g. mean, mode). 5Bishop, PRML, 2006 16/?? Collapsed Gibbs Sampling

42/??

Random walk versus Hamiltonian dynamics

Figure. R. Neal, MCMC using Hamiltonian Dynamics, 2011

Page 43: Markov Chain Monte Carlo InferenceIterative conditional modes: Instead of sampling, update wrt a point estimate (e.g. mean, mode). 5Bishop, PRML, 2006 16/?? Collapsed Gibbs Sampling

43/??

HMC in high dimensionalities

Figure. R. Neal, MCMC using Hamiltonian Dynamics, 2011

Page 44: Markov Chain Monte Carlo InferenceIterative conditional modes: Instead of sampling, update wrt a point estimate (e.g. mean, mode). 5Bishop, PRML, 2006 16/?? Collapsed Gibbs Sampling

44/??

Stochastic Gradient HMCI HMC is all great for accurate posterior inference but every

jump requires a full pass on the the data, which is nolonger practical in the present age.

I The naive way out is to switch from exact gradient tostochastic gradient for the potential energy (i.e. the model):

∇U(θ) = −|D||D|

∑x∈D

∇ log p(x|θ)−∇ log p(θ),

where D is a random minibatch.I The stochastic gradient trick is mostly used for very large

data sets, which results in excessively many iterations.When iterated large enough times, the Central LimitTheorem will govern the stochastic gradient noise

∇U(θ)−∇U(θ) ∼ N (0, V ),

where V is the noise covariance.

Page 45: Markov Chain Monte Carlo InferenceIterative conditional modes: Instead of sampling, update wrt a point estimate (e.g. mean, mode). 5Bishop, PRML, 2006 16/?? Collapsed Gibbs Sampling

45/??

The notorious efficiency-accuracy tradeoff

Minibatch SizeSmall Large

E�cient

Computation

Accurate

Gradient

Page 46: Markov Chain Monte Carlo InferenceIterative conditional modes: Instead of sampling, update wrt a point estimate (e.g. mean, mode). 5Bishop, PRML, 2006 16/?? Collapsed Gibbs Sampling

46/??

Naive Stochastic Gradient HMC

With a little sloppy notation, let us write

∇U(θ)−∇U(θ) ∼ N (0, V )⇒ ∇U(θ) = ∇U(θ) +N (0, V ).

With the added stochastic gradient noise, the ε-discretizedmomentum update turns into

∆r = −ε∇U(θ) = −ε(∇U(θ) + p), p = Ls, s ∼ N (0, 1),

= −ε∇U(θ) + εp, εp = εLs, s ∼ N (0, 1),

= −ε∇U(θ) +N (0, ε2V ).

where V = LLT is the Cholesky decomposition of V .

Page 47: Markov Chain Monte Carlo InferenceIterative conditional modes: Instead of sampling, update wrt a point estimate (e.g. mean, mode). 5Bishop, PRML, 2006 16/?? Collapsed Gibbs Sampling

47/??

Naive Stochastic Gradient HMCCasting ε→ 0, we attain the dynamical system below

dθ = M−1rdt,

dr = −∇U(θ)dt+N (0, 2B(θ)dt),

where B(θ) = 12εV (θ).

I The hockey puck on the ice surface is now under randomwind!

I The Hamiltonian system preserves entropy when underexact gradients.

I The extra entropy coming from the stochastic gradientbreaks the balance and accumulates entropy at everyiteration.

I Consequently, the dynamic system above converges to auniform distribution!

Page 48: Markov Chain Monte Carlo InferenceIterative conditional modes: Instead of sampling, update wrt a point estimate (e.g. mean, mode). 5Bishop, PRML, 2006 16/?? Collapsed Gibbs Sampling

48/??

Stochastic Gradient HMC with FrictionA new term is added to the system to counter the stochasticgradient entropy:

dθ = M−1rdt,

dr = −∇U(θ)dt−BM−1rdt+N (0, 2B(θ)dt).

I The hockey puck is now on an ice surface, under randomwind, and is exerted some friction!

I With the friction term added, the system is again able topreserve entropy, hence the equilibrium distribution is theposterior.

I The dynamical system above is known by physicists assecond-order Langevin dynamics [Wang & Uhlenbeck,1945].

Page 49: Markov Chain Monte Carlo InferenceIterative conditional modes: Instead of sampling, update wrt a point estimate (e.g. mean, mode). 5Bishop, PRML, 2006 16/?? Collapsed Gibbs Sampling

49/??

Practical IssuesI All is well, but we do not know the noise model V , hence

the B matrix in the friction term.I Yet we can approximate it empirically B ≈ B, whereB = 1

2εV with V being the empirical Fisher information.I For any C � B, the system below is equivalent to SGHMC

with friction

dθ = M−1rdt,

dr = −∇U(θ)dt− CM−1rdt+N (0, 2(C − B)dt) +N (0, 2B(θ)dt).

I When C = B = 0 the SGHMC boils down to SGD withmomentum.

I No need to do Metropolis correction any longer! Recall thatthis saves a lot of time (evaluation of the model on theentire training data).

Page 50: Markov Chain Monte Carlo InferenceIterative conditional modes: Instead of sampling, update wrt a point estimate (e.g. mean, mode). 5Bishop, PRML, 2006 16/?? Collapsed Gibbs Sampling

50/??

Stochastic Gradient HMC

Figure. T. Chen et al. Stochastic Gradient Hamiltonian MonteCarlo, ICML, 2014.

Page 51: Markov Chain Monte Carlo InferenceIterative conditional modes: Instead of sampling, update wrt a point estimate (e.g. mean, mode). 5Bishop, PRML, 2006 16/?? Collapsed Gibbs Sampling

51/??

Naive versus principled SGHMC

Figure. T. Chen et al. Stochastic Gradient Hamiltonian MonteCarlo, ICML, 2014.

Page 52: Markov Chain Monte Carlo InferenceIterative conditional modes: Instead of sampling, update wrt a point estimate (e.g. mean, mode). 5Bishop, PRML, 2006 16/?? Collapsed Gibbs Sampling

52/??

Naive versus principled SGHMC

Figure. T. Chen et al. Stochastic Gradient Hamiltonian MonteCarlo, ICML, 2014.

Page 53: Markov Chain Monte Carlo InferenceIterative conditional modes: Instead of sampling, update wrt a point estimate (e.g. mean, mode). 5Bishop, PRML, 2006 16/?? Collapsed Gibbs Sampling

53/??

SGHMC for deep learningA Bayesian multilayer perceptron with 100 hidden neuronsevaluated on MNIST.

Figure. T. Chen et al. Stochastic Gradient Hamiltonian MonteCarlo, ICML, 2014.

Page 54: Markov Chain Monte Carlo InferenceIterative conditional modes: Instead of sampling, update wrt a point estimate (e.g. mean, mode). 5Bishop, PRML, 2006 16/?? Collapsed Gibbs Sampling

54/??

Maximum A-Posteriori Estimation

For a Bayesian model p(θ) p(D|θ) with p(D|θ) =∏x∈D p(x|θ),

the mode of the posterior is called the Maximum A-PosterioriEstimate (MAP). The MAP of a model can be found by

argminθ

{− log p(θ)−

∑x∈D

log p(x|θ)}.

When closed-form solution is not available, we do gradientdescent

θt+1 ← θt − εt(−∇ log p(θt)−

∑x∈D∇ log p(x|θ)

).

The second term demands a full pass on the training set, whichis not feasible in many cases.

Page 55: Markov Chain Monte Carlo InferenceIterative conditional modes: Instead of sampling, update wrt a point estimate (e.g. mean, mode). 5Bishop, PRML, 2006 16/?? Collapsed Gibbs Sampling

55/??

Stochastic Gradient Descent (SGD)Robbins-Monro Theorem [Robbins & Monro 1951] says that theapproximate gradient

|D||D|

∑x∈D∇ log p(x|θ)

obtained by randomly sampling a minibatch D, hence assuming

that the training data consists of|D|D

replications of D willconverge to the exact gradient after infinitely many iterations ifthe learning rate follows a series satisfying

∞∑t=1

εt =∞,∞∑t=1

ε2t <∞.

Hence, the learning rate should never drop down to zero (left),but should still keep decreasing over time (right).

Page 56: Markov Chain Monte Carlo InferenceIterative conditional modes: Instead of sampling, update wrt a point estimate (e.g. mean, mode). 5Bishop, PRML, 2006 16/?? Collapsed Gibbs Sampling

56/??

(First-order) Langevin Dynamics

∆θt =ε

2

(∇ log p(θt) +

∑x∈D∇ log p(x|θt)

)+ ηt,

ηt = N (0, ε).

I Just as in HMC, introduced to solve stochastic differentialequations.

I Converges to the posterior, but discretization error shouldbe resolved by Metropolis correction.

I Unlike HMC, Gaussian noise injected to the gradient toenhance stochasticity.

I The noise variance ε is proportional to the learning rate.I Decreasing ε also decreases the discretization error. Will

be useful soon!

Page 57: Markov Chain Monte Carlo InferenceIterative conditional modes: Instead of sampling, update wrt a point estimate (e.g. mean, mode). 5Bishop, PRML, 2006 16/?? Collapsed Gibbs Sampling

57/??

Stochastic Gradient Langevin Dynamics (SGLD)7

Adapt Langevin dynamics to SGD

∆θt =εt2

(∇ log p(θt) +

|D||D|

∑x∈D

∇ log p(x|θt)

)+ ηt,

ηt = N (0, εt).

I To be eligible for Robbins-Monro, learning rate ηt has todecrease in time.

I The discretization error will also decrease⇒ Metropoliscorrection will no longer be required!

7M. Welling, Y.W. Teh, Stochastic Gradient Langevin Dynamics, ICML,2011

Page 58: Markov Chain Monte Carlo InferenceIterative conditional modes: Instead of sampling, update wrt a point estimate (e.g. mean, mode). 5Bishop, PRML, 2006 16/?? Collapsed Gibbs Sampling

58/??

SGD⇒ SGLD

I The algorithm has two phases, starts with first, switches tothe second after some iterations:

1. SGD: Performs only SGD with extended stochasticity (ηt) tooverpass local minima.

2. Langevin Dynamics: Samples from the true posterior.I The key question is when this switching takes place.I The system transitions into mode two when

εt ≈4α|D||D|

λmin(I−1F ),

where I−1F is the empirical Fisher information of thestochastic gradient error, λmin(·) is the smallest eigenvalueof the argument, and α is a sample threshold.

Page 59: Markov Chain Monte Carlo InferenceIterative conditional modes: Instead of sampling, update wrt a point estimate (e.g. mean, mode). 5Bishop, PRML, 2006 16/?? Collapsed Gibbs Sampling

59/??

Calculating posterior expectationsI During training, we collect a set of samples{θ1, θ2, · · · , θT }.

I We will use them to approximate an expectation E[f(θ)],for instance the posterior predictive of a future observationx∗

p(x∗|D) =

∫p(x∗|θ)p(θ|D)dθ = E[p(x∗|θ)].

I Simple sample averaging

E[p(x∗|θ)] ≈ 1

T

T∑t=1

p(x∗|θt)

will over-emphasize non-minimal regions where εt (hencediscretization error) was high, the system was not yetsufficiently in the Langevin dynamics phase.

I Instead, do

E[p(x∗|θ)] ≈∑T

t=1 εtp(x∗|θt)∑T

t=1 εt.

Page 60: Markov Chain Monte Carlo InferenceIterative conditional modes: Instead of sampling, update wrt a point estimate (e.g. mean, mode). 5Bishop, PRML, 2006 16/?? Collapsed Gibbs Sampling

60/??

Bayesian neural net benchmarking on some UCIdata sets

A Bayesian neural net with a single hidden layer of 50 unitsused. Root Mean Square Error on the test split reported.

HMC Dropout PBP Varout SGLDboston 2.76 2.97 3.01 2.70 2.21concrete 4.12 5.23 5.67 4.89 4.19kin8nm 0.06 0.10 0.10 0.08 0.02power 3.73 4.02 4.12 4.04 2.42protein 3.91 4.36 4.73 4.13 1.07red wine 0.63 0.62 0.64 0.63 0.21yacht 0.56 1.11 1.02 0.71 1.32