In the previous post, we case studied two examples: two coins and GMM, provided the algorithms for solving them (without proof), and derived a generalized form of the EM algorithm for solving a family of similar problems — finding the maximum log-likelihood with unknown hidden variables. This post will focus on the proof: Why are the algorithms showcased in the two examples mathematically equivalent to the final form of EM?
💡 Algorithm 3: Expectation-Maximization (EM) — Final version
Input: Observation {x(1)…x(n)}, initial guess for the parameter θ0
Output: Optimal parameter value θ∗ that maximizes the log-likelihood
L(θ)=i=1∑nlogP(x(i);θ)
For t=1…T (until converge):
E-Step: For each i=1…n, compute the hidden posterior:
q(i)(z(i))=P(z(i)∣x(i);θt)
M-Step: Compute the maximizer for the evidence lower bound (ELBO):
Remember that we computed the probability of a certain trial belonging to a certain coin in the E-step: q(1)(z(1))…q(5)(z(5)), where z(i)∈{A,B} is the coin tossed at the i-th trial:
Trial #1 i=1
Trial #2 i=2
Trial #3 i=3
Trial #4 i=4
Trial #5 i=5
Observation: Number of heads x(i)
5
9
8
4
7
Hidden posterior: q(i)(z(i)=A) = P(Trial is coin A | observation)
0.45
0.80
0.73
0.35
0.65
Hidden posterior: q(i)(z(i)=B) = P(Trial is coin B | observation)
0.55
0.20
0.27
0.65
0.35
Remember that this was computed using the posterior probability P(Trial is a certain coin | observation), which is consistent with Algorithm 3.
M-Step
It is slightly trickier to see why the “soft count” solution is equivalent to the M-step in Algorithm 3. To verify this, we expand the log part in the M-step formula as follows (Our derivation follows that of UMich EECS 545 lecture notes, adapted to our notion):
θt+1=argθmaxi=1∑nEz(i)∼q(i)logP(x(i),z(i);θ)=argθmaxi=1∑nEz(i)∼q(i)log[P(z(i))⋅P(x(i)∣z(i);θ)]=argθmaxi=1∑nEz(i)∼q(i)⎣⎢⎡can be omittedconstant wrt θlogP(z(i))+logP(x(i)∣z(i);θ)⎦⎥⎤=argθmaxi=1∑nEz(i)∼q(i)logP(x(i)∣z(i);θ)=argθmaxi=1∑nEz(i)∼q(i)log[(x(i)10)pz(i)x(i)(1−pz(i))(10−x(i))]=argθmaxi=1∑nEz(i)∼q(i)⎣⎢⎢⎢⎢⎡can be omittedconstant wrt θlog(x(i)10)+x(i)logpz(i)+(1−x(i))log(10−pz(i))⎦⎥⎥⎥⎥⎤=argθmaxi=1∑nEz(i)∼q(i)[x(i)logpz(i)+(10−x(i))log(1−pz(i))]=argθmaxi=1∑n{q(i)(z(i)=A)[x(i)logpA+(10−x(i))log(1−pA)]+q(i)(z(i)=B)[x(i)logpB+(10−x(i))log(1−pB)]}=argθmaxi=1∑n{a(i)[x(i)logpA+(10−x(i))log(1−pA)]+b(i)[x(i)logpB+(10−x(i))log(1−pB)]}
Taking the partial derivative with respect to θ={pA,pB} setting it to zero, we have:
Therefore, our M-step is also consistent with Algorithm 3.
Proof for GMM
E-Step
Similar to two coins, the hidden posterior was computed using the posterior probability P(Student is a certain gender | observation), which is consistent with Algorithm 3:
Student #1 i=1
Student #2 i=2
Student #3 i=3
Student #4 i=4
Student #5 i=5
Student #6 i=6
Observation: Height x(i)
168
180
170
172
178
176
Hidden Posterior: q(i)(z(i)=B) = P(boy | height)
0.26
1.00
0.50
0.74
0.99
0.96
Hidden Posterior: q(i)(z(i)=G) = P(girl | height)
0.74
0.00
0.50
0.26
0.01
0.04
M-Step
Now let’s see why the weighted mean and standard deviation are the solution to the GMM M-step. I’m omitting the first few steps since they are identical to the two coins example:
Therefore, our M-step is also consistent with Algorithm 3.
Closing Thoughts
This wraps up my notes on the EM algorithm, and I hope you find it to be helpful!
As I said at the beginning of Part I, EM is straightforward to understand through examples, yet tricky to fully understand the math behind. Therefore, I chose to put the examples and the ad-hoc algorithms first, and then prove that they are consistent with a more generalized theory. A fun fact is that I found Part II — while being math heavy — to be much easier to write than Part I, which focused on explaining the intuition. This doesn’t mean that the math is easy — it simply means that sometimes the intuition behind the math is extremely hard to put into words.
I first learned the EM algorithm from Stanford CS229: Machine Learning, and then I became a course assistant for Stanford CS221: Artificial Intelligence: Principles and Techniques where I needed to explain this algorithm to many students. In CS229, I learned to navigate through the proof and use EM to solve several math problems, yet failed to develop a good intuition; In CS221, I was finally able to fully appreciate the beauty of this algorithm with a few “Explain Like I’m Five (ELI5)” examples similar to two coins — which I believe was not only helpful to myself, but also to quite a few CS221 students.
Even so, I found it quite challenging to explain the intuition in words (rather than in math). After some research, I’ve found some excellent tutorials online, yet the majority of them are still overwhelmingly math heavy for beginners, while those with ELI5 examples usually skip the math derivations. Therefore, I decided to write a learning note by myself — with the hope that it will be accessible enough for beginners but still rigorous enough for readers seeking math proofs. Feel free to comment below or send me an email if you have any feedback or questions!
References
[1] What is the expectation maximization algorithm? Chuong B Do, Serafim Batzoglou. Nature, 2008. [paper]