KL 散度与离散流匹配中的广义 KL 损失
原理解释
在流匹配(Flow Matching)及相关生成模型的讨论中,我们经常遇到四个概念:散度(div)、KL 散度、熵 和 交叉熵。它们分属不同领域,但彼此联系紧密。下面分别解释其定义、物理意义以及在流匹配(尤其是离散流匹配)中的角色。
1. 散度(div)
定义(向量场的散度)
在连续空间 中,给定速度场 ,其散度定义为各分量对相应坐标的偏导数之和:
它刻画了向量场在点 的“膨胀”程度:若 ,则质量向外扩散;若 ,则质量向内汇聚。
在流匹配中的作用
- 连续流匹配:散度出现在连续性方程中,用于描述概率密度 的演化: 此外,计算生成样本的对数似然时需要沿轨迹积分散度:
- 离散流匹配:由于状态空间是离散的,没有连续的导数,因此没有直接的“散度”算子。但速率矩阵 必须满足 行和为零 的条件: 这保证了概率质量守恒,相当于连续情况下的“无散度”约束。
2. KL 散度(Kullback‑Leibler Divergence)
定义
KL 散度衡量两个概率分布 和 之间的差异:
它满足非负性,且当且仅当 时为零。
在流匹配中的作用
- 作为 Bregman 散度:在流匹配的损失函数中,可以选择不同的 Bregman 散度来度量速度场之间的差异。当选择 KL 散度 作为 Bregman 散度时,条件流匹配损失可以化简为关于后验分布的损失,即 广义 KL 损失。
- 离散流匹配中的广义 KL 损失: 该损失包含了交叉熵项和正则项,其推导基于 KL 散度。
3. 熵(Entropy)
定义
熵度量一个概率分布的不确定性:
在流匹配中的作用
- 熵本身通常不作为直接训练目标,但常出现在损失函数的分解中。例如,交叉熵可以分解为熵与 KL 散度之和:
- 在离散流匹配的广义 KL 损失中,当 时,损失项 正是交叉熵(因为真实分布是点质量,熵为 0,故交叉熵等于 KL 散度)。因此,训练过程实际上是在最小化 KL 散度。
4. 交叉熵(Cross‑Entropy)
定义
交叉熵衡量两个分布 和 之间的“不一致性”:
当 是真实分布(如点质量)时,最小化交叉熵等价于最大化似然。
在流匹配中的作用
- 训练分类器:在离散流匹配中,神经网络被训练来预测后验分布 。当当前 token 不等于目标 时,损失函数中的 就是交叉熵。因此,模型在大部分时间(当 尚未到达目标时)学习预测正确的目标 token,这类似于去噪自编码器。
- 与熵和 KL 的关系:由于真实分布是点质量(熵为 0),交叉熵恰好等于 KL 散度。故最小化交叉熵等价于最小化真实分布与预测分布之间的 KL 散度。
总结对比
| 概念 | 数学定义(离散) | 物理/信息论意义 | 在流匹配中的作用 |
|---|---|---|---|
| 散度(div) | 向量场的膨胀率 | 连续:控制概率流动,计算似然;离散:速率矩阵行和为零 | |
| KL 散度 | 分布间差异 | 作为 Bregman 散度,导出广义 KL 损失 | |
| 熵 | 分布的不确定性 | 隐含在交叉熵中,不直接训练 | |
| 交叉熵 | 预测分布与真实分布的不一致 | 离散流匹配训练中的主要项(当 ) |
在离散流匹配中,我们通常不直接计算散度,而是通过广义 KL 损失(包含交叉熵和正则项)来训练模型,使其学会从当前状态预测目标状态,从而间接构建出满足守恒条件的速率场。
KL散度和交叉熵的区别
KL 散度和交叉熵是信息论中密切相关的两个概念,在机器学习中经常用于衡量两个分布之间的差异,但它们在定义、性质和用途上有所不同。
1. 定义
设 和 是定义在相同离散空间上的两个概率分布(连续情况类似)。
交叉熵(Cross‑Entropy):
它表示用分布 来编码来自分布 的样本时所需的平均比特数(如果对数以 2 为底)。
KL 散度(Kullback‑Leibler Divergence):
其中 是 的熵。
2. 关系
两者通过熵联系起来:
- 当 固定时,最小化交叉熵等价于最小化 KL 散度,因为 是常数。
- 特别地,当 是 one‑hot 分布(例如分类任务中的真实标签)时,,此时交叉熵等于 KL 散度: 这也是为什么分类任务中常将交叉熵损失等同于负对数似然。
3. 主要区别
| 方面 | KL 散度 | 交叉熵 |
|---|---|---|
| 对称性 | 不对称: | 不对称: |
| 非负性 | 总是 ,且等于 0 当且仅当 | 可以小于 0(当使用自然对数时,但通常定义为非负?实际交叉熵可以是任意正数,但 非负) |
| 熵的依赖 | 显式减去 | 包含 在内 |
| 优化目标 | 常用于分布匹配(如变分自编码器中的 KL 项) | 常用于分类、语言模型等(直接最大化似然) |
| 数值稳定性 | 直接计算可能遇到 log(0) 问题,需处理 | 同样有 log(0) 问题,但分类时常用 cross_entropy 函数内部做了稳定处理 |
4. 在机器学习中的应用
交叉熵(Cross‑Entropy Loss)
- 分类任务:真实标签为 one‑hot(),模型输出概率 ,损失为 ,即交叉熵。
- 语言模型:预测下一个词的概率,真实词为 one‑hot,损失为负对数似然,等价于交叉熵。
KL 散度
- 变分自编码器(VAE):ELBO 中包含后验与先验的 KL 散度,作为正则项。
- 知识蒸馏:学生模型输出分布 与教师模型输出分布 之间的 KL 散度。
- 离散流匹配中的广义 KL 损失:虽然不直接是 KL 散度,但源于以 KL 散度作为 Bregman 散度推导而来,形式中包含交叉熵项和正则项。
5. 举例说明
假设真实分布 ,模型预测 。
- 熵 (以 e 为底)。
- 交叉熵 。
- KL 散度 。
6. 总结
- KL 散度 = 交叉熵 − 熵。
- 当真实分布 固定时,优化交叉熵等价于优化 KL 散度(因为熵是常数)。
- 在分类问题中,两者数值相等(因熵为 0),但概念上交叉熵是更直接的损失函数。
- 在离散流匹配的广义 KL 损失中,实际使用的是加权交叉熵加正则项,它源于以 KL 散度为 Bregman 散度的条件流匹配损失,因此保留了“广义 KL”的名称。
PyTorch 代码示例
1. 散度(Divergence)
1.1 连续空间(速度场的散度)
在连续流匹配中,散度 用于对数似然计算。常用的高效方法是 Hutchinson 迹估计:
PyTorch 实现:
def divergence_hutchinson(u_func, x, eps=None):
"""
计算向量场 u 在点 x 处的散度(Hutchinson 估计)。
u_func: 可调用对象,输入 x (batch, d),输出 u (batch, d)
x: 张量,形状 (batch, d),requires_grad=True
eps: 可选,随机噪声,默认从标准正态采样
"""
if eps is None:
eps = torch.randn_like(x)
u = u_func(x) ## (batch, d)
dot = (u * eps).sum(dim=1, keepdim=True) ## (batch, 1)
grad_dot = torch.autograd.grad(dot, x, grad_outputs=torch.ones_like(dot),
create_graph=True)[0] ## (batch, d)
div = (grad_dot * eps).sum(dim=1) ## (batch,)
return div使用示例:
x = torch.randn(16, 10, requires_grad=True) ## 16个10维点
u_func = lambda x: -x ## 简单速度场
div_val = divergence_hutchinson(u_func, x) ## 形状 (16,)1.2 离散空间(速率矩阵的约束)
离散流匹配中无直接的散度计算,但需要检查速率矩阵的行和是否为零:
def check_rate_matrix(u, x):
"""
u: 速度张量,形状 (batch, seq_len, vocab_size)
x: 当前 token 索引,仅用于调试
"""
row_sum = u.sum(dim=-1) ## 每行的和应为0
assert torch.allclose(row_sum, torch.zeros_like(row_sum), atol=1e-6)2. KL 散度(Kullback‑Leibler Divergence)
2.1 离散分布
对于两个离散分布 (真实)和 (预测),KL 散度:
PyTorch 实现(使用 F.kl_div,注意输入是对数概率):
import torch.nn.functional as F
## P 为概率向量(如 one-hot),Q_logits 为未归一化 logits
p_probs = torch.tensor([0.2, 0.3, 0.5]) ## 真实分布
q_logits = torch.tensor([1.0, 2.0, 3.0]) ## 模型输出
q_log_probs = F.log_softmax(q_logits, dim=-1)
kl = F.kl_div(q_log_probs, p_probs, reduction='sum') ## 注意参数顺序:input=log(Q), target=P2.2 连续分布
对于连续密度 和 ,KL 散度可通过蒙特卡洛估计:
PyTorch 示例(假设已知对数密度函数):
def log_p(x): ## 真实分布的对数密度
return -0.5 * (x**2).sum(dim=-1) - 0.5 * x.shape[-1] * np.log(2*np.pi)
def log_q(x): ## 模型分布的对数密度
return -0.5 * ((x - mu)**2).sum(dim=-1) / sigma**2 - 0.5 * x.shape[-1] * np.log(2*np.pi*sigma**2)
samples = torch.randn(1000, 10) ## 从真实分布采样(此处为标准正态)
kl = (log_p(samples) - log_q(samples)).mean()3. 熵(Entropy)
3.1 离散分布
PyTorch 实现:
probs = torch.softmax(logits, dim=-1) ## 形状 (batch, K)
entropy = -(probs * probs.log()).sum(dim=-1).mean() ## 平均熵3.2 连续分布
微分熵:
PyTorch 实现(假设可采样和对数密度已知):
samples = torch.randn(1000, 10) ## 从 p 采样
log_p_vals = log_p(samples) ## 计算对数密度
entropy = -log_p_vals.mean()4. 交叉熵(Cross‑Entropy)
4.1 离散分布
对于真实分布 (常为 one‑hot)和模型预测 ,交叉熵:
PyTorch 实现(使用 F.cross_entropy,它结合了 log_softmax 和 NLL):
## logits: (batch, K), labels: (batch,) 真实类别索引
loss = F.cross_entropy(logits, labels, reduction='mean')当需要显式计算概率时:
probs = F.softmax(logits, dim=-1)
cross_entropy = -(probs[range(len(labels)), labels]).log().mean()4.2 连续分布
对于连续空间,交叉熵定义为:
PyTorch 实现:
samples = torch.randn(1000, 10) ## 从真实分布采样
log_q_vals = log_q(samples) ## 模型分布的对数密度
cross_entropy = -log_q_vals.mean()5. 在离散流匹配(DFM)中的体现
在 DFM 的 广义 KL 损失 中,实际使用的是 加权交叉熵 + 正则项,而非直接调用上述函数。但我们可以将其理解为:
- 当 时,损失项为 ,这正是加权交叉熵。
- 当 时,损失项为 ,可看作对过高置信度的惩罚。
该损失在代码中通常这样实现:
def generalized_kl_loss(logits, x_1, x_t, t, scheduler):
## logits: (batch, seq_len, vocab)
## x_1, x_t: (batch, seq_len)
## t: (batch,)
log_p_1t = F.log_softmax(logits, dim=-1) ## log p(y|x_t)
p_1t = log_p_1t.exp() ## p(y|x_t)
log_p_x1 = torch.gather(log_p_1t, -1, x_1.unsqueeze(-1)).squeeze(-1)
p_xt = torch.gather(p_1t, -1, x_t.unsqueeze(-1)).squeeze(-1)
delta = (x_t == x_1).float()
lam = scheduler(t) ## lambda_t = d_kappa / (1 - kappa)
lam = lam.view(-1, *([1]*(x_1.dim()-1))) ## 广播
loss = -lam * ((1 - delta) * log_p_x1 + (delta - p_xt))
return loss.mean()总结
| 量 | 连续空间 | 离散空间 | 主要用途 |
|---|---|---|---|
| 散度 | Hutchinson 估计或自动微分 | 无直接计算(行和为零) | 似然计算、守恒约束 |
| KL 散度 | 蒙特卡洛估计 | F.kl_div | 分布匹配、变分推断 |
| 熵 | 蒙特卡洛估计 | -(p*log p).sum() | 不确定性度量 |
| 交叉熵 | 蒙特卡洛估计 | F.cross_entropy | 分类、最大似然训练 |
在离散流匹配中,这些概念被融合进广义 KL 损失中,通过加权交叉熵和正则项实现模型训练。理解它们的计算方法有助于调试和扩展 DFM 代码。