实现diffusion模型(手写数字集)
前段时间阅读了 Denoising Diffusion Probabilistic Models论文,最近在MNIST数据集进行了实现,效果还不错。如图1所示,展示的是去噪的过程,图片左上角是时间戳。当时间戳=1000时候,是一个白噪声;当时间戳=1时候,是一张清晰的手写数字。
图1. 恢复过程(去噪过程)
模型介绍
Diffusion 模型是加利福尼亚大学Jonathan等人提出的,该模型在图像生成、多模态模型方面具有巨大的影响力。比如,给一个模糊的图片,你可以使用Diffusion模型得到更清晰的高分图。 模型主要由两个部分组成,扩散和去噪过程。下面分别描述下两个过程所做的事,
前向扩散过程
- 定义隐变量:定义一个隐变量分布q(xt∣xt−1)q(x_{t}|x_{t-1})q(xt∣xt−1)以描述了时刻t−1t-1t−1到时刻ttt的差异。
- 逐步加噪:从初始图x0x_0x0开始,逐步加入生成噪声ϵt\epsilon_{t}ϵt,直到得到一个完全随机的噪声样本xTx_TxT。
去噪过程
- 定义条件分布:定义一个条件分布q(xt−1∣xt)q(x_{t-1}|x_t)q(xt−1∣xt)以描述如何从噪声样本xtx_{t}xt得到去噪的样本xt−1x_{t-1}xt−1。
- 逐步去噪:从完全随机的噪声样本xTx_{T}xT开始,逐步去除预测噪声ϵt\epsilon_{t}ϵt,直到得到一个清晰的图片x0x_{0}x0。
如上过程描述,那么扩散模型就是一个隐变量模型了。如何理解呢? 我们可以把扩散过程当作是一种编码过程,得到的是噪声空间中图像的潜在表示。同样的,去噪过程当作是一种解码过程,映射到原始数据。 不同之处是扩散模型一步一步加噪或者去噪,而VAE之类是直接一步得到隐变量或者原始数据。直观上理解,将问题转为多阶段处理的过程是不会增加问题的复杂性的。
训练与去噪
下面,我们对模型训练和去噪过程进行解释。
图2. 训练和恢复过程
先看训练过程,如图2所示
- 获取一张图x0x_{0}x0
- 采样一个时间ttt与噪声ϵ\epsilonϵ
- 利用噪声ϵ\epsilonϵ与预测噪声ϵθ\epsilon_{\theta}ϵθ之差产生的梯度更新预测噪声模型的参数
注意,ϵθ\epsilon_{\theta}ϵθ预测的是当前输入xtx_{t}xt和时间戳ttt相对应的噪声,利用该噪声对当前输入xtx_{t}xt进行去噪,并不是某一个时间戳时所添加的噪声,xt=at‾x0+1−at‾ϵx_{t}=\sqrt{\overline{a_{t}}}x_{0}+\sqrt{1-\overline{a_{t}}}\epsilonxt=atx0+1−atϵ,这是通过xtx_{t}xt和xt−1x_{t-1}xt−1的关系式迭代得到的。表面是预测噪声,实际上是预测xt−1x_{t-1}xt−1,只不过转移到预测噪声上了,具体可见李宏毅老师视频课程。
再看去噪过程
- 获取一高斯噪声(噪声空间中的潜在表示)
- 对于一个时间戳ttt,利用预测的噪声对xtx_{t}xt进行去噪(另一个潜在表示)得到xt−1x_{t-1}xt−1,再加一个σtz\sigma_{t}zσtz增加一定的随机性
注意,可以看到xt−1x_{t-1}xt−1是xtx_{t}xt与预测噪声的加权和,权重啥的都是超参数,故而我们只需要预测当前输入xtx_{t}xt和时间戳ttt相对应的噪声。 那为什么要加入噪声zzz? 增加一些随机性,试考虑VAE数据点对应潜表示也是一个概率分布,而不是一个固定表示。这里加个噪声zzz是不是也可以理解了呢。 σt\sigma_{t}σt的值是什么? 查询资料还没找到有标准值,我代码实现的时候用的是βt\beta_{t}βt,考虑xt=1−βt×xt+βt×ϵx_{t}=\sqrt{1-\beta_{t}}\times x_{t}+\sqrt{\beta_t}\times \epsilonxt=1−βt×xt+βt×ϵ。
模型实现
扩散模型使用unet预测噪声,unet网络特色是U型设计和跳层残差,利用U型设计,可以捕捉不同层次的特征(抽象特征或者细节特征),再利用跳层残差,可以有效的融合不同层次特征带来的不同信息。因此,unet在医学图像、语义分割等领域,其效果很好。需要注意的是,这里面的unet是需要考虑时间戳的,因为预测噪声模型需要xtx_{t}xt和ttt作为输入。关于扩散模型unet的细节,大家可以移步参考资料2,那博主非常用心,画了很多图,可以有助于你理解。 关于Diffusion model在MNIST数据集的代码,大家可以去我github查阅。如果有什么问题,也可以留言一起讨论交流。
参考资料
[1] 扩散模型 - Diffusion Model【李宏毅2023】 [2] 深入浅出扩散模型(Diffusion Model)系列:基石DDPM(源码解读篇) - 知乎 (zhihu.com) [3] 深入浅出扩散模型(Diffusion Model)系列:基石DDPM(人人都能看懂的数学原理篇) - 知乎 (zhihu.com)
转载自:https://juejin.cn/post/7363823940606066728