Representation Learning with Latent Variable Models

In this post, I discuss what makes the posterior distribution of a directed latent variable model a useful representation. Also, I raise questions that deserve careful study in the case of undirected latent variable models.

As Ferenc Huszár discussed in his blog post, training a latent variable model (LVM) via maximum likelihood estimation (MLE) does not necessarily lead to a “useful” representation. Actually, in the infinite model family limit, learning a “useful” representation of your data and achieving a high likelihood value (or generating high quality samples) are two orthogonal directions. For VAEs, what makes their latent representation useful is their auto-encoder based model structure as well as the ELBO maximization objective (Tschannen et al., 2018). However, what makes a useful representation in the case of undirected LVMs is still mysterious to me.

Uninformative latent variables in LVMs

In the literature, it is reported that in the context of VAEs, a power decoder pθ(xz)p_\theta(x\mid z) like PixelCNN can generate high-quality samples, however the mutual information I(x;z)\mathbb{I}(x;z) between the generated sample xx and the conditioned latent variable zz is very low, this is fatal to representation learning since the representation now contains almost no information of the original data. This is intuitive if you consider a LVM pθ(x,z)=pθ(xz)pθ(z)p_\theta(x, z) = p_\theta(x\mid z)p_\theta(z), and when your decoder pθ(xz)p_\theta(x\mid z) remains to be a powerful density estimator (e.g. an EBM) on pdata(x)p_{\text{data}}(x) after setting the weights connecting to the latent variable zz to 0, your LVM implicitly becomes to pθ(x,z)=pθ(x)pθ(z)p_\theta(x, z) = p_\theta(x)p_\theta(z) which leads to independence and zero mutual information.

I also come up with another way to approximately explain the above effect. Let’s take the latent variable model as a mixture distribution and we would like to use its marginal pθ(x)=pθ(xz)pθ(z)dzp_\theta(x) = \int p_\theta(x\mid z)p_\theta(z) \mathrm{d}z to model the pdata(x)p_{\text{data}}(x). Consider the learning process as minimizing the forward KL-divergence (although MLE is minimizing the reverse KL-divergence, conceptually we can think they have similar effect on the model): KL(pθ(x)pdata(x))=pθ(z)pθ(xz)logpθ(x)pdata(x)dxdz=pθ(z)pθ(xz)logpθ(xz)pθ(x)pdata(x)pθ(xz)dxdz=pθ(z)pθ(xz)logpθ(xz)pdata(x)dxdzpθ(z)pθ(xz)logpθ(xz)pθ(x)dxdz=Epθ(z)[KL(pθ(xz)pdata(x))]I(x;z)0, \begin{aligned} \operatorname{KL}\left(p_\theta(x) \parallel p_{\text{data}}(x)\right) &= \int p_\theta(z) \int p_\theta(x\mid z) \log \frac{p_\theta(x)}{p_\text{data}(x)}\mathrm{d}x\mathrm{d}z \\ &= \int p_\theta(z) \int p_\theta(x\mid z) \log \frac{p_\theta(x\mid z)p_\theta(x)}{p_\text{data}(x)p_\theta(x\mid z)}\mathrm{d}x\mathrm{d}z \\ &= \int p_\theta(z) \int p_\theta(x\mid z) \log \frac{p_\theta(x\mid z)}{p_\text{data}(x)}\mathrm{d}x\mathrm{d}z - \int p_\theta(z) \int p_\theta(x\mid z) \log \frac{p_\theta(x\mid z)}{p_\theta(x)}\mathrm{d}x\mathrm{d}z \\ &= \mathbb{E}_{p_\theta(z)}\left[ \operatorname{KL}\left(p_\theta(x\mid z) \parallel p_{\text{data}}(x)\right) \right] - \mathbb{I}\left( x;z \right) \\ &\geq 0, \end{aligned} this gives us an upper bound of the mutual information between xx and zz: I(x;z)Epθ(z)[KL(pθ(xz)pdata(x))], \mathbb{I}(x;z) \leq \mathbb{E}_{p_\theta(z)}\left[ \operatorname{KL}\left(p_\theta(x\mid z) \parallel p_{\text{data}}(x)\right) \right], thus when your decoder pθ(xz)p_\theta(x\mid z) itself can model pdata(x)p_\text{data}(x) well in the sense of averaged forward KL divergence, then little information about xx is contained in zz.

Information theoretical effects of maximizing ELBO

When training a VAE, since the marginal distribution is intractable for direct MLE, we turn to maximizing the evidence lower bound (ELBO): Epdata(x)[Eqϕ(zx)[logpθ(x,z)qϕ(zx)]]1Nn=1NEqϕ(znxn)[logpθ(xnzn)]A1Nn=1NKL(qϕ(znxn)pθ(zn))B, \begin{aligned} \mathbb{E}_{p_{\text{data}}(x)}\left[ \mathbb{E}_{q_\phi(z\mid x)}\left[ \log \frac{p_\theta(x,z)}{q_\phi(z\mid x)} \right] \right] &\approx \underbrace{\frac{1}{N}\sum_{n=1}^N \mathbb{E}_{q_\phi(z_n \mid x_n)}\left[ \log {p_\theta(x_n\mid z_n)} \right]}_{\text{A}} - \underbrace{\frac{1}{N}\sum_{n=1}^N\operatorname{KL}\left( q_\phi(z_n \mid x_n) \parallel p_\theta(z_n)\right)}_{\text{B}}, \end{aligned} the term A is the reconstruction term, which can increase the mutual information between xx and zz, and term B serves as a regularizer to enforce disentangled posterior.

Let’s further examine term B. First, define the aggregated posterior as: qϕ(z)=qϕ(zx)pdata(x)dxq_\phi(z) = \int q_\phi(z\mid x)p_{\text{data}}(x)\mathrm{d}x, then we have: KL(qϕ(z)pθ(z))=(qϕ(zx)pdata(x)dx)logqϕ(z)pθ(z)dz=qϕ(zx)pdata(x)logqϕ(z)qϕ(zx)pθ(z)qϕ(zx)dxdz=qϕ(zx)pdata(x)logqϕ(zx)pθ(z)dxdzqϕ(zx)pdata(x)logqϕ(zx)pdata(x)qϕ(z)pdata(x)dxdz=Epdata(x)[KL(qϕ(zx)pθ(z))]BI(x;z). \begin{aligned} \operatorname{KL}\left( q_\phi(z) \parallel p_\theta(z) \right) &= \int\left(\int q_\phi(z\mid x) p_{\text{data}}(x) \mathrm{d}x \right)\log \frac{q_\phi(z)}{p_\theta(z)}\mathrm{d}z \\ &= \iint q_\phi(z\mid x)p_{\text{data}}(x)\log \frac{q_\phi(z)q_\phi(z\mid x)}{p_\theta(z)q_\phi(z\mid x)}\mathrm{d}x\mathrm{d}z \\ &= \iint q_\phi(z\mid x)p_{\text{data}}(x)\log\frac{q_\phi(z\mid x)}{p_\theta(z)}\mathrm{d}x\mathrm{d}z - \iint q_\phi(z\mid x)p_{\text{data}}(x)\log \frac{q_\phi(z\mid x)p_{\text{data}}(x)}{q_\phi(z)p_{\text{data}}(x)}\mathrm{d}x\mathrm{d}z \\ &= \underbrace{\mathbb{E}_{p_\text{data}(x)}\left[ \operatorname{KL}\left( q_\phi(z\mid x)\parallel p_\theta(z) \right) \right]}_{\text{B}} - \mathbb{I}(x;z). \end{aligned} In this way, the regularizer over the posterior distribution becomes to: B=KL(qϕ(z)pθ(z))+I(x;z), \text{B} =\operatorname{KL}\left( q_\phi(z) \parallel p_\theta(z) \right) + \mathbb{I}(x;z), and it will penalize the model for high mutual information between xx and zz. Note that without this, the model will try its best to encode everything about xx into zz to achieve low reconstruction error, this will not lead to a useful representation.

In (Alemi et al., 2018), the authors analyze the ELBO with the rate-distortion theory, where distortion (D) and rate (R) are defined as: D=Epdata(x)[Eqϕ(zx)[logpθ(xz)]]=A,R=Epdata(x)[qϕ(zx)logqϕ(zx)pθ(z)]=B,ELBO=(D+R)=AB. \begin{aligned} D &= -\mathbb{E}_{p_{\text{data}}(x)}\left[ \mathbb{E}_{q_\phi(z\mid x)}\left[ \log p_\theta(x\mid z) \right] \right] = -\text{A}, \\ R &= \mathbb{E}_{p_{\text{data}}(x)}\left[ q_\phi(z\mid x)\log \frac{q_\phi(z\mid x)}{p_\theta(z)} \right] = \text{B}, \\ \operatorname{ELBO} &= -(D + R) = \text{A} - \text{B}. \end{aligned} Note that when we train a VAE by maximizing the ELBO, we are maximizing the sum of DD and RR. And as shown in the paper, when your decoder pθ(xz)p_\theta(x\mid z) is pretty powerful, the RR term (which is equal to B\text{B}) can be pushed to 0, this leads I(x;z)=0\mathbb{I}(x;z) = 0 because I(x;z)=BKL(qϕ(z)pθ(z))\mathbb{I}(x;z) = \text{B} - \operatorname{KL}\left( q_\phi(z) \parallel p_\theta(z) \right) and both KL divergence and mutual information are non-negative terms. Adding a weighting parameter β\beta between DD and RR leads to a tradeoff: (D+βR)=1Nn=1NEqϕ(znxn)[logpθ(xnzn)]1Nn=1NβKL(qϕ(znxn)pθ(zn)), \begin{aligned} -(D + \beta R) = {\frac{1}{N}\sum_{n=1}^N \mathbb{E}_{q_\phi(z_n \mid x_n)}\left[ \log {p_\theta(x_n\mid z_n)} \right]} - {\frac{1}{N}\sum_{n=1}^N\beta\operatorname{KL}\left( q_\phi(z_n \mid x_n) \parallel p_\theta(z_n)\right)}, \end{aligned} and this reveals the β\beta-VAE objective (Higgins et al., 2017).

Open questions: What about undirected LVMs?

The original hope of generative representation learning is that if we can create all the data that we have seen, then we implicitly may learn a representation that can be used to answer any question about the data. And in the sense of generative modeling, both directed and undirected latent variable models can perform well.

However, the undirected LVMs (i.e. energy-based latent variable models) are not trained by maximizing the ELBO, thus the above theory is not capable of analyzing its representational property. And it is interesting to study on the following questions:

  1. What makes the learned representation in EBLVMs useful? The model or the learning algorithm? In the above discussion on the directed LVMs, we can say it’s the auto-encoder based structure accompany with the ELBO maximization objective makes the representation useful. Does a bipartite graphical (like RBMs) structure play a similar role as the auto-encoder struture in VAEs?
  2. In (Wu et al., 2021) (Liao et al., 2022) (Lee et al., 2023), the latent variables of the EBLVMs can be marginalized and the marginal energy functions are analytically available. This enables us to train these models as common EBMs with no latent variables. When the training procedure has no explicit connection with zz, the representational usefulness of pθ(zx)p_\theta(z\mid x) highly depends on the extra constrain provided by the model structure. It is reported in (Wu et al., 2021) that posterior collapse may happen in their conjugate energy-based model and the mutual information between xx and the encoded zz is low, while in (Liao et al., 2022), when the GRBM is trained with the marginal energy, it tends to map xx to an almost deterministic latent code zz which preserves high mutual information. So different design of joint energy function leads to different level of coupling between xx and zz, then which part of an EBLVM is essential for that kind of coupling? Does the problem have any connection with latent variable identifiability (Wang et al., 2021)?
  3. High mutual information between xx and zz does not necessarily lead to a good representation, can we establish similar rate-distortion theory in the context of undirected LVMs as in (Alemi et al., 2018) (Tschannen et al., 2018)?
  1. Tschannen, M., Bachem, O., & Lucic, M. (2018). Recent advances in autoencoder-based representation learning. ArXiv Preprint ArXiv:1812.05069.
  2. Alemi, A., Poole, B., Fischer, I., Dillon, J., Saurous, R. A., & Murphy, K. (2018). Fixing a broken ELBO. International Conference on Machine Learning, 159–168.
  3. Higgins, I., Matthey, L., Pal, A., Burgess, C., Glorot, X., Botvinick, M., Mohamed, S., & Lerchner, A. (2017). beta-vae: Learning basic visual concepts with a constrained variational framework. International Conference on Learning Representations.
  4. Wu, H., Esmaeili, B., Wick, M., Tristan, J.-B., & Van De Meent, J.-W. (2021). Conjugate Energy-Based Models. International Conference on Machine Learning, 11228–11239.
  5. Liao, R., Kornblith, S., Ren, M., Fleet, D. J., & Hinton, G. (2022). Gaussian-Bernoulli RBMs Without Tears. ArXiv Preprint ArXiv:2210.10318.
  6. Lee, H., Jeong, J., Park, S., & Shin, J. (2023). Guiding Energy-based Models via Contrastive Latent Variables. ArXiv Preprint ArXiv:2303.03023.
  7. Wang, Y., Blei, D., & Cunningham, J. P. (2021). Posterior collapse and latent variable non-identifiability. Advances in Neural Information Processing Systems, 34, 5443–5455.