likes
comments
collection
share

实现diffusion模型(手写数字集)

作者站长头像
站长
· 阅读数 44

前段时间阅读了 Denoising Diffusion Probabilistic Models论文,最近在MNIST数据集进行了实现,效果还不错。如图1所示,展示的是去噪的过程,图片左上角是时间戳。当时间戳=1000时候,是一个白噪声;当时间戳=1时候,是一张清晰的手写数字。 实现diffusion模型(手写数字集) 实现diffusion模型(手写数字集)

图1. 恢复过程(去噪过程)

模型介绍

Diffusion 模型是加利福尼亚大学Jonathan等人提出的,该模型在图像生成、多模态模型方面具有巨大的影响力。比如,给一个模糊的图片,你可以使用Diffusion模型得到更清晰的高分图。 模型主要由两个部分组成,扩散和去噪过程。下面分别描述下两个过程所做的事,

前向扩散过程

  • 定义隐变量:定义一个隐变量分布q(xt∣xt−1)q(x_{t}|x_{t-1})q(xtxt1)以描述了时刻t−1t-1t1到时刻ttt的差异。
  • 逐步加噪:从初始图x0x_0x0开始,逐步加入生成噪声ϵt\epsilon_{t}ϵt,直到得到一个完全随机的噪声样本xTx_TxT

去噪过程

  • 定义条件分布:定义一个条件分布q(xt−1∣xt)q(x_{t-1}|x_t)q(xt1xt)以描述如何从噪声样本xtx_{t}xt得到去噪的样本xt−1x_{t-1}xt1
  • 逐步去噪:从完全随机的噪声样本xTx_{T}xT开始,逐步去除预测噪声ϵt\epsilon_{t}ϵt,直到得到一个清晰的图片x0x_{0}x0

如上过程描述,那么扩散模型就是一个隐变量模型了。如何理解呢? 我们可以把扩散过程当作是一种编码过程,得到的是噪声空间中图像的潜在表示。同样的,去噪过程当作是一种解码过程,映射到原始数据。 不同之处是扩散模型一步一步加噪或者去噪,而VAE之类是直接一步得到隐变量或者原始数据。直观上理解,将问题转为多阶段处理的过程是不会增加问题的复杂性的。

训练与去噪

下面,我们对模型训练和去噪过程进行解释。

实现diffusion模型(手写数字集)

图2. 训练和恢复过程

先看训练过程,如图2所示

  • 获取一张图x0x_{0}x0
  • 采样一个时间ttt与噪声ϵ\epsilonϵ
  • 利用噪声ϵ\epsilonϵ与预测噪声ϵθ\epsilon_{\theta}ϵθ之差产生的梯度更新预测噪声模型的参数

注意,ϵθ\epsilon_{\theta}ϵθ预测的是当前输入xtx_{t}xt和时间戳ttt相对应的噪声,利用该噪声对当前输入xtx_{t}xt进行去噪,并不是某一个时间戳时所添加的噪声,xt=at‾x0+1−at‾ϵx_{t}=\sqrt{\overline{a_{t}}}x_{0}+\sqrt{1-\overline{a_{t}}}\epsilonxt=atx0+1atϵ,这是通过xtx_{t}xtxt−1x_{t-1}xt1的关系式迭代得到的。表面是预测噪声,实际上是预测xt−1x_{t-1}xt1,只不过转移到预测噪声上了,具体可见李宏毅老师视频课程。

再看去噪过程

  • 获取一高斯噪声(噪声空间中的潜在表示)
  • 对于一个时间戳ttt,利用预测的噪声对xtx_{t}xt进行去噪(另一个潜在表示)得到xt−1x_{t-1}xt1,再加一个σtz\sigma_{t}zσtz增加一定的随机性

注意,可以看到xt−1x_{t-1}xt1xtx_{t}xt与预测噪声的加权和,权重啥的都是超参数,故而我们只需要预测当前输入xtx_{t}xt和时间戳ttt相对应的噪声。 那为什么要加入噪声zzz 增加一些随机性,试考虑VAE数据点对应潜表示也是一个概率分布,而不是一个固定表示。这里加个噪声zzz是不是也可以理解了呢。 σt\sigma_{t}σt的值是什么? 查询资料还没找到有标准值,我代码实现的时候用的是βt\beta_{t}βt,考虑xt=1−βt×xt+βt×ϵx_{t}=\sqrt{1-\beta_{t}}\times x_{t}+\sqrt{\beta_t}\times \epsilonxt=1βt×xt+βt×ϵ

模型实现

扩散模型使用unet预测噪声,unet网络特色是U型设计和跳层残差,利用U型设计,可以捕捉不同层次的特征(抽象特征或者细节特征),再利用跳层残差,可以有效的融合不同层次特征带来的不同信息。因此,unet在医学图像、语义分割等领域,其效果很好。需要注意的是,这里面的unet是需要考虑时间戳的,因为预测噪声模型需要xtx_{t}xtttt作为输入。关于扩散模型unet的细节,大家可以移步参考资料2,那博主非常用心,画了很多图,可以有助于你理解。 关于Diffusion model在MNIST数据集的代码,大家可以去我github查阅。如果有什么问题,也可以留言一起讨论交流。

参考资料

[1] 扩散模型 - Diffusion Model【李宏毅2023】 [2] 深入浅出扩散模型(Diffusion Model)系列:基石DDPM(源码解读篇) - 知乎 (zhihu.com) [3] 深入浅出扩散模型(Diffusion Model)系列:基石DDPM(人人都能看懂的数学原理篇) - 知乎 (zhihu.com)

转载自:https://juejin.cn/post/7363823940606066728
评论
请登录