Variational Inference

The posterior in Bayesian Models with latent variables is quite hard to compute if we do not choose a conjuagte prior for the likelihood function. However, we need to have some approximate version of the posterior in order to make inferences later on. One approach to approximate posterior is to find a distribution qQ (Q is a family of distributions) that can minimize the KL divergence p(z|X)minqQKL(q(z)p(z|X)) where z is the latent variable and X is the data And we do not need to worry about the normalization constant Z (p(X)) KL(q(z)p(z|X))=zq(z)logq(z)p(X|z)p(z)/Z=zq(z)logq(z)p(X|z)p(z)+log(Z)p(z|X)minqQKL(q(z)p(X|z)p(z))(p(X|z)p(z)=p(z)) since zq(z)dz=1. We only need to work with likelihood and prior in this case.

E-step in EM

utilizes the variational inference technique. The E-step requires us to calculate the posterior of the latent variables, which can be hard to compute. VI can help us approximate that distribution assuming that we restrict ourselves to a family of distributions. q(z)=p(z|x,θ)minqQKLq(z)p(z) This is called Variational EM.

Mean Field Approximations

This is a Variational Inference method where we assume the distribution q to be factorized over the latent variables across all dimensions d, i.e., Q={q|q(z)=i=1dqi(zi)}q(z)=argminqQKL(i=1dqi(zi)p(z)) and we minimize the KL divergence using coordinated gradient descent. First find the minima for q1 keeping everything else fixed, then q2, and so on. We will repeat this loop until convergence.

Since we are minimizing over one component at a time, the functional form of any qk(zk) can be derived as follows minqkKL((j=1dqj)(zj)p(z))=minqk(j=1dqj(zj))logj=1dqj(zj)p(z)dz=minqk{i=1d(j=1dqi(zi))log(qi(zi))dz(j=1dqj(zj))log(p(z))dz}=minqk{(qk(zk)log(qk(zk))(j=1,jkdqj(zj)dzj)dzk)+i=1,ikd(qi(zi)log(qi(zi))(j=1,jidqj(zj)dzi)dzi)(j=1dqj(zj))log(p(z))dz}Note thatj=1dqj(zj)dz=(j=1d)qj(zj)dz1dz2dzd=j=1d(qj(zj)dzj)=1minqk{qk(zk)log(qk(zk))dzk+i=1,ikdqi(zi)log(qi(zi))dziqk(zk)((j=1,jkdqj(zj))log(p(z))dzj)dzk}=minqk{qk(zk)[log(qk(zk))((j=1,jkdqj(zj))log(p(z))dzj)]dzk}becausei=1,ikdqi(zi)log(qi(zi))dziis constant wrt qk

We now make three observations

  1. adding a constant to the minimization equation will still give the same result

  2. qk(zk)dzk=1 because qk is a probability distribution

  3. (j=1,jkdqj(zj))log(p(z))dzj=Eqk[log(p(z))] since j=1,jkdqj(zj) is a valid probability distribution. Notice that the expectation integrates over all z except zk and thus is a function of just zk. qk in the expectation just means that we are considering the product j=1,jkdqj(zj) as the probability distribution. We convert this expectation to a positive value by taking the exponent, and then get a valid probability distribution as t(zk)=exp(Eqk[log(p(z))])exp(Eqk[log(p(z))])dzk The denominator of this expression is a constant which we can introduce in the equation as is without any alteration

Rewriting the minimization equation so far minqk{qk(zk)[log(qk(zk))log(exp(Eqk[log(p(z))]))]dzk+(exp(Eqk[log(p(z))])dzk)qk(zk)dzk}=minqk{qk(zk)log(qk(zk))logexp(Eqk[log(p(z))])exp(Eqk[log(p(z))])dzk}=minqk{qk(zk)logqk(zk)t(zk)}minqkKL((j=1dqj)(zj)p(z))=minqkKL(qk(zk)t(zk)) and we know that KL divergence is minimized when the two functions conincide, i.e. qk(zk)=t(zk)=exp(Eqk[log(p(z))])exp(Eqk[log(p(z))])dzklog(qk(zk))=Eqk[log(p(z))]constant which is the expectation of the posterior on z over q without the current component being minimized.