of 39

• date post

25-Jul-2020
• Category

Documents

• view

2

0

Embed Size (px)

Transcript of Gaussian variational approximation with structured ... Gaussian variational approximation with...

• Gaussian variational approximation with structured covariance matrices

David Nott

Department of Statistics and Applied Probability National University of Singapore

Collaborators: Linda Tan, Victor Ong, Michael Smith, Matias Quiroz, Robert Kohn

David Nott, NUS Gaussian variational approximation 1 / 39

• Bayesian computation as usual

Data y to be observed, unknowns θ. Construct a model for (y , θ):

p(y , θ) =

prior ³·µ p(θ)

likelihood ³¹¹¹¹¹¹·¹¹¹¹¹¹µ p(y ∣θ)

Condition on the observed y :

p(θ∣y) ´¹¹¹¹¹¹¸¹¹¹¹¹¹¶ posterior

∝ p(θ)p(y ∣θ)

Summarization of the posterior is done using algorithms like MCMC, sequential Monte Carlo, etc.

These usual Monte Carlo algorithms are exact in principle.

David Nott, NUS Gaussian variational approximation 2 / 39

• Variational approximations

Increasingly there is interest in algorithms that do not possess the exact in principle property.

Why approximate inference? Make use of the scalability of optimization-based approaches to computation. Approximate inference methods are enough to understand why certain models are unsuitable. Approximate inference methods may perform as well as exact methods for predictive inference.

This talk concerns a popular framework for approximate inference, variational approximation.

David Nott, NUS Gaussian variational approximation 3 / 39

• Variational inference basics Blei et al., 2017, Ormerod and Wand, 2012

Variational approximation reformulates Bayesian computation as an optimization problem.

Define an approximating family with some parameters (Gaussian for example). Then

Define some measure of "closeness" of an approximation to the true posterior. Optimize that measure over the approximating family.

Variational parameters to be optimized will be denoted λ.

David Nott, NUS Gaussian variational approximation 4 / 39

• Gaussian variational approximation, toy example

David Nott, NUS Gaussian variational approximation 5 / 39

• Gaussian variational approximation, toy example

David Nott, NUS Gaussian variational approximation 6 / 39

• Gaussian variational approximation, toy example

David Nott, NUS Gaussian variational approximation 7 / 39

• Gaussian variational approximation

In this talk we consider a multivariate normal approximation denoted qλ(θ) = N(µ,Σ) with λ = (µ,Σ) to be optimized.

The best normal approximation will be in the sense of minimizing Kullback-Leibler (KL) divergence,

KL(p(θ∣y)∣∣qλ(θ)) = log p(y) − ∫ log p(θ)p(y ∣θ)

qλ(θ) qλ(θ)dθ

= log p(y) −L(λ)

where p(y) = ∫ p(θ)p(y ∣θ) and L(λ) is called the variational lower bound.

Minimizing the KL divergence is equivalent to maximizing L(λ).

David Nott, NUS Gaussian variational approximation 8 / 39

• Gaussian variational approximation

The optimization of L(λ) is challenging for a normal family when the dimension d of θ is large.

Key difficulty: With no restriction on Σ in qλ(θ) = N(µ,Σ), there are d + d(d + 1)/2 parameters to optimize.

For high-dimensional θ we need reduced parametrizations for Σ. Exploiting conditional independence structure to motivate sparsity in Σ−1 (Tan and Nott, 2017). Factor models

David Nott, NUS Gaussian variational approximation 9 / 39

• How should we optimize? Stochastic gradient ascent Robbins and Monro, 1951

We will use stochastic gradient ascent methods for optimizing the lower bound.

Suppose that ∇λL(λ) is the gradient of L(λ) and that ∇̂λL(λ) is an unbiased estimate of it.

Initialize λ(0)

for t = 0,1, . . . and until some stopping rule is satisfied

λ(t+1) = λ(t) + ρt ̂∇λL(λ(t))

Typically the learning rate sequence ρt , t ≥ 0 satisfies ∑t ρt =∞, ∑t ρ2t

• Reparametrization gradients Kingma and Welling, 2013, Rezende et al., 2014

Low variance gradient estimates are crucial for stability and fast convergence of stochastic optimization - achieved using the reparametrization trick.

Writing h(θ) = p(θ)p(y ∣θ), the lower bound is

L(λ) = ∫ {log h(θ) − log qλ(θ)}qλ(θ) dθ.

Suppose that for the variational family qλ(θ) we can write θ ∼ qλ(θ) as θ = t(z, λ) where z ∼ f (z) and f (⋅) does not depend on λ. Then

L(λ) = Ef (log h(t(z, λ)) − log qλ(t(z, λ))

Differentiating under the integral sign, ∇λL(λ) is an expectation with respect to f (⋅): simulation from f (⋅) gives unbiased estimates.

David Nott, NUS Gaussian variational approximation 11 / 39

• Reparametrization gradients for the Gaussian family Titsias and Lázaro-Gredilla, 2014, Kucukelbir et al., 2016

Suppose qλ(θ) = N(θ;µ,Σ = CCT ) where µ is the mean vector, Σ is the covariance matrix, C the Cholesky factor of Σ.

Variational parameter λ = (µ,C). We can write θ ∼ qλ(θ) as

θ = µ +Cz z ∼ f (z) = N(0, I)

This shows that the Gaussian family has the structure required for the reparametrization trick.

David Nott, NUS Gaussian variational approximation 12 / 39

• Gaussian variational approximation using conditional independence structure

The low variance gradient estimates provided by the reparametrization trick are crucial to efficient and stable stochastic variational optimization methods. However, learning a Gaussian variational approximation is still hard if we parametrize Σ with a dense Cholesky factor: The number of parameters in the covariance matrix grows quadratically with the parameter dimension.

We need to use the structure of the model to obtain parsimonious structured paremetrizations of covariance matrices suitable for Gaussian variational approximations in high dimensions.

David Nott, NUS Gaussian variational approximation 13 / 39

• Gaussian variational approximation

What exploitable structure is available? Conditional independence structure.

Can we match the true conditional independence structure in the Gaussian approximation to make such approximations practical in high dimensions?

For a Gaussian random vector with covariance matrix Σ, if Ω = Σ−1 then Ωij = 0 implies variables i and j are conditionally independent given the rest.

David Nott, NUS Gaussian variational approximation 14 / 39

• Motivating example: longitudinal generalized linear mixed model

Observations y = (y1, . . . ,yn), yi = (yi1, . . . ,yini ) ⊺.

Observation specific random effects b = (b1, . . . ,bn)⊺. Assume bi ∼ N(0,G) say. The yi are conditionally independent given b, likelihood

n ∏ i=1

p(yi ∣bi , η)

where η denotes fixed effects and variance parameters. Joint posterior for θ = (b⊺, η⊺)⊺

p(θ∣y)∝ p(η) n ∏ i=1

p(bi ∣η)p(yi ∣bi , η).

In the joint posterior bi and bj , i ≠ j are conditionally independent given η.

David Nott, NUS Gaussian variational approximation 15 / 39

• Motivating example: longitudinal generalized linear mixed model

Consider a sparse Ω = Σ−1 in Gaussian variational approximation of the form:

Ω =

⎛ ⎜⎜⎜⎜⎜⎜ ⎝

b1 b2 . . . bn η b1 Ω11 0 . . . 0 Ω1,n+1 b2 0 Ω22 . . . 0 Ω2,n+1 ⋮ ⋮ ⋮ ⋱ ⋮ ⋮ bn 0 0 . . . Ωnn Ωn,n+1 η Ωn+1,1 Ωn+1,2 . . . Ωn+1,n Ωn+1,n+1

⎞ ⎟⎟⎟⎟⎟⎟ ⎠

David Nott, NUS Gaussian variational approximation 16 / 39

• Motivating example: longitudinal generalized linear mixed model

It will be convenient later to parametrize Ω in terms of its Cholesky factor, Ω = TT ⊺ say where T is lower triangular. By imposing sparse structure on T we can impose sparse structure on Ω. The leftmost non-zero entries in each row match. Choose T of the form

T =

⎛ ⎜⎜⎜⎜⎜⎜ ⎝

b1 b2 . . . bn η b1 T11 0 . . . 0 0 b2 0 T22 . . . 0 0 ⋮ ⋮ ⋮ ⋱ ⋮ ⋮ bn 0 0 . . . TNN 0 η TN+1,1 TN+1,2 . . . TN+1,N TN+1,N+1

⎞ ⎟⎟⎟⎟⎟⎟ ⎠

David Nott, NUS Gaussian variational approximation 17 / 39

• More general framework

Observations y = (y1, . . . ,yn), observation specific latent variables b1 . . . ,bn, global parameters η. Joint model p(y , θ) for (y , θ), θ = (b⊺, η⊺)⊺ of the form

p(η){ n ∏ i=1

p(yi ∣bi , η)}p(b1, . . . ,bk ∣η){∏ i>k

p(bi ∣bi−1, . . . ,bi−k , η)} .

for some 0 ≤ k ≤ n bi is conditionally independent of the other latent variables in p(θ∣y) given η and k neighbouring values. Our previous random effects example fits this structure with k = 0. A state space model for a time series where the bi are the states fits this structure with k = 1. For a state space model, use Ω of the form

David Nott, NUS Gaussian variational approximation 18 / 39

• More general framework

Ω =

⎛ ⎜⎜⎜⎜⎜⎜⎜⎜⎜⎜⎜ ⎝

b1 b2 b3 ⋯ bn−1 bn η b1 Ω11 Ω⊺21 0 . . . 0 0 Ω

⊺ N+1,1

b2 Ω21 Ω22 Ω⊺32 . . . 0 0 Ω ⊺ N+1,2

b3 0 Ω32 Ω33 . . . 0 0 Ω⊺N+1,3 ⋮ ⋮ ⋮ ⋮ ⋱ . . . ⋮ ⋮ bn−1 0 0 0 . . . Ωn−1,n−1 Ω⊺n,n−1 Ω