Paper Reading: Analytic-DPM

正文索引 [隐藏]

🚧

Background

在之前的DDPM, DDIM的模型中, Backward Progress的方差都是固定的或者是有规律的数字, 本身没有被考虑进动态的Backward Progress中, 这篇文章用很深厚的功力告诉我们仅仅依赖于和之前一样的KL散度的Loss设计, 我们是可以推出Backward Progress的分析上的最优的条件期望和方差的, 接下来这篇博客将会直接进行一个论文的抄.

Basic Knowledge

首先我们关注DDIM文章中对扩散模型的一个扩充的形式

$$
\begin{aligned}
&q_{\lambda}\left(\boldsymbol{x}_{1: N} \mid \boldsymbol{x}_{0}\right)=q_{\lambda}\left(\boldsymbol{x}_{N} \mid \boldsymbol{x}_{0}\right) \prod_{n=2}^{N} q_{\lambda}\left(\boldsymbol{x}_{n-1} \mid \boldsymbol{x}_{n}, \boldsymbol{x}_{0}\right) \\
&q_{\lambda}\left(\boldsymbol{x}_{N} \mid \boldsymbol{x}_{0}\right)=\mathcal{N}\left(\boldsymbol{x}_{N} \mid \sqrt{\bar{\alpha}} \boldsymbol{x}_{0}, \bar{\beta}_{N} \boldsymbol{I}\right) \\
&q_{\lambda}\left(\boldsymbol{x}_{n-1} \mid \boldsymbol{x}_{n}, \boldsymbol{x}_{0}\right)=\mathcal{N}\left(\boldsymbol{x}_{n-1} \mid \tilde{\boldsymbol{\mu}}_{n}\left(\boldsymbol{x}_{n}, \boldsymbol{x}_{0}\right), \lambda_{n}^{2} \boldsymbol{I}\right) \\
&\tilde{\boldsymbol{\mu}}_{n}\left(\boldsymbol{x}_{n}, \boldsymbol{x}_{0}\right)=\sqrt{\bar{\alpha}_{n-1}} \boldsymbol{x}_{0}+\sqrt{\bar{\beta}_{n-1}-\lambda_{n}^{2}} \cdot \frac{\boldsymbol{x}_{n}-\sqrt{\bar{\alpha}_{n}} \boldsymbol{x}_{0}}{\sqrt{\bar{\beta}_{n}}}
\end{aligned}
$$

并且上述形式是基于保证如下的分布形式所设计出来的

$$
q_{\lambda}\left(\boldsymbol{x}_{n} \mid \boldsymbol{x}_{0}\right)=\mathcal{N}\left(\boldsymbol{x}_{n} \mid \sqrt{\bar{\alpha}_{n}} \boldsymbol{x}_{0}, \bar{\beta}_{n} \boldsymbol{I}\right)
$$

其中$ \bar{\alpha}_{n}:=\prod_{i=1}^{n} \alpha_{i}$, 并且$\bar{\beta}_{n}:=1-\bar{\alpha}_{n}$.
熟悉相关数学形式的同学也很容易知道DDPM是上述形式中$\lambda_{n}^{2}=\tilde{\beta}_{n}$,  $\tilde{\beta}_{n}:=\frac{\bar{\beta}_{n-1}}{\bar{\beta}_{n}} \beta_{n}$的特例, 此时Forward Progress可以是马尔科夫的. 而当$\lambda_{n}^{2}=0$时, 就是DDIM模型

虽然正向过程未必是马尔科夫过程了, 但我们依旧假设我们的Reverse Progress是Markov的(或者说我们假设我们能够拿到一个不准的$x_0$), 我们建了一个神经网络去学习这样的Reverse Progress, 从终态的标准高斯分布$p\left(\boldsymbol{x}_{N}\right)=\mathcal{N}\left(\boldsymbol{x}_{N} \mid \mathbf{0}, \boldsymbol{I}\right)$ 出发:

$$p\left(\boldsymbol{x}_{0: N}\right)=p\left(\boldsymbol{x}_{N}\right) \prod_{n=1}^{N} p\left(\boldsymbol{x}_{n-1} \mid \boldsymbol{x}_{n}\right), \quad p\left(\boldsymbol{x}_{n-1} \mid \boldsymbol{x}_{n}\right)=\mathcal{N}\left(\boldsymbol{x}_{n-1} \mid \boldsymbol{\mu}_{n}\left(\boldsymbol{x}_{n}\right), \sigma_{n}^{2} \boldsymbol{I}\right)$$

在之前的理论中, 我们仅考虑神经网络去拟合均值, 用一个预测噪声的网络, 或者说是Score-based model $\boldsymbol{s}_n(\boldsymbol{x}_n)$来表示这样的均值:

$$
\boldsymbol{\mu}_{n}\left(\boldsymbol{x}_{n}\right)=\tilde{\boldsymbol{\mu}}_{n}\left(\boldsymbol{x}_{n}, \frac{1}{\sqrt{\bar{\alpha}_{n}}}\left(\boldsymbol{x}_{n}+\bar{\beta}_{n} \boldsymbol{s}_{n}\left(\boldsymbol{x}_{n}\right)\right)\right)
$$

训练的Loss由ELBo表示, 经过简单的推到有如下的形式:

$$
L_{\mathrm{vb}}=\mathbb{E}_{q}\left[-\log p\left(\boldsymbol{x}_{0} \mid \boldsymbol{x}_{1}\right)+\sum_{n=2}^{N} D_{\mathrm{KL}}\left(q\left(\boldsymbol{x}_{n-1} \mid \boldsymbol{x}_{0}, \boldsymbol{x}_{n}\right) \| p\left(\boldsymbol{x}_{n-1} \mid \boldsymbol{x}_{n}\right)\right)+D_{\mathrm{KL}}\left(q\left(\boldsymbol{x}_{N} \mid \boldsymbol{x}_{0}\right) \| p\left(\boldsymbol{x}_{N}\right)\right)\right]
$$

而Yang Song等人实际上是直接对比的Forward Progress和Reverse Progress联合分布的KL散度, 事实上这两者是等价的:

$$
\min _{\left\{\boldsymbol{\mu}_{n}, \sigma_{n}^{2}\right\}_{n=1}^{N}} L_{\mathrm{vb}} \Leftrightarrow \min _{\left\{\boldsymbol{\mu}_{n}, \sigma_{n}^{2}\right\}_{n=1}^{N}} D_{\mathrm{KL}}\left(q\left(\boldsymbol{x}_{0: N}\right) \| p\left(\boldsymbol{x}_{0: N}\right)\right)
$$

而在实际操作中, Ho 等人发现扔掉系数效果更好, 有了采样的算法, 这里形式化如下:

$$
\min _{\left\{\boldsymbol{s}_{n}\right\}_{n=1}^{N}} \mathbb{E}_{n} \bar{\beta}_{n} \mathbb{E}_{q_{n}\left(\boldsymbol{x}_{n}\right)}\left\|\boldsymbol{s}_{n}\left(\boldsymbol{x}_{n}\right)-\nabla_{\boldsymbol{x}_{n}} \log q_{n}\left(\boldsymbol{x}_{n}\right)\right\|^{2}=\mathbb{E}_{n, \boldsymbol{x}_{0}, \boldsymbol{\epsilon}}\left\|\boldsymbol{\epsilon}+\sqrt{\bar{\beta}_{n}} \boldsymbol{s}_{n}\left(\boldsymbol{x}_{n}\right)\right\|^{2}+c
$$

其中n是1到N的均匀分布, $\boldsymbol{x}_{n}=\sqrt{\bar{\alpha}_{n}} \boldsymbol{x}_{0}+\sqrt{\bar{\beta}_{n}} \boldsymbol{\epsilon}$, c是常数, Yang Song 等人发现他们简化后的事实上每个$\boldsymbol{s}_n$是跟他们提出的基于Langevin动力学的对应步的最优解一致, 也就是$\boldsymbol{s}^*_n(\boldsymbol{x}_n) =\nabla_{\boldsymbol{x}_{n}} \log q_{n}\left(\boldsymbol{x}_{n}\right) $

Theory

这篇文章claim说之前只优化均值只能算是一个次优方案, 提出了一个分析上的最优解:

定理 1: 相同的ELBoLoss, 上述扩充形式的最优解如下:

$$
\begin{aligned}
&\boldsymbol{\mu}_{n}^{*}\left(\boldsymbol{x}_{n}\right)=\tilde{\boldsymbol{\mu}}_{n}\left(\boldsymbol{x}_{n}, \frac{1}{\sqrt{\bar{\alpha}_{n}}}\left(\boldsymbol{x}_{n}+\bar{\beta}_{n} \nabla_{\boldsymbol{x}_{n}} \log q_{n}\left(\boldsymbol{x}_{n}\right)\right)\right) \\
&\sigma_{n}^{* 2}=\lambda_{n}^{2}+\left(\sqrt{\frac{\bar{\beta}_{n}}{\alpha_{n}}}-\sqrt{\bar{\beta}_{n-1}-\lambda_{n}^{2}}\right)^{2}\left(1-\bar{\beta}_{n} \mathbb{E}_{q_{n}\left(\boldsymbol{x}_{n}\right)} \frac{\|\nabla_{\boldsymbol{x}_{n}} \log q_{n}\left(\boldsymbol{x}_{n}\right)\|^{2}}{d}\right)
\end{aligned}
$$

其中上式的$ \mathbb{E}_{q_{n}\left(\boldsymbol{x}_{n}\right)} \frac{\|\nabla_{\boldsymbol{x}_{n}} \log q_{n}\left(\boldsymbol{x}_{n}\right)\|^{2}}{d}$通过蒙特卡洛方法进行估计, 注意到我们的实现中是通过网络来构建Score function $ \boldsymbol{s}_{n}\left(\boldsymbol{x}_{n}\right)$来逼近$\nabla_{\boldsymbol{x}_{n}} \log q_{n}\left(\boldsymbol{x}_{n}\right)$的, 于是我们通过蒙特卡洛方法估计出如下参数

$$
\Gamma_{n}=\frac{1}{M} \sum_{m=1}^{M} \frac{\left\|\boldsymbol{s}_{n}\left(\boldsymbol{x}_{n, m}\right)\right\|^{2}}{d}, \quad \boldsymbol{x}_{n, m} \stackrel{i i d}{\sim} q_{n}\left(\boldsymbol{x}_{n}\right)
$$

就能每一步选最准的方差了:

$$
\hat{\sigma}_{n}^{2}=\lambda_{n}^{2}+\left(\sqrt{\frac{\bar{\beta}_{n}}{\alpha_{n}}}-\sqrt{\bar{\beta}_{n-1}-\lambda_{n}^{2}}\right)^{2}\left(1-\bar{\beta}_{n} \Gamma_{n}\right)
$$

但是这会有另外一个问题, 就是蒙特卡洛方法可能不准, 我们通过定理去估计上下界做一个截断来增强算法的准确性:
定理2. 我们算出来的最优方差有这样较紧致的上下界:

$$
\lambda_{n}^{2} \leq \sigma_{n}^{* 2} \leq \lambda_{n}^{2}+\left(\sqrt{\frac{\bar{\beta}_{n}}{\alpha_{n}}}-\sqrt{\bar{\beta}_{n-1}-\lambda_{n}^{2}}\right)^{2}
$$

而且当我们初始的数据分布有界$[a,b]^d$的时候, 我们可以有更紧致的上界:

$$
\sigma_{n}^{* 2} \leq \lambda_{n}^{2}+\left(\sqrt{\bar{\alpha}_{n-1}}-\sqrt{\bar{\beta}_{n-1}-\lambda_{n}^{2}} \cdot \sqrt{\frac{\bar{\alpha}_{n}}{\bar{\beta}_{n}}}\right)^{2}\left(\frac{b-a}{2}\right)^{2}
$$

详细的证明就不抄了(懒狗发言
证明思路抄一下…
定理二的证明思路很naive啊, 就是估计一波就完事了
定理一的证明思路大概是这样的
  1. 首先呢我们先应用上moment matching的方法, 给出一般分布和高斯分布之间KL散度用moment表示的闭形式
  2. 之后呢作者小心地将$q(\boldsymbol{x}_{n-1}\mid \boldsymbol{x}_{n})$ 的moment用$q(\boldsymbol{x}_{0}\mid \boldsymbol{x}_{n})$表示了出来
  3. 然后$q(\boldsymbol{x}_{0}\mid \boldsymbol{x}_{n})$是可以用score function完全表出的