Approximations in Bayesian Inference -...

Post on 21-Sep-2020

0 views 0 download

Transcript of Approximations in Bayesian Inference -...

Approximations in Bayesian Inference

Vaclav Smıdl

March 3, 2020

Previous models

xi

i=1,2

m s

µ = 0 α = 2 β = 3

p(s) = iG(α0, β0)p(m|s) = N (µ, s)

p(xi |m, s) = N (m, s)

I Observations xi are sampled from Gaussian with unknown mean andvariance.

I We have some prior information about the mean and variance

Previous models

xi

i=1,2

m s

µ = 0 α β

xi

i=1,2

m s

µ φ α β

Inference:

p(m, s|x1, x2, µ, α, β, φ) ∝2∏

i=1p(xi ,m, s)p(m|µ, φ)p(s|α, β),

where p(d) is a normalization constant.

Joint likelihood

p(m, s|x1, x2, µ, α, β, φ)

∝ 1s

1sα0+1 exp

(−1

2(m − x1)2

s − 12

(m − x2)2

s − 12

(m − µ)2

φ− β0

s

)p(m|s, x1, x2, µ, α, β, φ)

∝ exp(−1

2(m − x1)2

s − 12

(m − x2)2

s − 12

(m − µ)2

φ

)∝ exp

(−1

2

[m2( 1

φ+ 2

s )− 2m(µφ

+ x1 + x2s )

])= N

(m;(

+ 2s

)−1(µ

φ+ x1 + x2

s

),

(1φ

+ 2s

)−1)

p(s|m, x1, x2, µ, α, β, φ)

∝ 1sα0+2 exp

(−1

2(m − x1)2

s − 12

(m − x2)2

s − β0s

)= iG(α0 + 1, 0.5(m − x1)2 + 0.5(m − x2)2 + β0)

Numerical solution:

p(m,s|x1=1,x2=2, 7=3, ,=0.1, -=0.1, ?=1000

2 4 6 8s

0

1

2

3

4

5

6

m

marginal

p(s|x1,x2)

0

1

2

3

4

5

6

marginal

p(m|x1,x2)

p(m,s|x1=1,x2=2, 7=3, ,=0.1, -=0.1, ?=1000

2 4 6 8s

0

1

2

3

4

5

6

m

marginal

p(s|x1,x2)

marginal

p(m|x1,x2)

Numerical solution:

p(m,s|x1=1,x2=2, 7=5, ,=0.1, -=0.1, ?=2

2 4 6 8s

0

1

2

3

4

5

6

m

marginal

p(s|x1,x2)

marginal

p(m|x1,x2)

What is the role of prior?

1. UninformativeI needed to make a consistent answerI Jeffrey’s

2. Non-committalI make minor adjustment

3. InformativeI incorporate important informationI do we know exact shape of the distribution?

4. Structural, weakly informative

Non-informative (Jeffreys)

I The answer should be invariant to the change of coordinates ofparameter θ

I Solutionp(θ) ∝

√detF(θ),

where F is the Fisher information matrix

F(θ) = −E[∂2

∂θ2 log p(X ; θ)∣∣∣∣ θ]

I For normal likelihood

log p(x |m, s) = −12 log s − 1

2(m − x)2

s + c= . . .

p(m, s) ∝ 1s

= iG(0, 0)

I Uniform prior on scale is informative...

Non-informative (Jeffreys)

I The answer should be invariant to the change of coordinates ofparameter θ

I Solutionp(θ) ∝

√detF(θ),

where F is the Fisher information matrix

F(θ) = −E[∂2

∂θ2 log p(X ; θ)∣∣∣∣ θ]

I For normal likelihood

log p(x |m, s) = −12 log s − 1

2(m − x)2

s + c= . . .

p(m, s) ∝ 1s = iG(0, 0)

I Uniform prior on scale is informative...

Conjugate prior

For normal likelihood with parameters m, s

log p(x |m, s) = −12 log s − 1

2(m − x)2

s + c

= −12 log s − 1

2m2 − 2mx + x2

s + c

= [−12 ,−

12 , x ,−

12x2][log s, m2

s ,ms ,

1s ]

is a composition of bases functions [log s, m2

s ,ms ,

1s ]. It is advantageous to

choose prior with the same basis functions.

log p(m, s) = −12 log s − 1

2(m − µ)2

s + c − (α0 + 1) log s − β0s

=

= [−12 − α0 − 1,−1

2 , µ,−12µ

2 − β0][log s, m2

s ,ms ,

1s ]

Posterior is of the same form as prior.

Conjugate prior

For normal likelihood with parameters m, s

log p(x |m, s) = −12 log s − 1

2(m − x)2

s + c

= −12 log s − 1

2m2 − 2mx + x2

s + c

= [−12 ,−

12 , x ,−

12x2][log s, m2

s ,ms ,

1s ]

is a composition of bases functions [log s, m2

s ,ms ,

1s ]. It is advantageous to

choose prior with the same basis functions.

log p(m, s) = −12 log s − 1

2(m − µ)2

s + c − (α0 + 1) log s − β0s

=

= [−12 − α0 − 1,−1

2 , µ,−12µ

2 − β0][log s, m2

s ,ms ,

1s ]

Posterior is of the same form as prior.

Exponential family:

Likelihood of the data is in form:

p(x |θ) = h(x) exp(η(θ)>T (x)− A(θ)

)whereη(θ) is natural parameter, sometimes used η = [η1, η2, . . .]T (x) is sufficient statistics,

Use:

p(x1|θ) = h(x1) exp(η(θ)>T (x1)− A(θ)

),

p(x2|θ) = h(x2) exp(η(θ)>T (x2)− A(θ)

),

p(θ) = h0 exp(η(θ)>T0 − ν0A(θ)

)p(x1, . . . xn|θ) =

n∏i=1

h(xi ) exp(η(θ)>

n∑i=1

T (xi )− nA(θ)),

Exponential family:

Likelihood of the data is in form:

p(x |θ) = h(x) exp(η(θ)>T (x)− A(θ)

)whereη(θ) is natural parameter, sometimes used η = [η1, η2, . . .]T (x) is sufficient statistics,

Use:

p(x1|θ) = h(x1) exp(η(θ)>T (x1)− A(θ)

),

p(x2|θ) = h(x2) exp(η(θ)>T (x2)− A(θ)

),

p(θ) = h0 exp(η(θ)>T0 − ν0A(θ)

)p(x1, . . . xn|θ) =

n∏i=1

h(xi ) exp(η(θ)>

n∑i=1

T (xi )− nA(θ)),

Exponential family of normal distribution

p(x |m, s) = 1√2π

exp(−1

2 log s − 12

(m − x)2

s

)= 1√

2πexp

(−1

2 log s − 12

[m2

s − 2ms x + 1

s x2])

= 1√2π︸ ︷︷ ︸

h(x)

exp

[

ms ,−

12s

]︸ ︷︷ ︸

η

[x , x2]>︸ ︷︷ ︸

T (x)

−12 log s − m2

2s︸ ︷︷ ︸−A(θ)

Or

p(x |m, s) = 1√2π︸ ︷︷ ︸

h(x)

exp

[

ms ,−

12s ,−

m2

2s

]︸ ︷︷ ︸

η

[x , x2, 1

]>︸ ︷︷ ︸T (x)

−12 log s︸ ︷︷ ︸−A(θ)

Non-commital: any prior conjugate statistics with minimum impact ofsufficient statistics.

Prior

Prior is typically in the hands of the modeller:I for sufficient number of data, use Jeffrey’sI Conjugate prior beneficial for analytical tractabilityI Care needed for tail behaviour

0.7 0.8 0.9 1.0 1.1 1.2 1.3yt

p(y t

)

mean(yt) = 1, std(yt) = 0.1

NGiGlN

0.5 1.0 1.5yt

p(y t

)mean(yt) = 1, std(yt) = 0.3

NGiGlN

Working out the non-conjugate prior example

p(m, s|x1, x2, µ, α, β, φ)

∝ 1s

1sα0+1 exp

(−1

2(m − x1)2

s − 12

(m − x2)2

s − 12

(m − µ)2

φ− β0

s

)Conditional distributions are available!

EM algorithmI provides maximum of one marginal distribution

Variational Bayes approximationI provides conditionally independent posterior

Gibbs samplerI provides samples from posterior (efficient)

EM algorithm classical form

Splits unknown vector into two parts:1. quantity of interest θ (e.g. mean m)2. missing data z (e.g. variance s)

General EM algorithm [Dempster, Laird, Rubin, 1977]:

θ = arg maxθ

∫p(x |θ, z)p(z)dz ,

can be (approximately) found by alternating:E-step: q(θ|θ(i)) =

∫log p(x |θ, z)p(z |θ(i))dz

M-step: θ(i+1) = arg maxθ q(θ|θ(i))

E-step: θ = m , z = sProxy distribution

Q(m|m(i)) =∫

log p(x1, x2,m, s)p(s|m(i))ds

= Ep(s|m(i))(log p(x1, x2,m, s))

using

p(x1, x2|m, s) ∝ 1s exp

(−1

2(m − x1)2

s − 12

(m − x2)2

s − 12

(m − µ)2

φ

)log p(x1, x2|m, s) = c + log s − 1

2(m − x1)2

s − 12

(m − x2)2

s − 12

(m − µ)2

φ,

Q(m|m(i)) = c − 12 E( 1

s )[(m − x1)2 + (m − x2)2]− 1

2(m − µ)2

φ

E( 1s ) = α

βs = β

α

Q(m|m(i)) = c − 12

[(m − x1)2

s + (m − x2)2

s

]− 1

2(m − µ)2

φ

M-step

New point estimate m(i+1) = arg maxm q(m|m(i))1. set first derivative equal to 02. completion of squares

m(i+1) =(

+ 2s

)−1(µ

φ+ x1 + x2

s

),

based on s estimate of conditional

p(s|m(i)) = iG(α, β)α = α0 + 1,β = 0.5(m(i) − x1)2 + 0.5(m(i) − x2)2 + β0

repeat for p(s|m(i+1)).Note: meaning of variables can be swapped. E-step over m and

M-step on s.

Divergence minimization

We seek best approximation of intractable distribution p(x) in the chosenclass of parametric functions, q(x |θ), such that

θ∗ = arg minθ

D(p, q),

where D(p, q) is a statistical divergence.

Different results for different choices of: i) q(θ), and ii) D.

Variational Bayes: i) q(θ1, θ2) = q(θ1)q(θ2), and ii) Kullback-Leibler.

Divergence minimization

We seek best approximation of intractable distribution p(x) in the chosenclass of parametric functions, q(x |θ), such that

θ∗ = arg minθ

D(p, q),

where D(p, q) is a statistical divergence.

Different results for different choices of: i) q(θ), and ii) D.

Variational Bayes: i) q(θ1, θ2) = q(θ1)q(θ2), and ii) Kullback-Leibler.

Divergence minimization

We seek best approximation of intractable distribution p(x) in the chosenclass of parametric functions, q(x |θ), such that

θ∗ = arg minθ

D(p, q),

where D(p, q) is a statistical divergence.

Different results for different choices of: i) q(θ), and ii) D.

Variational Bayes: i) q(θ1, θ2) = q(θ1)q(θ2), and ii) Kullback-Leibler.

Variational Bayes

Is a divergence minimization technique with

q∗ = arg minq

KL(q||p) = arg minq

Eq

(log q

p

)q(m, s) = q(m|x1, x2)q(s|x1, x2).

which allows free-form optimization.

Result:

q(m|x1, x2) ∝ exp(Eq(s) [log p(x1, x2,m, s)]

)q(s|x1, x2) ∝ exp

(Eq(m) [log p(x1, x2,m, s)]

)which is a set of implicit functions.I Proportionality above allows to use p(x1, x2,m, s) in place of

p(m, s|x1, x2)I Variational EM algorithm (E-E algorithm).

Variational Bayes

Is a divergence minimization technique with

q∗ = arg minq

KL(q||p) = arg minq

Eq

(log q

p

)q(m, s) = q(m|x1, x2)q(s|x1, x2).

which allows free-form optimization.Result:

q(m|x1, x2) ∝ exp(Eq(s) [log p(x1, x2,m, s)]

)q(s|x1, x2) ∝ exp

(Eq(m) [log p(x1, x2,m, s)]

)which is a set of implicit functions.I Proportionality above allows to use p(x1, x2,m, s) in place of

p(m, s|x1, x2)I Variational EM algorithm (E-E algorithm).

E-step: m

Proxy distribution

q(m) = Eq(s)(log p(x1, x2,m, s))

using

p(x1, x2,m, s) ∝ 1s

1sα0+1 exp

(−1

2(m − x1)2

s − 12

(m − x2)2

s − 12

(m − µ)2

φ− β0

s

)log p(x1, x2,m, s) ∝ (α0 + 2) log s − 1

2(m − x1)2

s − 12

(m − x2)2

s − 12

(m − µ)2

φ− β0

s ,

∝ −12 E(1

s

) [(m − x1)2 + (m − x2)2]− 1

2(m − µ)2

φ

∝ −12

[(m − x1)2

s + (m − x2)2

s

]− 1

2(m − µ)2

φ

q(m) = N

(m;(

+ 2s

)−1(µ

φ+ x1 + x2

s

),

(1φ

+ 2s

)−1)

E-step: s

Proxy distribution

q(s) = Eq(m)(log p(x1, x2,m, s))

using

p(x1, x2,m, s) ∝ 1s

1sα0+1 exp

(−1

2(m − x1)2

s − 12

(m − x2)2

s − 12

(m − µ)2

φ− β0

s

)log p(x1, x2,m, s) = −(α0 + 2) log s − 1

2(m − x1)2

s − 12

(m − x2)2

s − 12

(m − µ)2

φ− β0

s ,

Eq(m)(log p(x1, ·)) = −(α0 + 2) log s − 12s Eq(m)

[(m − x1)2 + (m − x2)2]− β0

s

= −(α0 + 2) log s − 12s Eq(m)

[(m2 − 2mx1 + x2

2 + m2 − 2x2m + x22]

q(s) = iG (α, β) ,α = α0 + 1,β = 0.5E(m2)− E(m)(x1 + x2) + 0.5(x2

1 + x22 ) + β0

Toy: Variational Bayes

Factors:

q(m) = N(

m;(

+ 2s

)−1(µ

φ+ x1 + x2

s

),

(1φ

+ 2s

)−1)

q(s) = iG(α0 + 1,E(m2)− E(m)(x1 + x2) + 0.5(x2

1 + x22 ) + β0

)with

s = E(m2)− E(m)(x1 + x2) + 0.5(x21 + x2

2 ) + β0α0 + 1 ,

E(m) =(

+ 2s

)−1(µ

φ+ x1 + x2

s

),

E(m2) = E(m)2 +(

+ 2s

)−1,

which needs to be (Iterated).

Toy: Variational Bayes Iterations

p(m,s|x1=1,x2=2, 7=5, ,=0.1, -=0.1, ?=2

5 10 15s

0

1

2

3

4

5

6

m

marginal

p(s|x1,x2)q(s|x1,x2)

marginal

p(m|x1,x2)q(m|x1,x2)

Homework assignment

I Working code for EM estimation of toy problem:1. m = arg max p(m|x1, x2) (5 points)2. s = arg max p(s|x1, x2) (5 points)

I Working code for Variational Bayes estimation of the toy problem(10 points).