# Probabilistic programming and optimizationArto Klami Probabilistic programming and optimization...

### Transcript of Probabilistic programming and optimizationArto Klami Probabilistic programming and optimization...

Probabilistic programming and optimization

Arto Klami

March 29, 2018

Arto Klami Probabilistic programming and optimization March 29, 2018 1 / 23

Bayesian inference

Making predictions under uncertainty:

p(x̃|D) =

∫θp(x̃|θ)p(θ|D)dθ

Markov chain Monte Carlo approximates this with

p(x̃|D) ≈ 1

M

M∑m=1

p(x̃|θm)

for θm drawn from p(θ|D) using some algorithmthat hopefully produces good enough samples

Arto Klami Probabilistic programming and optimization March 29, 2018 2 / 23

Bayesian inference using optimization

We can convert the search for the posteriordistribution into optimization problem as well

1 Choose some parametric family of distributionsq(θ|λ)

2 Find q(θ|λ) that is close to p(θ|D), byminimizing some dissimilarity measure wrt λ

Arto Klami Probabilistic programming and optimization March 29, 2018 3 / 23

Bayesian inference using optimization

Why? Because we can...

...or because we care about

computational efficiency

deterministic solution

interpretability

...

Arto Klami Probabilistic programming and optimization March 29, 2018 4 / 23

Bayesian inference using optimization

Why? Because we can...

...or because we care about

computational efficiency

deterministic solution

interpretability

...

Arto Klami Probabilistic programming and optimization March 29, 2018 4 / 23

Variational approximation

A prototypical example is variational approximation where we use Kullback-Leibler divergencebetween q(θ|λ) and p(θ|D) as the loss function

Equivalent formulation: Maximize a lower bound L(λ) for the marginal probability p(D):

p(D) = Eq(θ|λ)[logp(D,θ)

q(θ|λ)] + KL(q(θ|λ)|p(θ|D)) ≥ Eq(θ|λ)[log

p(D,θ)

q(θ|λ)] = L(λ)

Arto Klami Probabilistic programming and optimization March 29, 2018 5 / 23

Variational approximation [pre 2014]

Mean-field approximation q(θ) =∏D

d=1 qd(θd) leads to very elegant coordinate ascentalgorithm with

qd(θd) ∝ eEq−d (θ)[log p(D,θ)]

If using conjugate models we can compute the expecations analytically! See e.g. Blei et al.(2016)

...but it is often quite tedious and error-prone, and extending beyond conjugate models isdifficult

Arto Klami Probabilistic programming and optimization March 29, 2018 6 / 23

Variational approximation [pre 2014]

Mean-field approximation q(θ) =∏D

d=1 qd(θd) leads to very elegant coordinate ascentalgorithm with

qd(θd) ∝ eEq−d (θ)[log p(D,θ)]

If using conjugate models we can compute the expecations analytically! See e.g. Blei et al.(2016)

...but it is often quite tedious and error-prone, and extending beyond conjugate models isdifficult

Arto Klami Probabilistic programming and optimization March 29, 2018 6 / 23

Variational approximation in probabilistic programming

Can we do this for arbitrary models and approximations, without

(a) Making the mean-field assumption

(b) Assuming conjugate models

(c) ...or even knowledge of the model when writing the inference algorithm

Yes, and under rather mild conditions. Besides being able to evalute log p(D,λ) we need toassume either of the following:

1 The approximation q(θ|λ) is differentiable wrt to λ

2 The model p(D,θ) is differentiable wrt to θ and the approximation can bereparameterized

Think of Metropolis-Hastings vs Hamiltonian Monte Carlo: Assume only evaluation of log p(·)or also its gradients

Arto Klami Probabilistic programming and optimization March 29, 2018 7 / 23

Variational approximation in probabilistic programming

Can we do this for arbitrary models and approximations, without

(a) Making the mean-field assumption

(b) Assuming conjugate models

(c) ...or even knowledge of the model when writing the inference algorithm

Yes, and under rather mild conditions. Besides being able to evalute log p(D,λ) we need toassume either of the following:

1 The approximation q(θ|λ) is differentiable wrt to λ

2 The model p(D,θ) is differentiable wrt to θ and the approximation can bereparameterized

Think of Metropolis-Hastings vs Hamiltonian Monte Carlo: Assume only evaluation of log p(·)or also its gradients

Arto Klami Probabilistic programming and optimization March 29, 2018 7 / 23

Variational approximation in probabilistic programming

Can we do this for arbitrary models and approximations, without

(a) Making the mean-field assumption

(b) Assuming conjugate models

(c) ...or even knowledge of the model when writing the inference algorithm

Yes, and under rather mild conditions. Besides being able to evalute log p(D,λ) we need toassume either of the following:

1 The approximation q(θ|λ) is differentiable wrt to λ

2 The model p(D,θ) is differentiable wrt to θ and the approximation can bereparameterized

Think of Metropolis-Hastings vs Hamiltonian Monte Carlo: Assume only evaluation of log p(·)or also its gradients

Arto Klami Probabilistic programming and optimization March 29, 2018 7 / 23

Variational approximation in probabilistic programming

Monte Carlo approximation for the bound itself is easy using

L ≈ 1

M

M∑m=1

logp(D,θm)

q(θm|λ),

where θm are drawn from the approximation

For optimization we would typically want approximations of the gradient, to be used for(stochastic) gradient ascent

λ← λ + α∇λL

Automatic differentiation can be used for both log q(·) and log p(·), but the challenge is inhandling the expectation

Arto Klami Probabilistic programming and optimization March 29, 2018 8 / 23

Black-box VI with score function estimators [Ranganath et al.AISTATS’14]

Since ∇q = q∇ log q, simple algebraic manipulation gives the score function estimator

∇λEq(θ|λ)[log p(D,θ)] = Eq(θ|λ)[∇λ log q(θ|λ) log p(D,θ)]

≈ 1

M

M∑m=1

∇λ log q(θm|λ) log p(D,θm),

where θm are again drawn from q(θ|λ)

Practical problem: Very high variance. Control variates and other variance reductiontechniques help, but only to a degree

Arto Klami Probabilistic programming and optimization March 29, 2018 9 / 23

Black-box VI with score function estimators [Ranganath et al.AISTATS’14]

Since ∇q = q∇ log q, simple algebraic manipulation gives the score function estimator

∇λEq(θ|λ)[log p(D,θ)] = Eq(θ|λ)[∇λ log q(θ|λ) log p(D,θ)]

≈ 1

M

M∑m=1

∇λ log q(θm|λ) log p(D,θm),

where θm are again drawn from q(θ|λ)

Practical problem: Very high variance. Control variates and other variance reductiontechniques help, but only to a degree

Arto Klami Probabilistic programming and optimization March 29, 2018 9 / 23

Reparameterization [Kingma et al. ICLR’13; Titsias&Lazaro-GredillaICML’14]

Assume q(θ|λ) can be written as a system

z ∼ φ(z)

θ = f (z,λ)

where φ(z) is some simple distribution that does not depend on λ

Then L = Eq[log p(D,θ)] = Eφ[log p(D, f (z,λ))] leads to the reparameterization estimator

∇λL = Eφ[∇λ log p(D, f (z,λ))] ≈ 1

M

M∑m=1

∇λ log p(D, f (zm,λ))

for zm ∼ φ(z)

Much lower variance, but requires propagating derivatives through log p(·) and f (·)

Arto Klami Probabilistic programming and optimization March 29, 2018 10 / 23

Reparameterization [Kingma et al. ICLR’13; Titsias&Lazaro-GredillaICML’14]

Assume q(θ|λ) can be written as a system

z ∼ φ(z)

θ = f (z,λ)

where φ(z) is some simple distribution that does not depend on λ

Then L = Eq[log p(D,θ)] = Eφ[log p(D, f (z,λ))] leads to the reparameterization estimator

∇λL = Eφ[∇λ log p(D, f (z,λ))] ≈ 1

M

M∑m=1

∇λ log p(D, f (zm,λ))

for zm ∼ φ(z)

Much lower variance, but requires propagating derivatives through log p(·) and f (·)Arto Klami Probabilistic programming and optimization March 29, 2018 10 / 23

Reparameterization

Simplest example:

q(θ|µ, σ) = N (µ, σ2) ≡ z ∼ N (0, 1), θ = µ+ σz

The gradient becomes

∂L∂µ

=∂L∂θ

∂f (z ,λ)

∂µ=∂L∂θ

∂L∂σ

=∂L∂θ

∂f (z ,λ)

∂σ=∂L∂θ

z

and when computing the expectation we need to also take into account the change of volumeunder the transformation

log

∣∣∣∣∂f −1(θ)

∂µ

∣∣∣∣ = 0 log

∣∣∣∣∂f −1(θ)

∂σ

∣∣∣∣ = log σ2

Arto Klami Probabilistic programming and optimization March 29, 2018 11 / 23

Reparameterization

Reparameterization can be done for:

All one-liner samplers (normal, exponential, cauchy, ...)

Rejection samplers (gamma, beta, dirichlet, ...) [Naesseth et al. AISTATS’17]

Chains of transformations (log-normal, inv-gamma, ...)

Example of a chain: z ∼ N (0, 1) t = µ+ σz θ = et

Stan does variational inference by chaining reparameterization for normal distribution withpre-defined transformations for handling finite supports [Kucukelbir et al. 2015]

Arto Klami Probabilistic programming and optimization March 29, 2018 12 / 23

Importance sampling estimator [Sakaya&Klami, UAI’17]

The reparameterization estimate is computed using the chain rule as

1

M

M∑m=1

∇λ log p(D, f (zm,λ)) =1

M

M∑m=1

∂f (zm,λ)

∂λ∇θ log p(D,θm)

Typically ∇θ log p(D,θm) is computationally expensive, requiring looping over data samples

and computing derivatives over potentially complex probabilistic program, whereas ∂f (zm,λ)∂λ is

easy

Arto Klami Probabilistic programming and optimization March 29, 2018 13 / 23

Importance sampling estimator

Let’s take a look at a single step during optimization:

Draw samples zm ∼ φ(z)

Convert them to θm = f (zm,λ)

Evaluate am = ∇θ log p(D,θm) and Bm = ∂f (zm,λ)∂λ

Update λ towards 1M

∑m Bmam

Seems stupid to keep on computing ∇θ log p(D,θm) that does not actually depend on λ

Arto Klami Probabilistic programming and optimization March 29, 2018 14 / 23

Importance sampling estimator

Alternative optimization step:

Assume am = ∇θ log p(D,θm) exists for some θm

Solve for zm = f −1(θm,λ)

Evaluate Bm = ∂f (zm,λ)∂λ

Update λ towards 1M

∑m Bmam

A lot less computation, but we are no longer doing Monte Carlo integration over q(θ|λ)

Arto Klami Probabilistic programming and optimization March 29, 2018 15 / 23

Importance sampling estimator

Alternative optimization step:

Assume am = ∇θ log p(D,θm) exists for some θm

...and they were drawn from some old q(θ|λo) with zomSolve for zm = f −1(θm,λ)

Evaluate Bm = ∂f (zm,λ)∂λ

Compute importance sampling weight wm = φ(zm)φ(zom)

Update λ towards 1M

∑m wmBmam

Fixed! Still a lot less computation and now we we use importance sampling to compute theexpectation over q(θ|λ) using θm drawn from a proposal distribution q(θ|λo)

Arto Klami Probabilistic programming and optimization March 29, 2018 16 / 23

Importance-sampled SGD [Sakaya&Klami, UAI’17]

We can use importance-sampled gradient estimates in any gradient-based optimizationalgorithm

Whenever we draw θ, store the gradient ∇θ log p(D,θ) and also zm

When computing gradients, re-use the old ones but weight them

When θm is too unlikely under the current q(θ|λ) the weight tends to zero: Re-sample θand re-compute the gradient of L as well

Can also be used with stochastic average gradients, providing a variant that handles stalegradients properly

Works also with score function estimators, but the speedup is way smaller

Arto Klami Probabilistic programming and optimization March 29, 2018 17 / 23

Importance sampling estimator

Arto Klami Probabilistic programming and optimization March 29, 2018 18 / 23

Importance sampling estimator

Arto Klami Probabilistic programming and optimization March 29, 2018 18 / 23

Importance sampling estimator

Order of magnitude speedup compared to re-computing gradients every time

...but only when the approximation factors into terms of sufficiently low dimensionality – inhigher dimensions already small changes in λ drive the weights to zero. Note that log p(·)does not need to factorize, so we only make assumptions on the approximation.

Arto Klami Probabilistic programming and optimization March 29, 2018 19 / 23

Importance sampling estimator

Order of magnitude speedup compared to re-computing gradients every time

...but only when the approximation factors into terms of sufficiently low dimensionality – inhigher dimensions already small changes in λ drive the weights to zero. Note that log p(·)does not need to factorize, so we only make assumptions on the approximation.

Arto Klami Probabilistic programming and optimization March 29, 2018 19 / 23

Relationship with deep learning

Variational autoencoders [Kingma et al. ICLR’13] are simple latent variable models that usereparameterization gradients

Model:

zn ∼ N (0, I )

εn ∼ N (0, σ2)

xn = f (zn, η) + εn

Approximation:q(zn|λ) = N (gµ(xn,λµ), gΣ(xn,λΣ))

Here f (·) and g(·) are neural networks

Amortized inference: When q(z |λ) is parameterized using a function from the inputs

Arto Klami Probabilistic programming and optimization March 29, 2018 20 / 23

Relationship with deep learning

Variational autoencoders [Kingma et al. ICLR’13] are simple latent variable models that usereparameterization gradients

Model:

zn ∼ N (0, I )

εn ∼ N (0, σ2)

xn = f (zn, η) + εn

Approximation:q(zn|λ) = N (gµ(xn,λµ), gΣ(xn,λΣ))

Here f (·) and g(·) are neural networks

Amortized inference: When q(z |λ) is parameterized using a function from the inputs

Arto Klami Probabilistic programming and optimization March 29, 2018 20 / 23

Relationship with deep learning

VAE is hence a probabilistic program, but searching for point estimates of the neural networkparameters

The variational approximation itself is rather naive and cannot model complex posteriors: Eventhough the mean and covariance are nonlinear functions, the distribution itself is still normal

Normalizing flows [Rezende et al., ICML’15] are a more serious attempt of building flexibleapproximations: Use neural network as a chain of reparameterization transformations

Arto Klami Probabilistic programming and optimization March 29, 2018 21 / 23

Take home messages

Variational inference over arbitrary probabilistic programs is possible, but besides modelwe need to somehow specify the approximation as well

Can be done with the same basic tools as deep learning, since all we need is automaticdifferentiation, sampling from standard distributions and SGD

Consequently easy to merge with DL models as well, but typically we still take pointestimates over the network parameters

Arto Klami Probabilistic programming and optimization March 29, 2018 22 / 23

References

1 Blei, Kucukelbir, McAuliffe. Variational inference: A review for statisticians, 2016.

2 Ranganath, Gerrish, Blei. Black box variational inference, AISTATS, 2014.

3 Kingma, Welling. Auto-encoding variational Bayes, ICLR, 2013.

4 Titsias, Lazaro-Gredilla. Doubly stochastic variational Bayes for non-conjugate inference,ICML, 2014.

5 Naesseth, Ruiz, Linderman, Blei. Reparameterization gradients throughacceptance-rejection sampling algorithms, AISTATS, 2017.

6 Kucukelbir, Ranganath, Gelman, Blei. Automatic variational inference in Stan, 2015.

7 Sakaya, Klami. Importance sampled stochastic optimization for variational inference, UAI,2017.

8 Rezende, Mohamed. Variational inference with normalizing flows, ICML, 2015.

Arto Klami Probabilistic programming and optimization March 29, 2018 23 / 23