背景

VAE 是一个隐变量模型(LVM),其核心包含两个组件

$$ x \xrightarrow[q_{\phi}(z|x)]{\text{Encoder}} z \xrightarrow[p_{\theta}(x|z)]{\text{Decoder}} \hat{x} $$

我们可以使用最大似然估计来训练这个模型,似然被定义为:

$$ \log p_{\theta}(x) =

\underbrace{\mathbb{E}_{z \sim p_{\theta}(z|x)}[\log p_{\theta}(x\mid z)]}_{\text{Reconstruction } \mathcal{L}_{\text{rec}}} - \underbrace{\mathbb{D}_{KL}(p_{\theta}(z|x), \Vert, p(z))}_{\text{Regularisation } \mathcal{L}_{\text{reg}}} $$

我们可以对重建损失 $\mathcal{L}_{\text{rec}}$ 作出概率假设,从而进行下一步推演。需要注意的是,我们总是在最大化似然,也就是最大化 $\log p_{\theta}(x)$ 也就是 $\mathcal{L}_{\text{rec}}$,最小化 $\mathcal{L}_{\text{reg}}$。

通常来说,我们会对其进行两种概率假设:正态分布和伯努利分布。

正态分布假设

即,我们认为 $p(x\mid z) \sim \mathcal{N}(\mu, \sigma^2I)$,$p(z) \sim \mathcal{N}(0, I)$。因此对其取对数,得到:

$$ \begin{aligned} \log p_{\theta}(x) &= \mathbb{E}_{z \sim p_{\theta}(z|x)}[\log p_{\theta}(x\mid z)] \\ &= \mathbb{E}_{z \sim p_{\theta}(z|x)}\left[\log \mathcal{N}(x; \mu, \sigma^2I)\right] \\ &= \mathbb{E}_{z \sim p_{\theta}(z|x)}\left[\log \frac{1}{\sqrt{2\pi\sigma^2}} \exp\left(-\frac{(x-\mu)^2}{2\sigma^2}\right)\right] \\ &= \mathbb{E}_{z \sim p_{\theta}(z|x)}\left[-\frac{(x-\mu)^2}{2\sigma^2} - \frac{1}{2}\log(2\pi\sigma^2)\right] \\ &\approx \sum_{x\in \mathcal{X}} \left[-\frac{(x-\mu)^2}{2\sigma^2} - \frac{1}{2}\log(2\pi\sigma^2)\right] \\

&= \sum_{x\in \mathcal{X}} \left[-\frac{1}{2\sigma^2} \Vert x-\mu\Vert^2- \frac{1}{2}\log(2\pi\sigma^2)\right] \\

&= \sum_{x\in \mathcal{X}} \left[-\frac{1}{2\sigma^2} \Vert \hat{x}-x\Vert^2- \frac{1}{2}\log(2\pi\sigma^2)\right] \\

&= \sum_{x\in \mathcal{X}} \left[-\frac{1}{2\sigma^2} \text{MSE}(\hat{x}, x)- \frac{1}{2}\log(2\pi\sigma^2)\right] \\ &= \sum_{x\in \mathcal{X}} \left[-\frac{1}{2\sigma^2} \text{MSE}(\hat{x}, x)\right] \end{aligned} $$

伯努利分布假设

即,我们认为 $p(x\mid z) = \prod_{i=1}^D \text{Bernoulli}(x_i; \hat{x}_i)$。

$$ \begin{aligned} p_\theta(x\mid z) &= \prod_{i=1}^D \text{Bernoulli}(x_i; \hat{x}_i) \\ &= \prod^D_{i=1} \hat{x}_i^{x_i} (1-\hat{x}_i)^{1-x_i} \end{aligned} $$

对其取对数,得到:

$$ \begin{aligned} \log p_{\theta}(x) &= \mathbb{E}_{z \sim p_{\theta}(z|x)}[\log p_{\theta}(x\mid z)] \\ &= \log \prod^D_{i=1} \left[\hat{x}_i^{x_i} (1-\hat{x}_i)^{1-x_i}\right] \\ &= \sum_{i=1}^D\left[ x_i \log \hat{x}_i + (1-x_i) \log (1-\hat{x}_i)\right] \\ &= \sum_{i=1}^D\left[ -\text{BCE}(x_i, \hat{x}_i)\right] \\ &= -\text{BCE}(x, \hat{x}) \end{aligned} $$

损失函数

因此可以看出对于假设正态分布的模型,其损失函数为:

$$ \mathcal{L}_{\text{rec}} = -\text{MSE}(\hat{x}, x) $$

对于假设伯努利分布的模型,其损失函数为:

$$ \mathcal{L}_{\text{rec}} = - \text{BCE}(\hat{x}, x) $$