Advanced Statistical Computing Week 5: EM Algorithmavdvaart/ASC/EM.pdf · 2012. 10. 1. · Week 5:...

23
Advanced Statistical Computing Week 5: EM Algorithm Aad van der Vaart Fall 2012

Transcript of Advanced Statistical Computing Week 5: EM Algorithmavdvaart/ASC/EM.pdf · 2012. 10. 1. · Week 5:...

Page 1: Advanced Statistical Computing Week 5: EM Algorithmavdvaart/ASC/EM.pdf · 2012. 10. 1. · Week 5: EM Algorithm Aad van der Vaart Fall 2012. Contents 2 EM Algorithm Mixtures Hidden

Advanced Statistical ComputingWeek 5: EM Algorithm

Aad van der Vaart

Fall 2012

Page 2: Advanced Statistical Computing Week 5: EM Algorithmavdvaart/ASC/EM.pdf · 2012. 10. 1. · Week 5: EM Algorithm Aad van der Vaart Fall 2012. Contents 2 EM Algorithm Mixtures Hidden

Contents

2

EM Algorithm

Mixtures

Hidden Markov models

Page 3: Advanced Statistical Computing Week 5: EM Algorithmavdvaart/ASC/EM.pdf · 2012. 10. 1. · Week 5: EM Algorithm Aad van der Vaart Fall 2012. Contents 2 EM Algorithm Mixtures Hidden

EM Algorithm

Page 4: Advanced Statistical Computing Week 5: EM Algorithmavdvaart/ASC/EM.pdf · 2012. 10. 1. · Week 5: EM Algorithm Aad van der Vaart Fall 2012. Contents 2 EM Algorithm Mixtures Hidden

EM-algorithm

4

SETTING:Observation X, likelihood θ 7→ pθ(X), hard to maximize and find MLE θ).

X can be viewed as 1st coordinate of (X,Y ) with density (x, y) 7→ pθ(x, y):

pθ(x) =

pθ(x, y) dµ(y).

EM-ALGORITHM: GIVEN θ0 REPEAT• E-step: compute θ 7→ Eθi

(

log pθ(X,Y )|X)

.• M-step: θi+1 =: point of maximum of this function.

θ0, θ1, . . . often tends to MLE, but may not converge, converge slowly,or converge to local maximum.

[ Y may be missing data, of augmented data, invented for convenience.]

Page 5: Advanced Statistical Computing Week 5: EM Algorithmavdvaart/ASC/EM.pdf · 2012. 10. 1. · Week 5: EM Algorithm Aad van der Vaart Fall 2012. Contents 2 EM Algorithm Mixtures Hidden

EM-Algorithm — increases target

5

LEMMA θ0, θ1, . . . generated by EM-algorithm satisfies pθ0(X) ≤ pθ1(X) ≤ · · · .

PROOFpθ(x, y) = pθ(y|x)pθ(x).

Eθi

(

log pθ(X,Y )|X)

= Eθi

(

log pθ(Y |X)|X)

+ log pθ(X).

Because θi+1 maximizes left side over θ, it suffices to show

Eθi

(

log pθi+1(Y |X)|X

)

≤ Eθi

(

log pθi(Y |X)|X)

.

Or −K(p, q): = Ep log(q/p)(Y ) ≤ 0 for p = pθi , q = pθi+1, conditioned on X.

Now Kullback-Leibler divergence K(p; q) is nonnegative for any p, q.

This does not prove that θi converges to the MLE!

Page 6: Advanced Statistical Computing Week 5: EM Algorithmavdvaart/ASC/EM.pdf · 2012. 10. 1. · Week 5: EM Algorithm Aad van der Vaart Fall 2012. Contents 2 EM Algorithm Mixtures Hidden

EM-Algorithm — linear convergence

6

The speed of the EM-algorithm is linear, with slow convergence if the augmentedmodel is statistically much more informative than the data model.

Page 7: Advanced Statistical Computing Week 5: EM Algorithmavdvaart/ASC/EM.pdf · 2012. 10. 1. · Week 5: EM Algorithm Aad van der Vaart Fall 2012. Contents 2 EM Algorithm Mixtures Hidden

Mixtures

Page 8: Advanced Statistical Computing Week 5: EM Algorithmavdvaart/ASC/EM.pdf · 2012. 10. 1. · Week 5: EM Algorithm Aad van der Vaart Fall 2012. Contents 2 EM Algorithm Mixtures Hidden

Mixtures

8

SETTINGObservations random sample X1, . . . , Xn from density

pθ(x) =k∑

j=1

pjf(x; ηj), θ = (p1, . . . , pk, η1, . . . , ηk).

AUGMENTED DATA

P(Yi = j) = pj , Xi|Yi = j ∼ f(·; ηj), i = 1, . . . , n.

Full likelihood

pθ(X1, . . . , Xn, Y1, . . . , Yn) =n∏

i=1

k∏

j=1

(

pjf(Xi; ηj))1{Yi=j} .

Page 9: Advanced Statistical Computing Week 5: EM Algorithmavdvaart/ASC/EM.pdf · 2012. 10. 1. · Week 5: EM Algorithm Aad van der Vaart Fall 2012. Contents 2 EM Algorithm Mixtures Hidden

Mixtures — E-step, M-step

9

E-step: given (p, eta):

Ep,η

(

logn∏

i=1

k∏

j=1

(

pjf(Xi, ηj))1{Yi=j} |X1, . . . , Xn

)

=n∑

i=1

k∑

j=1

log(

pjf(Xi, ηj))

αi,j ,αi,j : = Pp,η

(

Yi = j|Xi

)

=pjf(Xi, ηj)

c pcf(Xi, ηc)

=[

k∑

j=1

log pj

(

n∑

i=1

αi,j

)]

+k∑

j=1

[

n∑

i=1

log f(Xi; ηj)αi,j

]

.

M-step: for j = 1, . . . , k:

pnewj =1

n

n∑

i=1

αi,j , ηnewj = argmaxη

n∑

i=1

log f(Xi; η)αi,j .

[ If the f(·; η) have a common parameter, then the computation of the ηj does not separate as theydo here.]

Page 10: Advanced Statistical Computing Week 5: EM Algorithmavdvaart/ASC/EM.pdf · 2012. 10. 1. · Week 5: EM Algorithm Aad van der Vaart Fall 2012. Contents 2 EM Algorithm Mixtures Hidden

Mixtures — Example

10

EXAMPLEIf f(·; η) ∼ N(η, 1), then

n∑

i=1

log f(Xi, η)αi,j = − 1

2

∑ni=1

αi,j(Xi − η)2 + Const.

ηnewj =

∑n

i=1αijXi

∑ni=1

αi,j

.

EXAMPLEIf f(·; η) ∼ Γ(r, η), then

n∑

i=1

log f(Xi, η)αi,j =n∑

i=1

(r log η − ηXi)αi,j + Const.

ηnewj =r∑n

i=1αi,j

∑ni=1

αi,jXi

.

Page 11: Advanced Statistical Computing Week 5: EM Algorithmavdvaart/ASC/EM.pdf · 2012. 10. 1. · Week 5: EM Algorithm Aad van der Vaart Fall 2012. Contents 2 EM Algorithm Mixtures Hidden

R

11

0 5 10 150.

000.

100.

200.

30

> n=100> shape=c(2,2,2); eta=c(1,6,.2); prob=c(1/4,1/8,5/8)> component=sample(c(1,2,3),n,replace=TRUE,prob=prob)> x=rgamma(n,shape=shape[component],rate=eta[component])

Page 12: Advanced Statistical Computing Week 5: EM Algorithmavdvaart/ASC/EM.pdf · 2012. 10. 1. · Week 5: EM Algorithm Aad van der Vaart Fall 2012. Contents 2 EM Algorithm Mixtures Hidden

R — EM, known shape

12

> k=3; a=matrix(0,n,k); p=c(1/3,1/3,1/3); eta=c(1,2,3); change=1> while (change>0.0001){+ for (j in 1:k) a[,j]=p[j]*dgamma(x,2,eta[j])+ a=diag(1/apply(a,1,sum))%*%a+ etanew=2*apply(a,2,sum)/matrix(x,1,n)%*%a+ pnew=apply(a,2,mean)+ change=sum(abs(etanew-eta)+abs(pnew-p))+ print(rbind(pnew,etanew))+ eta=etanew; p=pnew}[ --- output deleted ---- ]

[,1] [,2] [,3]pnew 0.6259239 0.3161804 0.05789564

0.2157931 1.7430514 7.57683781

0 5 10 15

0.00

0.10

0.20

0.30

Page 13: Advanced Statistical Computing Week 5: EM Algorithmavdvaart/ASC/EM.pdf · 2012. 10. 1. · Week 5: EM Algorithm Aad van der Vaart Fall 2012. Contents 2 EM Algorithm Mixtures Hidden

R — packages

13

> library(mixtools)> mod=gammamixEM(x,k=3)number of iterations= 323> summary(mod)Error in summary.mixEM(mod) : Unknown mixEM object of type gammamixEM> mod[[2]]; mod[[3]][1] 0.37441469 0.57523322 0.05035209

comp.1 comp.2 comp.3alpha 1.6203475 2.092346 20.9880430beta 0.6184701 4.126267 0.7926715

0 5 10 15

0.00

0.10

0.20

0.30

[ Besides package mixtools, there is also flexmix, and ... (?)]

Page 14: Advanced Statistical Computing Week 5: EM Algorithmavdvaart/ASC/EM.pdf · 2012. 10. 1. · Week 5: EM Algorithm Aad van der Vaart Fall 2012. Contents 2 EM Algorithm Mixtures Hidden

Mixtures — warnings

14

Not all mixtures are identifiable from the data: multiple parameter vectors maygive the same mixture.

Maximum likelihood may work only if the parameter set is restricted. (Notableexample: location scale mixtures, if the scale parameter approaches zero, thelikelihood may tend to infinity.)

EM tends to be slow for large data sets, and might get stuck in local maxima (?)

Page 15: Advanced Statistical Computing Week 5: EM Algorithmavdvaart/ASC/EM.pdf · 2012. 10. 1. · Week 5: EM Algorithm Aad van der Vaart Fall 2012. Contents 2 EM Algorithm Mixtures Hidden

Hidden Markov models

Page 16: Advanced Statistical Computing Week 5: EM Algorithmavdvaart/ASC/EM.pdf · 2012. 10. 1. · Week 5: EM Algorithm Aad van der Vaart Fall 2012. Contents 2 EM Algorithm Mixtures Hidden

Hidden Markov model

16

Y1 Y2 Y3

. . .

Yn−1 Yn

X1 X2 X3

. . .

Xn−1 Xn

Markov chain of hidden states Y1, Y2, . . . ,; only outputs X1, X2, . . . observed.

Xi given Yi conditionally independent of all other variables.

EXAMPLES• speech recognition: states abstract, outputs Fourier coding of sounds.• genomics: states are introns/exons, outputs nucleotides• genomics: states are # chromosomal duplicates, outputs noisy

measurements• genetics: states inheritance vectors, output measured markers.• cell biology: states of ion channels, outputs current or no current• economics: state of economy, output # firms in default.

Page 17: Advanced Statistical Computing Week 5: EM Algorithmavdvaart/ASC/EM.pdf · 2012. 10. 1. · Week 5: EM Algorithm Aad van der Vaart Fall 2012. Contents 2 EM Algorithm Mixtures Hidden

Hidden Markov model

17

Y1 Y2 Y3

. . .

Yn−1 Yn

X1 X2 X3

. . .

Xn−1 Xn

Markov chain of hidden states Y1, Y2, . . . ,; only outputs X1, X2, . . . observed.

Xi given Yi conditionally independent of all other variables.

Parameters• density π of Y1

• transition density p(yi| yi−1) of the Markov chain.• output density q(xi| yi).

Full likelihood

π(y1)p(y2| y1)× · · · × p(yn| yn−1) q(x1| y1)× · · · × q(xn| yn).

Page 18: Advanced Statistical Computing Week 5: EM Algorithmavdvaart/ASC/EM.pdf · 2012. 10. 1. · Week 5: EM Algorithm Aad van der Vaart Fall 2012. Contents 2 EM Algorithm Mixtures Hidden

HMM — E and M-step

18

E-step:

Eπ,p,q

(

log π(Y1)

n∏

i=2

p(Yi|Yi−1)

n∏

i=1

q(Xi|Yi)|X1, . . . , Xn

)

= Eπ,p,q

(

log π(Y1)|X1, . . . , Xn

)

+n∑

i=2

Eπ,p,q

(

log p(Yi|Yi−1)|X1, . . . , Xn

)

+

n∑

i=1

Eπ,p,q

(

log q(Xi|Yi)|X1, . . . , Xn

)

.

M-step:• depends on the specification of models for π, p, q.• if state space is finite p is typically left free.• only current estimate of law of (Yi−1, Yi) given X1, . . . , Xn needed, which

are computed using the forward and backward algorithm.

Page 19: Advanced Statistical Computing Week 5: EM Algorithmavdvaart/ASC/EM.pdf · 2012. 10. 1. · Week 5: EM Algorithm Aad van der Vaart Fall 2012. Contents 2 EM Algorithm Mixtures Hidden

Baum-Welch

19

The EM-algorithm for the HMM with finite state space, and completelyunspecified distributions π, p, q, is called Baum-Welch algorithm.

If π and p are left free:πnew = p

Y1|X1,...,Xn

π,p,q (y).

pnew(v|u) =

∑ni=2

pYi−1,Yi|X1,...,Xn

π,p,q (u, v)∑n

i=2pYi−1|X1,...,Xn

π,p,q (u).

If q is also left free (possible for finite output space, but not often the case):

qnew(x| y) =

i:Xi=x pYi|X1,...,Xi−1,Xi=x,Xi+1,...,Xn

π,p,q (y)∑

x∈X

i:Xi=x pYi|X1,...,Xi−1,Xi=x,Xi+1,...,Xn

π,p,q (y).

[ To compute these expressions need density of (Yi−1, Yi) given X1, . . . , Xn. This is computedusing the forward and backward algorithm.]

Page 20: Advanced Statistical Computing Week 5: EM Algorithmavdvaart/ASC/EM.pdf · 2012. 10. 1. · Week 5: EM Algorithm Aad van der Vaart Fall 2012. Contents 2 EM Algorithm Mixtures Hidden

Viterbi

20

Y1 Y2 Y3

. . .

Yn−1 Yn

X1 X2 X3

. . .

Xn−1 Xn

The Viterbi algorithm computes the most likely state path given the outcomes:

argmaxy1,...,yn

P(Y1 = y1, . . . , Yn = yn|X1, . . . , Xn).

Page 21: Advanced Statistical Computing Week 5: EM Algorithmavdvaart/ASC/EM.pdf · 2012. 10. 1. · Week 5: EM Algorithm Aad van der Vaart Fall 2012. Contents 2 EM Algorithm Mixtures Hidden

R

21

0 20 40 60 80 1000

12

34

5

> library(HiddenMarkov)> Pi=matrix(c(0.7,0.3,0.2,0.8),2,2,byrow=TRUE); delta=c(0.3,0.7)> n=100; pn=list(size=rep(5,n)); pm=list(prob=c(0.3,0.8))> myhmm=dthmm(NULL,Pi=Pi,delta=delta,distn="binom",pn=pn,pm=pm)> x=simulate(myhmm,n)>> plot(1:n,x$x,type="s",xlab="",ylab="")> lines(1:n,x$y-1,col=2,type="s")

[ Markov chain with two states, transition matrix Π, initial distribution δ. Outputs are from thebinomial(5, p)- distribution, with θ = 0.3 from state 1 and θ = 0.8 from state 2. Red: states, Black:outputs.]

Page 22: Advanced Statistical Computing Week 5: EM Algorithmavdvaart/ASC/EM.pdf · 2012. 10. 1. · Week 5: EM Algorithm Aad van der Vaart Fall 2012. Contents 2 EM Algorithm Mixtures Hidden

R

22

0 20 40 60 80 1000

12

34

5

> mod=BaumWelch(x); mod$Pi; mod$pm[---- output deleted ---]

[,1] [,2][1,] 0.6287149 0.3712851[2,] 0.2637289 0.7362711$prob[1] 0.3173456 0.8313127

[ Markov chain with two states, transition matrix Π =0.7 0.30.2 0.8

, initial distribution δ = (0.3, 0.7).

Outputs are from the binomial(5, p)- distribution, with θ = 0.3 from state 1 and θ = 0.8 from state 2.]

Page 23: Advanced Statistical Computing Week 5: EM Algorithmavdvaart/ASC/EM.pdf · 2012. 10. 1. · Week 5: EM Algorithm Aad van der Vaart Fall 2012. Contents 2 EM Algorithm Mixtures Hidden

R

23

0 20 40 60 80 1000

12

34

5

> Viterbi(x)[1] 2 2 2 2 2 2 2 1 1 2 2 2 2 2 1 1 1 1 1 2 2 2 2 2 2 2 1 1 1[36] 2 2 2 2 2 2 2 2 2 1 1 1 2 1 1 2 2 2 2 2 2 2 1 1 1 1 2 2 1[71] 2 2 2 2 2 2 2 2 2 2 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 1 1

> lines(1:n,Viterbi(x)-1,col=3,lw=2)

[ Red: true states, Black: outputs; Green: reconstructed states.]