Counterfactual Risk Minimization

Post on 18-Oct-2021

13 views 0 download

Transcript of Counterfactual Risk Minimization

Counterfactual Risk MinimizationLearning from logged bandit

feedbackAdith Swaminathan, Thorsten Joachims

Software: http://www.cs.cornell.edu/~adith/poem/

Learning frameworks

๐’™

๐’šOnline Batch

Full Information Perceptron, โ€ฆ SVM, ...

Bandit Feedback LinUCB, โ€ฆ ?

2

Alice: Cool!

SpaceX launch

Logged bandit feedback is everywhere!

3

๐‘ฅ๐‘–: Query

๐‘ฅ1 = (๐ด๐‘™๐‘–๐‘๐‘’, ๐‘ก๐‘’๐‘โ„Ž)๐‘ฆ1 = (๐‘†๐‘๐‘Ž๐‘๐‘’๐‘‹)ฮด1 = 1

๐‘ฆ๐‘–: Prediction

ฮด๐‘–: Feedback

Alice: Want Tech news

Goal

โ€ข Risk of โ„Ž: ๐• โ†ฆ ๐•๐‘… โ„Ž = ๐”ผ๐‘ฅ ฮด(๐‘ฅ, โ„Ž ๐‘ฅ )

โ€ข Find โ„Žโˆ— โˆˆ ๐“— with minimum risk

โ€ข Can we find โ„Žโˆ— using ๐’Ÿ = collected from โ„Ž0?

4

๐‘ฅ1 = (๐ด๐‘™๐‘–๐‘๐‘’, ๐‘ก๐‘’๐‘โ„Ž)๐‘ฆ1 = (๐‘†๐‘๐‘Ž๐‘๐‘’๐‘‹)ฮด1 = 1

๐‘ฅ2 = (๐ด๐‘™๐‘–๐‘๐‘’, ๐‘ ๐‘๐‘œ๐‘Ÿ๐‘ก๐‘ )๐‘ฆ2 = (๐น1)ฮด2 = 5

Learning by replaying logs?

โ€ข Training/evaluation from logged data is counter-factual [Bottou et al]

5

๐‘ฅ1 = (๐ด๐‘™๐‘–๐‘๐‘’, ๐‘ก๐‘’๐‘โ„Ž)๐‘ฆ1 = (๐‘†๐‘๐‘Ž๐‘๐‘’๐‘‹)ฮด1 = 1

๐‘ฅ2 = (๐ด๐‘™๐‘–๐‘๐‘’, ๐‘ ๐‘๐‘œ๐‘Ÿ๐‘ก๐‘ )๐‘ฆ2 = (๐น1)ฮด2 = 5

๐‘ฅ1: (Alice,tech)

๐‘ฆ1โ€ฒ : Apple watch

What would Alice do ??

๐‘ฅ โˆผ Pr(๐•), however ๐‘ฆ = โ„Ž0(๐‘ฅ), not โ„Ž(๐‘ฅ)

โ€ข Stochastic โ„Ž: ๐• โ†ฆ โˆ†(๐•), ๐‘ฆ โˆผ โ„Ž(๐‘ฅ)๐‘… โ„Ž = ๐”ผ๐‘ฅ๐”ผ๐‘ฆโˆผโ„Ž(๐‘ฅ) ฮด(๐‘ฅ, ๐‘ฆ)

Stochastic policies to the rescue!

Likely

Unlikely

๐•

SpaceX launch

6

โ„Ž0โ„Ž

Counterfactual risk estimators

โ€ข ๐’Ÿ = { ๐‘ฅ1, ๐‘ฆ1, ฮด1, ๐‘1 , ๐‘ฅ2, ๐‘ฆ2, ฮด2, ๐‘2 , โ€ฆ , ๐‘ฅ๐‘›, ๐‘ฆ๐‘›, ฮด๐‘›, ๐‘๐‘› }

โ€ข ๐‘๐‘– = โ„Ž0 ๐‘ฆ๐‘– ๐‘ฅ๐‘– โ€ฆ propensity [Rosenbaum et al]

๐‘น๐““ ๐’‰ =๐Ÿ

๐’

๐’Š=๐Ÿ

๐’

๐œน๐’Š๐’‰(๐’š๐’Š|๐’™๐’Š)

๐’‘๐’Š

Basic Importance Sampling [Owen]

๐”ผ๐‘ฅ ๐”ผ๐‘ฆโˆผโ„Ž ฮด ๐‘ฅ, ๐‘ฆ = ๐”ผ๐‘ฅ ๐”ผ๐‘ฆโˆผโ„Ž๐‘œ ฮด(๐‘ฅ, ๐‘ฆ)โ„Ž(๐‘ฆ|๐‘ฅ)

โ„Ž0(๐‘ฆ|๐‘ฅ)Perf of new system Samples from old system Importance weight

7

Story so far

8

Logged bandit data

Fix โ„Ž ๐‘ฃ๐‘ . โ„Ž0mismatch

Control variance

Tractable bound

Importance sampling causes non-uniform variance!

9

Logged bandit data

Fix โ„Ž ๐‘ฃ๐‘ . โ„Ž0mismatch

Control variance

Tractable bound

๐‘ฅ1 = (๐ด๐‘™๐‘–๐‘๐‘’, ๐‘ ๐‘๐‘œ๐‘Ÿ๐‘ก๐‘ )๐‘ฆ1 = (๐น1)๐›ฟ1 = 5๐‘1 = 0.9

๐‘ฅ2 = (๐ด๐‘™๐‘–๐‘๐‘’, ๐‘ก๐‘’๐‘โ„Ž)๐‘ฆ2 = (๐‘†๐‘๐‘Ž๐‘๐‘’๐‘‹)๐›ฟ2 = 1๐‘2 = 0.9

๐‘ฅ3 = (๐ด๐‘™๐‘–๐‘๐‘’,๐‘š๐‘œ๐‘ฃ๐‘–๐‘’๐‘ )๐‘ฆ3 = (๐‘†๐‘ก๐‘Ž๐‘Ÿ ๐‘Š๐‘Ž๐‘Ÿ๐‘ )๐›ฟ3 = 2๐‘3 = 0.9

๐‘ฅ4 = (๐ด๐‘™๐‘–๐‘๐‘’, ๐‘ก๐‘’๐‘โ„Ž)๐‘ฆ4 = (๐‘‡๐‘’๐‘ ๐‘™๐‘Ž)๐›ฟ4 = 1๐‘4 = 0.9

โ„Ž1 ๐‘…๐““ โ„Ž1 = 1

โ„Ž2 ๐‘…๐““ โ„Ž2 = 1.33

Want: Error bound that captures variance of importance sampling

Capacity control

Counterfactual Risk Minimization

โ€ข W.h.p. in ๐’Ÿ โˆผ โ„Ž0

โˆ€โ„Ž โˆˆ ๐“—, ๐‘…(โ„Ž) โ‰ค ๐‘…๐’Ÿ(โ„Ž) + ๐‘‚ ๐‘‰๐‘Ž๐‘Ÿ๐’Ÿ โ„Ž ๐‘› + ๐‘‚( ๐“โˆž(๐“—) ๐‘›)

*conditions apply. Refer [Maurer et al]

Empirical risk Variance regularization

Logged bandit data

Fix โ„Ž ๐‘ฃ๐‘ . โ„Ž0mismatch

Control variance

Tractable bound

10

Learning objective

โ„Ž๐ถ๐‘…๐‘€ = argminโ„Žโˆˆ๐“—

๐‘…๐’Ÿ โ„Ž + ฮป ๐‘‰๐‘Ž๐‘Ÿ๐’Ÿ โ„Ž

๐‘›

POEM: CRM algorithm for structured prediction

โ€ข CRFs: โ„Ž๐‘ค โˆˆ ๐“—๐‘™๐‘–๐‘›; โ„Ž๐‘ค ๐‘ฆ ๐‘ฅ =exp(๐‘คฯ•(๐‘ฅ,๐‘ฆ))

โ„ค(๐‘ฅ;๐‘ค)

โ€ข Policy Optimizer for Exponential Models :

๐‘คโˆ— = argmin๐‘ค

1

๐‘›

๐‘–=1

๐‘›

ฮด๐‘–โ„Ž๐‘ค(๐‘ฆ๐‘–|๐‘ฅ๐‘–)

๐‘๐‘–+ ฮป ๐‘‰๐‘Ž๐‘Ÿ(๐’‰๐‘ค)

๐‘›+ ฮผ ๐‘ค 2

11

Good: Gradient descent, search over infinitely

many w

Bad: Not convex in w

Ugly: Resists stochastic

optimization

Logged bandit data

Fix โ„Ž ๐‘ฃ๐‘ . โ„Ž0mismatch

Control variance

Tractable bound

Stochastically optimize ๐‘‰๐‘Ž๐‘Ÿ(๐’‰๐‘ค) ?

โ€ข Taylor-approximate!

๐‘‰๐‘Ž๐‘Ÿ(๐’‰๐‘ค) โ‰ค ๐ด๐‘ค๐‘ก

๐‘–=1

๐‘›

โ„Ž๐‘ค๐‘– + ๐ต๐‘ค๐‘ก

๐‘–=1

๐‘›

โ„Ž๐‘ค๐‘– 2 + ๐ถ๐‘ค๐‘ก

โ€ข During epoch: Adagrad with ๐›ปโ„Ž๐‘ค๐‘– + ฮป ๐‘›(๐ด๐‘ค๐‘ก๐›ปโ„Ž๐‘ค

๐‘– + 2๐ต๐‘ค๐‘กโ„Ž๐‘ค๐‘– ๐›ปโ„Ž๐‘ค๐‘– )}

โ€ข After epoch: ๐‘ค๐‘ก+1 โ†ค ๐‘ค, compute ๐ด๐‘ค๐‘ก+1 , ๐ต๐‘ค๐‘ก+1

๐‘ค๐‘ก

w๐‘ค๐‘ก+1

12

Logged bandit data

Fix โ„Ž ๐‘ฃ๐‘ . โ„Ž0mismatch

Control variance

Tractable bound

Experiment

โ€ข Supervised Bandit MultiLabel [Agarwal et al]

โ€ข ฮด ๐‘ฅ, ๐‘ฆ = Hamming(๐‘ฆโˆ— ๐‘ฅ , ๐‘ฆ) (smaller is better)

โ€ข LibSVM Datasetsโ€ข Scene (few features, labels and data)โ€ข Yeast (many labels)โ€ข LYRL (many features and data)โ€ข TMC (many features, labels and data)

โ€ข Validate hyper-params (ฮป,ฮผ) using ๐‘…๐’Ÿ๐‘ฃ๐‘Ž๐‘™(โ„Ž)

โ€ข Supervised test set expected Hamming loss

13

Approaches

โ€ข Baselinesโ€ข โ„Ž0: Supervised CRF trained on 5% of training data

โ€ข Proposedโ€ข IPS (No variance penalty) (extends [Bottou et al])

โ€ข POEM

โ€ข Skylinesโ€ข Supervised CRF (independent logit regression)

14

(1) Does variance regularization help?

15

1.543

5.547

1.463

3.445

1.519

4.614

1.118

3.023

1.143

4.517

0.996

2.522

0

1

2

3

4

5

6

Scene Yeast LYRL TMC

ฮด

Test set expected Hamming Loss

h0 IPS POEM Supervised CRF

0.659

2.822

0.222

1.189

(2) Is it efficient?

Avg Time (s) Scene Yeast LYRL TMC

POEM(B) 75.20 94.16 561.12 949.95

POEM(S) 4.71 5.02 120.09 276.13

CRF 4.86 3.28 62.93 99.18

16

โ€ข POEM recovers same performance at fraction of L-BFGS costโ€ข Scales as supervised CRF, learns from bandit feedback

(3) Does generalization improve as ๐‘› โ†’ โˆž?

17

2.75

3.25

3.75

4.25

1 2 4 8 16 32 64 128 256

ฮด

Replay Count

CRF h0 POEM

(4) Does stochasticity of โ„Ž0 affect learning?

18

0.8

1

1.2

1.4

1.6

0.5 1 2 4 8 16 32

ฮด

Param Multiplier

h0 POEM

Stochastic Deterministic

Conclusion

โ€ข CRM principle to learn from logged bandit feedbackโ€ข Variance regularization

โ€ข POEM for structured output predictionโ€ข Scales as supervised CRF, learns from bandit feedback

โ€ข Contact: adith@cs.cornell.edu

โ€ข POEM available at http://www.cs.cornell.edu/~adith/poem/

โ€ข Long paper: Counterfactual risk minimization โ€“ Learning from logged bandit feedback, http://jmlr.org/proceedings/papers/v37/swaminathan15.html

โ€ข Thanks!

19

References1. Art B. Owen. 2013. Monte Carlo theory, methods and examples.

2. Paul R. Rosenbaum and Donald B. Rubin. 1983. The central role of the propensity score in observational studies for causal effects. Biometrika 70. 41-55.

3. Lรฉon Bottou, Jonas Peters, Joaquin Quiรฑonero-Candela, Denis X. Charles, D. Max Chickering, Elon Portugaly, Dipankar Ray, Patrice Simard, and Ed Snelson. 2013. Counterfactual reasoning and learning systems: the example of computational advertising. J. Mach. Learn. Res. 14, 1, 3207-3260.

4. Alekh Agarwal, Daniel Hsu, Satyen Kale, John Langford, Lihong Li and Robert Schapire. 2014. Taming the Monster: A Fast and Simple Algorithm for Contextual Bandits. Proceedings of the 31st International Conference on Machine Learning. 1638-1646.

5. Andreas Maurer and Massimiliano Pontil. 2009. Empirical bernstein bounds and sample-variance penalization. Proceedings of the 22nd Conference on Learning Theory.

6. Adith Swaminathan and Thorsten Joachims. 2015. Counterfactual risk minimization: Learning from logged bandit feedback. Proceedings of the 32nd International Conference on Machine Learning.

20