Probabilistic programming and optimizationArto Klami Probabilistic programming and optimization...
Transcript of Probabilistic programming and optimizationArto Klami Probabilistic programming and optimization...
Probabilistic programming and optimization
Arto Klami
March 29, 2018
Arto Klami Probabilistic programming and optimization March 29, 2018 1 / 23
Bayesian inference
Making predictions under uncertainty:
p(x̃|D) =
∫θp(x̃|θ)p(θ|D)dθ
Markov chain Monte Carlo approximates this with
p(x̃|D) ≈ 1
M
M∑m=1
p(x̃|θm)
for θm drawn from p(θ|D) using some algorithmthat hopefully produces good enough samples
Arto Klami Probabilistic programming and optimization March 29, 2018 2 / 23
Bayesian inference using optimization
We can convert the search for the posteriordistribution into optimization problem as well
1 Choose some parametric family of distributionsq(θ|λ)
2 Find q(θ|λ) that is close to p(θ|D), byminimizing some dissimilarity measure wrt λ
Arto Klami Probabilistic programming and optimization March 29, 2018 3 / 23
Bayesian inference using optimization
Why? Because we can...
...or because we care about
computational efficiency
deterministic solution
interpretability
...
Arto Klami Probabilistic programming and optimization March 29, 2018 4 / 23
Bayesian inference using optimization
Why? Because we can...
...or because we care about
computational efficiency
deterministic solution
interpretability
...
Arto Klami Probabilistic programming and optimization March 29, 2018 4 / 23
Variational approximation
A prototypical example is variational approximation where we use Kullback-Leibler divergencebetween q(θ|λ) and p(θ|D) as the loss function
Equivalent formulation: Maximize a lower bound L(λ) for the marginal probability p(D):
p(D) = Eq(θ|λ)[logp(D,θ)
q(θ|λ)] + KL(q(θ|λ)|p(θ|D)) ≥ Eq(θ|λ)[log
p(D,θ)
q(θ|λ)] = L(λ)
Arto Klami Probabilistic programming and optimization March 29, 2018 5 / 23
Variational approximation [pre 2014]
Mean-field approximation q(θ) =∏D
d=1 qd(θd) leads to very elegant coordinate ascentalgorithm with
qd(θd) ∝ eEq−d (θ)[log p(D,θ)]
If using conjugate models we can compute the expecations analytically! See e.g. Blei et al.(2016)
...but it is often quite tedious and error-prone, and extending beyond conjugate models isdifficult
Arto Klami Probabilistic programming and optimization March 29, 2018 6 / 23
Variational approximation [pre 2014]
Mean-field approximation q(θ) =∏D
d=1 qd(θd) leads to very elegant coordinate ascentalgorithm with
qd(θd) ∝ eEq−d (θ)[log p(D,θ)]
If using conjugate models we can compute the expecations analytically! See e.g. Blei et al.(2016)
...but it is often quite tedious and error-prone, and extending beyond conjugate models isdifficult
Arto Klami Probabilistic programming and optimization March 29, 2018 6 / 23
Variational approximation in probabilistic programming
Can we do this for arbitrary models and approximations, without
(a) Making the mean-field assumption
(b) Assuming conjugate models
(c) ...or even knowledge of the model when writing the inference algorithm
Yes, and under rather mild conditions. Besides being able to evalute log p(D,λ) we need toassume either of the following:
1 The approximation q(θ|λ) is differentiable wrt to λ
2 The model p(D,θ) is differentiable wrt to θ and the approximation can bereparameterized
Think of Metropolis-Hastings vs Hamiltonian Monte Carlo: Assume only evaluation of log p(·)or also its gradients
Arto Klami Probabilistic programming and optimization March 29, 2018 7 / 23
Variational approximation in probabilistic programming
Can we do this for arbitrary models and approximations, without
(a) Making the mean-field assumption
(b) Assuming conjugate models
(c) ...or even knowledge of the model when writing the inference algorithm
Yes, and under rather mild conditions. Besides being able to evalute log p(D,λ) we need toassume either of the following:
1 The approximation q(θ|λ) is differentiable wrt to λ
2 The model p(D,θ) is differentiable wrt to θ and the approximation can bereparameterized
Think of Metropolis-Hastings vs Hamiltonian Monte Carlo: Assume only evaluation of log p(·)or also its gradients
Arto Klami Probabilistic programming and optimization March 29, 2018 7 / 23
Variational approximation in probabilistic programming
Can we do this for arbitrary models and approximations, without
(a) Making the mean-field assumption
(b) Assuming conjugate models
(c) ...or even knowledge of the model when writing the inference algorithm
Yes, and under rather mild conditions. Besides being able to evalute log p(D,λ) we need toassume either of the following:
1 The approximation q(θ|λ) is differentiable wrt to λ
2 The model p(D,θ) is differentiable wrt to θ and the approximation can bereparameterized
Think of Metropolis-Hastings vs Hamiltonian Monte Carlo: Assume only evaluation of log p(·)or also its gradients
Arto Klami Probabilistic programming and optimization March 29, 2018 7 / 23
Variational approximation in probabilistic programming
Monte Carlo approximation for the bound itself is easy using
L ≈ 1
M
M∑m=1
logp(D,θm)
q(θm|λ),
where θm are drawn from the approximation
For optimization we would typically want approximations of the gradient, to be used for(stochastic) gradient ascent
λ← λ + α∇λL
Automatic differentiation can be used for both log q(·) and log p(·), but the challenge is inhandling the expectation
Arto Klami Probabilistic programming and optimization March 29, 2018 8 / 23
Black-box VI with score function estimators [Ranganath et al.AISTATS’14]
Since ∇q = q∇ log q, simple algebraic manipulation gives the score function estimator
∇λEq(θ|λ)[log p(D,θ)] = Eq(θ|λ)[∇λ log q(θ|λ) log p(D,θ)]
≈ 1
M
M∑m=1
∇λ log q(θm|λ) log p(D,θm),
where θm are again drawn from q(θ|λ)
Practical problem: Very high variance. Control variates and other variance reductiontechniques help, but only to a degree
Arto Klami Probabilistic programming and optimization March 29, 2018 9 / 23
Black-box VI with score function estimators [Ranganath et al.AISTATS’14]
Since ∇q = q∇ log q, simple algebraic manipulation gives the score function estimator
∇λEq(θ|λ)[log p(D,θ)] = Eq(θ|λ)[∇λ log q(θ|λ) log p(D,θ)]
≈ 1
M
M∑m=1
∇λ log q(θm|λ) log p(D,θm),
where θm are again drawn from q(θ|λ)
Practical problem: Very high variance. Control variates and other variance reductiontechniques help, but only to a degree
Arto Klami Probabilistic programming and optimization March 29, 2018 9 / 23
Reparameterization [Kingma et al. ICLR’13; Titsias&Lazaro-GredillaICML’14]
Assume q(θ|λ) can be written as a system
z ∼ φ(z)
θ = f (z,λ)
where φ(z) is some simple distribution that does not depend on λ
Then L = Eq[log p(D,θ)] = Eφ[log p(D, f (z,λ))] leads to the reparameterization estimator
∇λL = Eφ[∇λ log p(D, f (z,λ))] ≈ 1
M
M∑m=1
∇λ log p(D, f (zm,λ))
for zm ∼ φ(z)
Much lower variance, but requires propagating derivatives through log p(·) and f (·)
Arto Klami Probabilistic programming and optimization March 29, 2018 10 / 23
Reparameterization [Kingma et al. ICLR’13; Titsias&Lazaro-GredillaICML’14]
Assume q(θ|λ) can be written as a system
z ∼ φ(z)
θ = f (z,λ)
where φ(z) is some simple distribution that does not depend on λ
Then L = Eq[log p(D,θ)] = Eφ[log p(D, f (z,λ))] leads to the reparameterization estimator
∇λL = Eφ[∇λ log p(D, f (z,λ))] ≈ 1
M
M∑m=1
∇λ log p(D, f (zm,λ))
for zm ∼ φ(z)
Much lower variance, but requires propagating derivatives through log p(·) and f (·)Arto Klami Probabilistic programming and optimization March 29, 2018 10 / 23
Reparameterization
Simplest example:
q(θ|µ, σ) = N (µ, σ2) ≡ z ∼ N (0, 1), θ = µ+ σz
The gradient becomes
∂L∂µ
=∂L∂θ
∂f (z ,λ)
∂µ=∂L∂θ
∂L∂σ
=∂L∂θ
∂f (z ,λ)
∂σ=∂L∂θ
z
and when computing the expectation we need to also take into account the change of volumeunder the transformation
log
∣∣∣∣∂f −1(θ)
∂µ
∣∣∣∣ = 0 log
∣∣∣∣∂f −1(θ)
∂σ
∣∣∣∣ = log σ2
Arto Klami Probabilistic programming and optimization March 29, 2018 11 / 23
Reparameterization
Reparameterization can be done for:
All one-liner samplers (normal, exponential, cauchy, ...)
Rejection samplers (gamma, beta, dirichlet, ...) [Naesseth et al. AISTATS’17]
Chains of transformations (log-normal, inv-gamma, ...)
Example of a chain: z ∼ N (0, 1) t = µ+ σz θ = et
Stan does variational inference by chaining reparameterization for normal distribution withpre-defined transformations for handling finite supports [Kucukelbir et al. 2015]
Arto Klami Probabilistic programming and optimization March 29, 2018 12 / 23
Importance sampling estimator [Sakaya&Klami, UAI’17]
The reparameterization estimate is computed using the chain rule as
1
M
M∑m=1
∇λ log p(D, f (zm,λ)) =1
M
M∑m=1
∂f (zm,λ)
∂λ∇θ log p(D,θm)
Typically ∇θ log p(D,θm) is computationally expensive, requiring looping over data samples
and computing derivatives over potentially complex probabilistic program, whereas ∂f (zm,λ)∂λ is
easy
Arto Klami Probabilistic programming and optimization March 29, 2018 13 / 23
Importance sampling estimator
Let’s take a look at a single step during optimization:
Draw samples zm ∼ φ(z)
Convert them to θm = f (zm,λ)
Evaluate am = ∇θ log p(D,θm) and Bm = ∂f (zm,λ)∂λ
Update λ towards 1M
∑m Bmam
Seems stupid to keep on computing ∇θ log p(D,θm) that does not actually depend on λ
Arto Klami Probabilistic programming and optimization March 29, 2018 14 / 23
Importance sampling estimator
Alternative optimization step:
Assume am = ∇θ log p(D,θm) exists for some θm
Solve for zm = f −1(θm,λ)
Evaluate Bm = ∂f (zm,λ)∂λ
Update λ towards 1M
∑m Bmam
A lot less computation, but we are no longer doing Monte Carlo integration over q(θ|λ)
Arto Klami Probabilistic programming and optimization March 29, 2018 15 / 23
Importance sampling estimator
Alternative optimization step:
Assume am = ∇θ log p(D,θm) exists for some θm
...and they were drawn from some old q(θ|λo) with zomSolve for zm = f −1(θm,λ)
Evaluate Bm = ∂f (zm,λ)∂λ
Compute importance sampling weight wm = φ(zm)φ(zom)
Update λ towards 1M
∑m wmBmam
Fixed! Still a lot less computation and now we we use importance sampling to compute theexpectation over q(θ|λ) using θm drawn from a proposal distribution q(θ|λo)
Arto Klami Probabilistic programming and optimization March 29, 2018 16 / 23
Importance-sampled SGD [Sakaya&Klami, UAI’17]
We can use importance-sampled gradient estimates in any gradient-based optimizationalgorithm
Whenever we draw θ, store the gradient ∇θ log p(D,θ) and also zm
When computing gradients, re-use the old ones but weight them
When θm is too unlikely under the current q(θ|λ) the weight tends to zero: Re-sample θand re-compute the gradient of L as well
Can also be used with stochastic average gradients, providing a variant that handles stalegradients properly
Works also with score function estimators, but the speedup is way smaller
Arto Klami Probabilistic programming and optimization March 29, 2018 17 / 23
Importance sampling estimator
Arto Klami Probabilistic programming and optimization March 29, 2018 18 / 23
Importance sampling estimator
Arto Klami Probabilistic programming and optimization March 29, 2018 18 / 23
Importance sampling estimator
Order of magnitude speedup compared to re-computing gradients every time
...but only when the approximation factors into terms of sufficiently low dimensionality – inhigher dimensions already small changes in λ drive the weights to zero. Note that log p(·)does not need to factorize, so we only make assumptions on the approximation.
Arto Klami Probabilistic programming and optimization March 29, 2018 19 / 23
Importance sampling estimator
Order of magnitude speedup compared to re-computing gradients every time
...but only when the approximation factors into terms of sufficiently low dimensionality – inhigher dimensions already small changes in λ drive the weights to zero. Note that log p(·)does not need to factorize, so we only make assumptions on the approximation.
Arto Klami Probabilistic programming and optimization March 29, 2018 19 / 23
Relationship with deep learning
Variational autoencoders [Kingma et al. ICLR’13] are simple latent variable models that usereparameterization gradients
Model:
zn ∼ N (0, I )
εn ∼ N (0, σ2)
xn = f (zn, η) + εn
Approximation:q(zn|λ) = N (gµ(xn,λµ), gΣ(xn,λΣ))
Here f (·) and g(·) are neural networks
Amortized inference: When q(z |λ) is parameterized using a function from the inputs
Arto Klami Probabilistic programming and optimization March 29, 2018 20 / 23
Relationship with deep learning
Variational autoencoders [Kingma et al. ICLR’13] are simple latent variable models that usereparameterization gradients
Model:
zn ∼ N (0, I )
εn ∼ N (0, σ2)
xn = f (zn, η) + εn
Approximation:q(zn|λ) = N (gµ(xn,λµ), gΣ(xn,λΣ))
Here f (·) and g(·) are neural networks
Amortized inference: When q(z |λ) is parameterized using a function from the inputs
Arto Klami Probabilistic programming and optimization March 29, 2018 20 / 23
Relationship with deep learning
VAE is hence a probabilistic program, but searching for point estimates of the neural networkparameters
The variational approximation itself is rather naive and cannot model complex posteriors: Eventhough the mean and covariance are nonlinear functions, the distribution itself is still normal
Normalizing flows [Rezende et al., ICML’15] are a more serious attempt of building flexibleapproximations: Use neural network as a chain of reparameterization transformations
Arto Klami Probabilistic programming and optimization March 29, 2018 21 / 23
Take home messages
Variational inference over arbitrary probabilistic programs is possible, but besides modelwe need to somehow specify the approximation as well
Can be done with the same basic tools as deep learning, since all we need is automaticdifferentiation, sampling from standard distributions and SGD
Consequently easy to merge with DL models as well, but typically we still take pointestimates over the network parameters
Arto Klami Probabilistic programming and optimization March 29, 2018 22 / 23
References
1 Blei, Kucukelbir, McAuliffe. Variational inference: A review for statisticians, 2016.
2 Ranganath, Gerrish, Blei. Black box variational inference, AISTATS, 2014.
3 Kingma, Welling. Auto-encoding variational Bayes, ICLR, 2013.
4 Titsias, Lazaro-Gredilla. Doubly stochastic variational Bayes for non-conjugate inference,ICML, 2014.
5 Naesseth, Ruiz, Linderman, Blei. Reparameterization gradients throughacceptance-rejection sampling algorithms, AISTATS, 2017.
6 Kucukelbir, Ranganath, Gelman, Blei. Automatic variational inference in Stan, 2015.
7 Sakaya, Klami. Importance sampled stochastic optimization for variational inference, UAI,2017.
8 Rezende, Mohamed. Variational inference with normalizing flows, ICML, 2015.
Arto Klami Probabilistic programming and optimization March 29, 2018 23 / 23