0%

Understanding the EM Algorithm | Part II

In the previous post, we case studied two examples: two coins and GMM, provided the algorithms for solving them (without proof), and derived a generalized form of the EM algorithm for solving a family of similar problems — finding the maximum log-likelihood with unknown hidden variables. This post will focus on the proof: Why are the algorithms showcased in the two examples mathematically equivalent to the final form of EM?

💡 Algorithm 3: Expectation-Maximization (EM) — Final version

  • Input: Observation {x(1)x(n)}\{x^{(1)}…x^{(n)}\}, initial guess for the parameter θ0\theta_0
  • Output: Optimal parameter value θ\theta^* that maximizes the log-likelihood

    L(θ)=i=1nlogP(x(i);θ)\mathcal{L}(\theta)=\sum_{i=1}^n \log P(x^{(i)};\theta)

  • For t=1Tt=1…T (until converge):
    • E-Step: For each i=1ni=1…n, compute the hidden posterior:

      q(i)(z(i))=P(z(i)x(i);θt)q^{(i)}(z^{(i)})=P(z^{(i)}|x^{(i)};\theta_t)

    • M-Step: Compute the maximizer for the evidence lower bound (ELBO):

      θt+1=argmaxθi=1nEz(i)q(i)(z(i))logP(x(i),z(i); θ)\theta_{t+1}=\arg\max_\theta \sum_{i=1}^n \displaystyle \mathop{\mathbb{E}}_{z^{(i)}\sim q^{(i)}(z^{(i)})}\log P(x^{(i)}, z^{(i)};~ \theta)

Proof for two coins

E-Step

Remember that we computed the probability of a certain trial belonging to a certain coin in the E-step: q(1)(z(1))q(5)(z(5))q^{(1)}(z^{(1)})…q^{(5)}(z^{(5)}), where z(i){A,B}z^{(i)} \in \{A, B\} is the coin tossed at the ii-th trial:

Trial #1
i=1i=1
Trial #2
i=2i=2
Trial #3
i=3i=3
Trial #4
i=4i=4
Trial #5
i=5i=5
Observation: Number of heads x(i)x^{(i)} 5 9 8 4 7
Hidden posterior:
q(i)(z(i)=A)q^{(i)}(z^{(i)}=A) = P(Trial is coin A | observation)
0.45 0.80 0.73 0.35 0.65
Hidden posterior:
q(i)(z(i)=B)q^{(i)}(z^{(i)}=B) = P(Trial is coin B | observation)
0.55 0.20 0.27 0.65 0.35

Remember that this was computed using the posterior probability P(Trial is a certain coin | observation), which is consistent with Algorithm 3.

M-Step

It is slightly trickier to see why the “soft count” solution is equivalent to the M-step in Algorithm 3. To verify this, we expand the log part in the M-step formula as follows (Our derivation follows that of UMich EECS 545 lecture notes, adapted to our notion):

θt+1=argmaxθi=1nEz(i)q(i)logP(x(i),z(i);θ)=argmaxθi=1nEz(i)q(i)log[P(z(i))P(x(i)z(i);θ)]=argmaxθi=1nEz(i)q(i)[logP(z(i))constant wrt θcan be omitted+logP(x(i)z(i);θ)]=argmaxθi=1nEz(i)q(i)logP(x(i)z(i);θ)=argmaxθi=1nEz(i)q(i)log[(10x(i))pz(i)x(i)(1pz(i))(10x(i))]=argmaxθi=1nEz(i)q(i)[log(10x(i))constant wrt θcan be omitted+x(i)logpz(i)+(1x(i))log(10pz(i))]=argmaxθi=1nEz(i)q(i)[x(i)logpz(i)+(10x(i))log(1pz(i))]=argmaxθi=1n{q(i)(z(i)=A)[x(i)logpA+(10x(i))log(1pA)]+                           q(i)(z(i)=B)[x(i)logpB+(10x(i))log(1pB)]}=argmaxθi=1n{a(i)[x(i)logpA+(10x(i))log(1pA)]+                           b(i)[x(i)logpB+(10x(i))log(1pB)]}\begin{aligned} \theta_{t+1}&=\textcolor{darkgrey}{\arg\max_\theta \sum_{i=1}^n \displaystyle \mathop{\mathbb{E}}_{z^{(i)}\sim q^{(i)}}} \log P(x^{(i)}, z^{(i)}; \theta)\\ &= \textcolor{darkgrey}{\arg\max_\theta \sum_{i=1}^n \displaystyle \mathop{\mathbb{E}}_{z^{(i)}\sim q^{(i)}}}\log \left[ P(z^{(i)})\cdot P(x^{(i)}|z^{(i)};\theta)\right]\\ &= \textcolor{darkgrey}{\arg\max_\theta \sum_{i=1}^n \displaystyle \mathop{\mathbb{E}}_{z^{(i)}\sim q^{(i)}}}\left[\underbrace{\log P(z^{(i)})}_{\text{constant wrt }\theta \atop \text{can be omitted}}+\log P(x^{(i)}|z^{(i)};\theta)\right]\\ &= \textcolor{darkgrey}{\arg\max_\theta \sum_{i=1}^n \displaystyle \mathop{\mathbb{E}}_{z^{(i)}\sim q^{(i)}}} \log P(x^{(i)}|z^{(i)};\theta)\\ &= \textcolor{darkgrey}{\arg\max_\theta \sum_{i=1}^n \displaystyle \mathop{\mathbb{E}}_{z^{(i)}\sim q^{(i)}}} \log\left[ \binom{10}{x^{(i)}} p_{z^{(i)}}^{x^{(i)}}(1-p_{z^{(i)}})^{(10-x^{(i)})}\right]\\ &= \textcolor{darkgrey}{\arg\max_\theta \sum_{i=1}^n \displaystyle \mathop{\mathbb{E}}_{z^{(i)}\sim q^{(i)}}} \left[ \underbrace{\log \binom{10}{x^{(i)}}}_{\text{constant wrt }\theta \atop \text{can be omitted}}+x^{(i)}\log p_{z^{(i)}}+(1-x^{(i)})\log (10-p_{z^{(i)}}) \right]\\ &=\textcolor{darkgrey}{\arg\max_\theta \sum_{i=1}^n \displaystyle \mathop{\mathbb{E}}_{z^{(i)}\sim q^{(i)}}} \left[x^{(i)}\log p_{z^{(i)}}+(10-x^{(i)})\log (1-p_{z^{(i)}})\right]\\ &= \textcolor{darkgrey}{\arg\max_\theta \sum_{i=1}^n }\{\textcolor{darkred}{q^{(i)}(z^{(i)}=A)}[x^{(i)}\log p_A+(10-x^{(i)})\log (1-p_A)]+\\&~~~~~~~~~~~~~~~~~~~~~~~~~~~\textcolor{green}{q^{(i)}(z^{(i)}=B)}[x^{(i)}\log p_B+(10-x^{(i)})\log (1-p_B)]\}\\ &= \textcolor{darkgrey}{\arg\max_\theta \sum_{i=1}^n }\{\textcolor{darkred}{a^{(i)}}[x^{(i)}\log p_A+(10-x^{(i)})\log (1-p_A)]+\\&~~~~~~~~~~~~~~~~~~~~~~~~~~~\textcolor{green}{b^{(i)}}[x^{(i)}\log p_B+(10-x^{(i)})\log (1-p_B)]\} \end{aligned}

Taking the partial derivative with respect to θ={pA,pB}\theta=\{p_A, p_B\} setting it to zero, we have:

Eq[logP(X,Z;θ)]pA=i=1na(i)x(i)pAi=1na(i)(10x(i))1pA=0Eq[logP(X,Z;θ)]pB=i=1nb(i)x(i)pBi=1nb(i)(10x(i))1pB=0\begin{aligned} \dfrac{\partial \mathbb{E}_q[\log P(X,Z;\theta)]}{\partial p_A} = \dfrac{\sum_{i=1}^na^{(i)}x^{(i)}}{p_A}-\dfrac{\sum_{i=1}^na^{(i)}(10-x^{(i)})}{1-p_A}=0\\ \dfrac{\partial \mathbb{E}_q[\log P(X,Z;\theta)]}{\partial p_B} = \dfrac{\sum_{i=1^n}b^{(i)}x^{(i)}}{p_B}-\dfrac{\sum_{i=1}^nb^{(i)}(10-x^{(i)})}{1-p_B}=0\\ \end{aligned}

Solving this, we get:

pA=i=1na(i)x(i)i=1na(i)10pB=i=1nb(i)x(i)i=1nb(i)10p_A=\dfrac{\sum_{i=1}^na^{(i)} \cdot x^{(i)}}{\sum_{i=1}^na^{(i)}\cdot 10}\\ p_B=\dfrac{\sum_{i=1}^nb^{(i)} \cdot x^{(i)}}{\sum_{i=1}^nb^{(i)}\cdot 10}

which is consistent with what we previously did using the “soft count”:

pA(1)=0.45×5+0.80×9+0.73×8+0.35×4+0.65×7(0.45+0.80+0.73+0.35+0.65)×100.713pB(1)=0.55×5+0.20×9+0.27×8+0.65×4+0.35×7(0.55+0.20+0.27+0.65+0.35)×100.581\begin{aligned} p_A^{(1)} &= \dfrac{\textcolor{darkred}{0.45}\times 5+\textcolor{darkred}{0.80}\times9+\textcolor{darkred}{0.73}\times8+\textcolor{darkred}{0.35}\times4+\textcolor{darkred}{0.65}\times 7}{(\textcolor{darkred}{0.45+0.80+0.73+0.35+0.65})\times10}\approx 0.713\\\\ p_B^{(1)} &= \dfrac{\textcolor{green}{0.55}\times5+\textcolor{green}{0.20}\times9+\textcolor{green}{0.27}\times8+\textcolor{green}{0.65}\times4+\textcolor{green}{0.35}\times7}{(\textcolor{green}{0.55+0.20+0.27+0.65+0.35})\times 10}\approx 0.581 \end{aligned}

Therefore, our M-step is also consistent with Algorithm 3.

Proof for GMM

E-Step

Similar to two coins, the hidden posterior was computed using the posterior probability P(Student is a certain gender | observation), which is consistent with Algorithm 3:

Student #1
i=1i=1
Student #2
i=2i=2
Student #3
i=3i=3
Student #4
i=4i=4
Student #5
i=5i=5
Student #6
i=6i=6
Observation:
Height x(i)x^{(i)}
168 180 170 172 178 176
Hidden Posterior:
q(i)(z(i)=B)q^{(i)}(z^{(i)}=B) = P(boy | height)
0.26 1.00 0.50 0.74 0.99 0.96
Hidden Posterior:
q(i)(z(i)=G)q^{(i)}(z^{(i)}=G) = P(girl | height)
0.74 0.00 0.50 0.26 0.01 0.04

M-Step

Now let’s see why the weighted mean and standard deviation are the solution to the GMM M-step. I’m omitting the first few steps since they are identical to the two coins example:

θt+1=argmaxθi=1nEz(i)q(i)logP(x(i)z(i);θ)=argmaxθi=1nEz(i)q(i)logN(x(i);μz(i),σz(i))=argmaxθi=1nEz(i)q(i)log[1σz(i)2πexp((x(i)μz(i))22σz(i)2)]=argmaxθi=1nEz(i)q(i)[log1σz(i)2π(x(i)μz(i))22σz(i)2]=argmaxθi=1n[q(i)(z(i)=B)(log1σB2π(x(i)μB)22σB2)+q(i)(z(i)=G)(log1σG2π(x(i)μG)22σG2)]=argmaxθi=1n[b(i)(log1σB2π(x(i)μB)22σB2)+g(i)(log1σG2π(x(i)μG)22σG2)]\begin{aligned} \theta_{t+1}&= \textcolor{darkgrey}{\arg\max_\theta \sum_{i=1}^n \displaystyle \mathop{\mathbb{E}}_{z^{(i)}\sim q^{(i)}}} \log P(x^{(i)}|z^{(i)};\theta)\\ % &= \textcolor{darkgrey}{\arg\max_\theta \sum_{i=1}^n \displaystyle \mathop{\mathbb{E}}_{z^{(i)}\sim q^{(i)}}} \log \mathcal{N}(x^{(i)}; \mu_{z^{(i)}}, \sigma_{z^{(i)}})\\ % &= \textcolor{darkgrey}{\arg\max_\theta \sum_{i=1}^n \displaystyle \mathop{\mathbb{E}}_{z^{(i)}\sim q^{(i)}}} \log \left[ \dfrac{1}{\sigma_{z^{(i)}}\sqrt{2\pi}} \exp \left( -\dfrac{(x^{(i)}-\mu_{z^{(i)}})^2}{2\sigma_{z^{(i)}}^2} \right) \right]\\ % &= \textcolor{darkgrey}{\arg\max_\theta \sum_{i=1}^n \displaystyle \mathop{\mathbb{E}}_{z^{(i)}\sim q^{(i)}}} \left[ \log \dfrac{1}{\sigma_{z^{(i)}}\sqrt{2\pi}} - \dfrac{(x^{(i)}-\mu_{z^{(i)}})^2}{2\sigma_{z^{(i)}}^2} \right]\\ % &= \textcolor{darkgrey}{\arg\max_\theta \sum_{i=1}^n} \left[ \textcolor{darkred}{q^{(i)}(z^{(i)}=B)} \left( \log \dfrac{1}{\sigma_B\sqrt{2\pi}} - \dfrac{(x^{(i)}-\mu_B)^2}{2\sigma_B^2} \right) + \textcolor{green}{q^{(i)}(z^{(i)}=G)} \left( \log \dfrac{1}{\sigma_G\sqrt{2\pi}} - \dfrac{(x^{(i)}-\mu_G)^2}{2\sigma_G^2} \right) \right]\\ % &= \textcolor{darkgrey}{\arg\max_\theta \sum_{i=1}^n} \left[ \textcolor{darkred}{b^{(i)}} \left( \log \dfrac{1}{\sigma_B\sqrt{2\pi}} - \dfrac{(x^{(i)}-\mu_B)^2}{2\sigma_B^2} \right) + \textcolor{green}{g^{(i)}} \left( \log \dfrac{1}{\sigma_G\sqrt{2\pi}} - \dfrac{(x^{(i)}-\mu_G)^2}{2\sigma_G^2} \right) \right] % \end{aligned}

Taking the partial derivative with respect to θ={μA,μB,σA,σB}\theta=\{\mu_A, \mu_B, \sigma_A, \sigma_B\} and setting it to zero, we have:

Eq[logP(X,Z;θ)]μB=i=1nb(i)(x(i)μB)2σB2=0Eq[logP(X,Z;θ)]μG=i=1ng(i)(x(i)μG)2σG2=0Eq[logP(X,Z;θ)]σB=i=1nb(i)σB+i=1nb(i)(x(i)μB)2σB3=0Eq[logP(X,Z;θ)]σG=i=1ng(i)σG+i=1ng(i)(x(i)μG)2σG3=0\begin{aligned} % \dfrac{\partial \mathbb{E}_q[\log P(X,Z;\theta)]}{\partial \mu_B} &= -\dfrac{\sum_{i=1}^n b^{(i)}(x^{(i)}-\mu_B)}{2\sigma_B^2} =0\\ % \dfrac{\partial \mathbb{E}_q[\log P(X,Z;\theta)]}{\partial \mu_G} &= -\dfrac{\sum_{i=1}^n g^{(i)}(x^{(i)}-\mu_G)}{2\sigma_G^2} =0\\ % \dfrac{\partial \mathbb{E}_q[\log P(X,Z;\theta)]}{\partial \sigma_B} &= -\dfrac{\sum_{i=1}^n b^{(i)}}{\sigma_B} + \dfrac{\sum_{i=1}^n b^{(i)}(x^{(i)}-\mu_B)^2}{\sigma_B^3} =0\\ % \dfrac{\partial \mathbb{E}_q[\log P(X,Z;\theta)]}{\partial \sigma_G} &= -\dfrac{\sum_{i=1}^n g^{(i)}}{\sigma_G} + \dfrac{\sum_{i=1}^n g^{(i)}(x^{(i)}-\mu_G)^2}{\sigma_G^3} =0\\ % \end{aligned}

Solving these, we get:

μB=i=1nb(i)x(i)i=1nb(i)μG=i=1ng(i)x(i)i=1ng(i)σB2=i=1nb(i)(x(i)μB)2i=1nb(i)σG2=i=1ng(i)(x(i)μG)2i=1ng(i)\begin{aligned} % \mu_B &= \dfrac{\sum_{i=1}^n b^{(i)}x^{(i)}}{\sum_{i=1}^n b^{(i)}}\\ % \mu_G &= \dfrac{\sum_{i=1}^n g^{(i)}x^{(i)}}{\sum_{i=1}^n g^{(i)}}\\ % \sigma_B^2 &= \dfrac{\sum_{i=1}^n b^{(i)} (x^{(i)}-\mu_B)^2}{\sum_{i=1}^n b^{(i)}}\\ % \sigma_G^2 &= \dfrac{\sum_{i=1}^n g^{(i)} (x^{(i)}-\mu_G)^2}{\sum_{i=1}^n g^{(i)}} % \end{aligned}

which is consistent with what we previously did using the “weighted” mean and standard deviation:

μB(1)=0.26×168+1.00×180+...+0.96×1760.26+1.00+...+0.96175.5μG(1)=0.74×168+0.00×180+...+0.04×1760.74+0.00+...+0.04169.6σB(1)2=0.26×(168175.5)2+...+0.96×(176175.5)20.26+1.00+...+0.963.83σG(1)2=0.74×(168169.6)2+...+0.04×(176169.6)20.74+0.00+...+0.042.04\begin{aligned} \mu_B^{(1)} &= \dfrac{\textcolor{darkred}{0.26}\times168+\textcolor{darkred}{1.00}\times180+...+\textcolor{darkred}{0.96}\times176}{\textcolor{darkred}{0.26+1.00+...+0.96}} \approx 175.5\\\\ \mu_G^{(1)} &= \dfrac{\textcolor{green}{0.74}\times168+\textcolor{green}{0.00}\times180+...+\textcolor{green}{0.04}\times176}{\textcolor{green}{0.74+0.00+...+0.04}} \approx 169.6\\\\ {\sigma_B^{(1)}}^2 &= \dfrac{\textcolor{darkred}{0.26}\times(168-175.5)^2+...+\textcolor{darkred}{0.96}\times(176-175.5)^2}{\textcolor{darkred}{0.26+1.00+...+0.96}}\approx 3.83\\\\ {\sigma_G^{(1)}}^2 &= \dfrac{\textcolor{green}{0.74}\times(168-169.6)^2+...+\textcolor{green}{0.04}\times(176-169.6)^2}{\textcolor{green}{0.74+0.00+...+0.04}}\approx 2.04 \end{aligned}

Therefore, our M-step is also consistent with Algorithm 3.

Closing Thoughts

This wraps up my notes on the EM algorithm, and I hope you find it to be helpful!

As I said at the beginning of Part I, EM is straightforward to understand through examples, yet tricky to fully understand the math behind. Therefore, I chose to put the examples and the ad-hoc algorithms first, and then prove that they are consistent with a more generalized theory. A fun fact is that I found Part II — while being math heavy — to be much easier to write than Part I, which focused on explaining the intuition. This doesn’t mean that the math is easy — it simply means that sometimes the intuition behind the math is extremely hard to put into words.

I first learned the EM algorithm from Stanford CS229: Machine Learning, and then I became a course assistant for Stanford CS221: Artificial Intelligence: Principles and Techniques where I needed to explain this algorithm to many students. In CS229, I learned to navigate through the proof and use EM to solve several math problems, yet failed to develop a good intuition; In CS221, I was finally able to fully appreciate the beauty of this algorithm with a few “Explain Like I’m Five (ELI5)” examples similar to two coins — which I believe was not only helpful to myself, but also to quite a few CS221 students.

Even so, I found it quite challenging to explain the intuition in words (rather than in math). After some research, I’ve found some excellent tutorials online, yet the majority of them are still overwhelmingly math heavy for beginners, while those with ELI5 examples usually skip the math derivations. Therefore, I decided to write a learning note by myself — with the hope that it will be accessible enough for beginners but still rigorous enough for readers seeking math proofs. Feel free to comment below or send me an email if you have any feedback or questions!

References

[1] What is the expectation maximization algorithm? Chuong B Do, Serafim Batzoglou. Nature, 2008. [paper]

[2] Expectation Maximization. Benjamin Bray. UMich EECS 545: Machine Learning course notes, 2016. [course notes]

[3] The EM algorithm. Tengyu Ma, Andrew Ng. Stanford CS 229: Machine Learning course notes, 2019. [course notes]

[4] Bayesian networks: EM algorithm. Stanford CS 221: Artificial Intelligence: Principles and Techniques slides, 2021. [slides]

[5] 如何感性地理解EM算法?工程师milter. 简书, 2017. [blog post]

[6] Coin Flipping and EM. Karl Rosaen, chansoo. UMich EECS 545: Machine Learning Materials. [Jupyter Notebook]