Learning Mixtures of Product Distributions
description
Transcript of Learning Mixtures of Product Distributions
Learning Mixtures of Product Distributions
Jon Feldman
Columbia University
Rocco Servedio
Columbia University
Ryan O’Donnell
IAS
Learning Distributions
• There is a an unknown distribution P over Rn, or maybe
just over {0,1}n.
• An algorithm gets access to random samples from P.
• In time polynomial in n/ε it should output a hypothesis
distribution Q which (w.h.p.) is ε-close to P.
[Technical details later.]
Learning Distributions
R 0
Hopeless in general!
Learning Classes of Distributions
• Since this is hopeless in general one assumes that P comes
from class of distributions C.
• We speak of whether C is polynomial-time learnable or not;
this means that there is one algorithm that learns every P
in C.
• Some easily learnable classes:
– C = {Gaussians over Rn}
– C = {Product distributions over {0,1}n}
Learning Distributions
Learning product distributions over {0,1}n
E.g. n = 3. Samples…0 1 00 1 10 1 11 1 10 1 00 1 10 1 00 1 01 1 10 0 0
Hypothesis: [.2 .9 .5]
Mixtures of product distributions
Fix k ≥ 2 and let π1 + π2 + … πk = 1.
The π-mixture of distributions P 1, …, P k is:
– Draw i according to mixture weights πi.
– Draw from P i.
In the case of product distributions over {0,1}n:
π1 [ μ1 μ1 μ1 … μ1 ]
π2 [ μ2 μ2 μ2 … μ2 ]
…
πk [ μk μk μk … μk ]
1 2
1
3
2
1 2
3
n
3 n
n
Learning mixture example
E.g. n = 4. Samples… 1 1 0 0 0 0 0 1 0 1 0 1 0 1 1 0 0 0 0 1 1 1 1 0 0 1 0 1 0 0 1 1 1 1 1 0 1 0 1 0
True distribution: 60% [ .8 .8 .6 .2 ]40% [ .2 .4 .3 .8 ]
Prior work• [KMRRSS94]: learned in time poly(n/ε, 2k) in the special
case that there is a number p < ½ such that every μi is
either p or 1−p.
• [FM99]: learned mixtures of 2 product distributions over
{0,1}n in polynomial time (with a few minor technical
deficiencies).
• [CGG98]: learned a generalization of 2 product distributions
over {0,1}n, no deficiencies.
The latter two leave mixtures of 3+ as an open problem: there is a
qualitative difference between 2 & 3. [FM99] also leaves
open learning mixes of Gaussians, other Rn distributions.
j
Our results
• A poly(n/ε) time algorithm learning a mixture of k product
distributions over {0,1}n for any constant k.
• Evidence that getting a poly(n/ε) algorithm for k = ω(1)
[even in the case where μ’s are in {0, ½, 1}] will be very
hard (if possible).
• Generalizations:
– Let C 1, …, C n be “nice” classes of distributions over R
(…definable in terms of O(1) moments…) Algorithm
learns mixture of O(1) distributions in C 1 × · · · × C n.
– Only pairwise independence of coords is used…
Technical definitions
When is a hypothesis distribution Q “ε-close” to the target
distribution P ?
• L1 distance? ∫ |P(x) – Q(x)|.
• KL divergence: KL(P || Q) = ∫ P (x) log[P (x)/Q(x)].
Getting a KL-close hypothesis is more stringent:
fact: L1 ≤ O(KL½).
We learn under KL divergence, which leads to some technical
advantages (and some technical difficulties).
Learning distributions summary
• Learning a class of distributions C.
• Let P be any distribution in the class.
• Given ε and δ > 0.
• Get samples and do poly(n/ε, log(1/δ)) much work.
• With probability at least 1−δ output a hypothesis Q which
satisfies KL(P || Q) < ε.
Some intuition for k = 2
Idea: Find two coordinates j and j' to “key off.”
• Suppose you notice that the bits in coords j and j' are very
frequently different.
• Then probably most of the …0…1… examples come from
one mixture and most of the …1…0… examples come from
the other mixture –
• Use this separation to estimate all other means.
More details for the intuition
Suppose you somehow “know” the following three things:
– The mixture weights are 60% / 40%.
– There are j and j' such that means satisfy
pj pj'
qj qj'
– The values pj, pj', qj, qj' themselves.
> ε.
More details for the intuition
Main algorithmic idea:
For each coord m, estimate (to within ε2) the correlation
between j & m and j' & m.
corr(j, m) = (.6 pj) pm + (.4 qj) qm
corr(j', m) = (.6 pj') pm + (.4 qj') qm
Solve this system of equations for pm, qm. Done!
Since the determinant is > ε, any error in correlation estimation
error does not blow up too much.
Two questions
1. This assumes that there is some 2×2 submatrix which is far
from singular. In general, no reason to believe this is the
case.
– But if not, then one set of means is very nearly a multiple
of the other set; problem becomes very easy.
2. How did we know π1, π2? How did we know which j and j'
were good? How did we know the 4 means pj, pj', qj, qj'?
Guessing
Just guess. I.e., “try” “all” possibilities.
• Guess if the 2 × n matrix is essentially rank 1 or not.
• Guess π1, π2 to within ε2. (Time: 1/ε4.)
• Guess correct j, j'. (Time: n2.)
• Guess pj, pj', qj, qj' to within ε2. (Time: 1/ε8.)
Solve the system of equations in every case.
Time: poly(n/ε).
Checking guesses
• After this we get a whole bunch of candidate hypotheses.
• When we get lucky and make all the right guesses, the
resulting candidate hypothesis will be a good one – say, will
be ε-close in KL to the truth.
Can we pick the (or, a) candidate hypothesis which is KL-close
to the truth? I.e., can we guess and check?
Yes – use a Maximum Likelihood test…
Checking with ML
Suppose Q is a candidate hypothesis for P.
Estimate its “log likelihood”:
log Πx є S Q(x)
= Σx є S log Q(x)
≈ |S| E[log Q (x)]
= |S| ∫ P (x) log Q (x)
= |S| [ ∫ P log P – KL(P || Q ) ].
Checking with ML cont’d
• By Chernoff bounds, if we take enough samples, all
candidate hypotheses Q will have their “estimated log-
likelihoods” close to their expectations.
– Any KL-close Q will look very good in the ML test.
– Anything which looks good in the ML test is KL-close.
• Thus assuming there is an ε-close candidate hypothesis
among guesses, we find an O(ε)-close candidate hypothesis.
• I.e., we can guess and check.
Overview of the algorithm
We now give the precise algorithm for learning a mixture of k
product distributions, along with intuition for why it works.
Intuitively:
– Estimate all the pairwise correlations of bits.
– Guess a number of parameters of the mixture distn.
– Use guesses, correlation estimates to solve for remaining
parameters.
– Show that whenever guesses are close, the resulting parameter
estimations give a close-in-KL candidate hypothesis.
– Check candidates with ML algorithm, pick best one.
The algorithm
1. Estimate all pairwise correlations corr(j, j') to
within (ε/n)k. (Time: (n/ε)k.)
Note: corr(j, j') = Σi = 1..k πi μi μi
= μj , μj' ,
where μj = ( (πi)½ μi )i = 1..k
2. Guess all πi to within (ε/n)k. (Time: (n/ε)k2.)
Now it suffices to estimate all vectors μj, j = 1… n.
j j'~ ~
~j
~
Mixtures of product distributions
Fix k ≥ 2 and let π1 + π2 + … πk = 1.
The π-mixture of distributions P 1, …, P k is:
– Draw i according to mixture weights πi.
– Draw from P i.
In the case of product distributions over {0,1}n:
π1 [ μ1 μ1 μ1 … μ1 ]
π2 [ μ2 μ2 μ2 … μ2 ]
…
πk [ μk μk μk … μk ]
1 2
1
3
2
1 2
3
n
3 n
n
Guessing matrices from most of their Gram matrices
Let A be the k × n matrix of μ i’s.
A =
After estimating all correlations, we know all dot products of
distinct columns of A to high accuracy.
Goal: determine all entries of A, making only O(1) guesses.
~ j
μ1 μ2 μn
~~ ~
Two remarks
1. This is the final problem, where all the main action and
technical challenge lies. Note that all we ever do with the
samples is estimate pairwise correlations.
2. If we knew the dot products of the columns of A with
themselves, we’d have the whole matrix ATA. That would
be great; we could just factor it and recover A exactly.
Unfortunately, there doesn’t seem to be any way to get at
these quantities – Σi = 1..k πi (μi)2.
j
Keying off a nonsingular submatrix
Idea: find a nonsingular k × k matrix to “key off.”
As before, the “usual” case is that A has full rank.
• Then there is a k × k nonsingular submatrix AJ.
• Guess this matrix (time: nk) and all its entries to
within (ε/n)k (time: (n/ε)k3 – final running time).
• Now use this submatrix and correlation estimates to find
all other entries of A:
for all m, AJT Am = corr(m, j) (j є J)
Non-full rank case
But what if A is not full rank? (Or in actual analysis, if A is
extremely close to being rank deficient.) A genuine
problem.
Then A has some perpendicular space of dimension 0 < d ≤ k,
spanned by some orthonormal vectors u1, …, ud.
3. Guess d and the vectors u1, …, ud.
Now adjoin these columns to A getting a full rank matrix.
A' = A u1 u2 … ud
Non-full rank case – cont’d
Now A' has full rank and we can do the full rank case!
Why do we still know all pairwise dot products of A'’s columns?
– Dot product of u’s with A columns are 0!
– Dot product of u’s with each other is 1. (Don’t need
this.)
4. Guess a k × k submatrix of A' and all its
entries. Use these to solve for all other
entries.
The actual analysis
• The actual analysis of this algorithm is quite delicate.
• There’s some linear algebra & numerical analysis ideas.
• The main issue is: The degree to which A is “essentially” of
rank k – d is similar to the degree to which all guessed
vectors u really do have dot product 0 with A’s original
columns.
• The key is to find a large multiplicative “gap”
between A’s singular values, and treat its location
as the essential rank of A.
• This is where the necessary accuracy (ε/n)k comes in.
Can we learn a mixture of ω(1)?
Claim: Let T be a decision tree on {0,1}n with k leaves.
Then the uniform distribution over the inputs which make T
output 1 is a mixture of at most k product distributions.
Indeed, all product distributions have means 0, ½, or 1.
x1
x2
x2
x3
0
0 0
0
1
1
1
11 0
0 1
0
2/3: [0, 0, ½, ½, ½, …]
1/3: [1, 1, 0, ½, ½, …]
Learning DTs under uniform
Cor: If one can learn a mixture of k product distributions over
{0,1}n (even 0/½/1 ones) in poly(n) time, one can PAC-learn
k-leaf decision trees under uniform in poly(n) time.
PAC-learning ω(1)-size DTs under uniform is an extremely
notorious problem:
– easier than learning ω(1)-term DNF under uniform, a 20-
year-old problem;
– essentially equivalent to learning ω(1)-juntas under
uniform; worth $1000 from A. Blum to solve
Generalizations
We gave an algorithm that guessed the means of an unknown
mixture of k product distributions.
What assumptions did we really need?
• pairwise independence of coords
• means fell in a bounded range [-poly(n), poly(n)]
• 1-d distributions (and pairwise products of same) are
“samplable” – can find true correlations by estimation
• the means defined the 1-d distributions
The last of these is rarely true. But…
Higher moments
• Suppose we ran the algorithm and got N guesses for the
means of all the distributions.
• Now run the algorithm again, but whenever you get the
point x1, …, xn, treat it as x12, …, xn
2.
• You will get N guesses for the second moments!
• Cross product the two lists, get N2 guesses for the
mean, second moment pairs.
• Guess and check, as always.
Generalizations
Let C 1, …, C n be families of distributions on R which have the
following “niceness properties”:
– means bounded in [-poly(n), poly(n)]
– sharp tail bounds / samplability
– defined by O(1) moments, closeness in moments closeness in KL
– … more technical concerns…
Should be able to learn O(1)-mixtures from C 1 × · · · × C n in
same time.
Definitely can learn mixtures of axis-aligned Gaussians, mixtures
of distributions on O(1)-sized sets.
Open questions
• Quantify some nice properties of families of distributions
over R which this algorithm can learn.
• Simplify algorithm:
– Simpler analysis?
– Faster? nk2 ? nk ? nlog k ???
– Specific fast results for k = 2, 3.
• Solve other distribution-learning problems.