目录

KL 散度与离散流匹配中的广义 KL 损失

系列 - Deep Learning系列

在流匹配(Flow Matching)及相关生成模型的讨论中,我们经常遇到四个概念:散度(div)KL 散度交叉熵。它们分属不同领域,但彼此联系紧密。下面分别解释其定义、物理意义以及在流匹配(尤其是离散流匹配)中的角色。


在连续空间 Rd\mathbb{R}^d 中,给定速度场 ut(x)Rdu_t(x) \in \mathbb{R}^d,其散度定义为各分量对相应坐标的偏导数之和:

div(ut)(x)=i=1dutixi(x). \operatorname{div}(u_t)(x) = \sum_{i=1}^d \frac{\partial u_t^i}{\partial x^i}(x).

它刻画了向量场在点 xx 的“膨胀”程度:若 div>0\operatorname{div}>0,则质量向外扩散;若 div<0\operatorname{div}<0,则质量向内汇聚。

  • 连续流匹配:散度出现在连续性方程中,用于描述概率密度 ptp_t 的演化: ptt+div(ptut)=0. \frac{\partial p_t}{\partial t} + \operatorname{div}(p_t u_t) = 0. 此外,计算生成样本的对数似然时需要沿轨迹积分散度: logp1(x1)=logp0(x0)01div(ut(xt))dt. \log p_1(x_1) = \log p_0(x_0) - \int_0^1 \operatorname{div}(u_t(x_t))\, dt.
  • 离散流匹配:由于状态空间是离散的,没有连续的导数,因此没有直接的“散度”算子。但速率矩阵 ut(yx)u_t(y|x) 必须满足 行和为零 的条件: yut(yx)=0,ut(yx)0 (yx), \sum_{y} u_t(y|x) = 0,\quad u_t(y|x) \ge 0\ (y\neq x), 这保证了概率质量守恒,相当于连续情况下的“无散度”约束。

KL 散度衡量两个概率分布 PPQQ 之间的差异:

DKL(PQ)=xP(x)logP(x)Q(x)(离散),DKL(pq)=p(x)logp(x)q(x)dx(连续). D_{\text{KL}}(P\|Q) = \sum_x P(x)\log\frac{P(x)}{Q(x)} \quad (\text{离散}),\qquad D_{\text{KL}}(p\|q) = \int p(x)\log\frac{p(x)}{q(x)}\,dx \quad (\text{连续}).

它满足非负性,且当且仅当 P=QP=Q 时为零。

  • 作为 Bregman 散度:在流匹配的损失函数中,可以选择不同的 Bregman 散度来度量速度场之间的差异。当选择 KL 散度 作为 Bregman 散度时,条件流匹配损失可以化简为关于后验分布的损失,即 广义 KL 损失
  • 离散流匹配中的广义 KL 损失LGKL=Et,X0,X1,Xtiλt[(11Xti=X1i)(logp1tθ(X1iXt))+(1Xti=X1ip1tθ(XtiXt))], \mathcal{L}_{\text{GKL}} = \mathbb{E}_{t, X_0, X_1, X_t} \sum_i \lambda_t \left[ (1-\mathbf{1}_{X_t^i=X_1^i})(-\log p_{1|t}^\theta(X_1^i|X_t)) + (\mathbf{1}_{X_t^i=X_1^i} - p_{1|t}^\theta(X_t^i|X_t)) \right], 该损失包含了交叉熵项和正则项,其推导基于 KL 散度。

熵度量一个概率分布的不确定性:

H(P)=xP(x)logP(x)(离散),H(p)=p(x)logp(x)dx(连续). H(P) = -\sum_x P(x)\log P(x) \quad (\text{离散}),\qquad H(p) = -\int p(x)\log p(x)\,dx \quad (\text{连续}).
  • 熵本身通常不作为直接训练目标,但常出现在损失函数的分解中。例如,交叉熵可以分解为熵与 KL 散度之和: H(P,Q)=H(P)+DKL(PQ). H(P,Q) = H(P) + D_{\text{KL}}(P\|Q).
  • 在离散流匹配的广义 KL 损失中,当 XtX1X_t \neq X_1 时,损失项 logp1t(X1Xt)-\log p_{1|t}(X_1|X_t) 正是交叉熵(因为真实分布是点质量,熵为 0,故交叉熵等于 KL 散度)。因此,训练过程实际上是在最小化 KL 散度。

交叉熵衡量两个分布 PPQQ 之间的“不一致性”:

H(P,Q)=xP(x)logQ(x)(离散),H(p,q)=p(x)logq(x)dx(连续). H(P,Q) = -\sum_x P(x)\log Q(x) \quad (\text{离散}),\qquad H(p,q) = -\int p(x)\log q(x)\,dx \quad (\text{连续}).

PP 是真实分布(如点质量)时,最小化交叉熵等价于最大化似然。

  • 训练分类器:在离散流匹配中,神经网络被训练来预测后验分布 p1tθ(Xt)p_{1|t}^\theta(\cdot|X_t)。当当前 token XtX_t 不等于目标 X1X_1 时,损失函数中的 logp1tθ(X1Xt)-\log p_{1|t}^\theta(X_1|X_t) 就是交叉熵。因此,模型在大部分时间(当 XtX_t 尚未到达目标时)学习预测正确的目标 token,这类似于去噪自编码器。
  • 与熵和 KL 的关系:由于真实分布是点质量(熵为 0),交叉熵恰好等于 KL 散度。故最小化交叉熵等价于最小化真实分布与预测分布之间的 KL 散度。

概念数学定义(离散)物理/信息论意义在流匹配中的作用
散度(div)iuixi\sum_i \frac{\partial u^i}{\partial x^i}向量场的膨胀率连续:控制概率流动,计算似然;离散:速率矩阵行和为零
KL 散度xP(x)logP(x)Q(x)\sum_x P(x)\log\frac{P(x)}{Q(x)}分布间差异作为 Bregman 散度,导出广义 KL 损失
xP(x)logP(x)-\sum_x P(x)\log P(x)分布的不确定性隐含在交叉熵中,不直接训练
交叉熵xP(x)logQ(x)-\sum_x P(x)\log Q(x)预测分布与真实分布的不一致离散流匹配训练中的主要项(当 XtX1X_t\neq X_1

在离散流匹配中,我们通常不直接计算散度,而是通过广义 KL 损失(包含交叉熵和正则项)来训练模型,使其学会从当前状态预测目标状态,从而间接构建出满足守恒条件的速率场。

KL 散度交叉熵是信息论中密切相关的两个概念,在机器学习中经常用于衡量两个分布之间的差异,但它们在定义、性质和用途上有所不同。


PPQQ 是定义在相同离散空间上的两个概率分布(连续情况类似)。

  • 交叉熵(Cross‑Entropy):

    H(P,Q)=xP(x)logQ(x). H(P, Q) = -\sum_{x} P(x) \log Q(x).

    它表示用分布 QQ 来编码来自分布 PP 的样本时所需的平均比特数(如果对数以 2 为底)。

  • KL 散度(Kullback‑Leibler Divergence):

    DKL(PQ)=xP(x)logP(x)Q(x)=H(P,Q)H(P), D_{\text{KL}}(P \| Q) = \sum_{x} P(x) \log \frac{P(x)}{Q(x)} = H(P, Q) - H(P),

    其中 H(P)=xP(x)logP(x)H(P) = -\sum_x P(x) \log P(x)PP 的熵。


两者通过熵联系起来:

DKL(PQ)=H(P,Q)H(P). D_{\text{KL}}(P \| Q) = H(P, Q) - H(P).
  • PP 固定时,最小化交叉熵等价于最小化 KL 散度,因为 H(P)H(P) 是常数。
  • 特别地,当 PP 是 one‑hot 分布(例如分类任务中的真实标签)时,H(P)=0H(P)=0,此时交叉熵等于 KL 散度: H(P,Q)=DKL(PQ). H(P, Q) = D_{\text{KL}}(P \| Q). 这也是为什么分类任务中常将交叉熵损失等同于负对数似然。

方面KL 散度交叉熵
对称性不对称:DKL(PQ)DKL(QP)D_{\text{KL}}(P\|Q) \neq D_{\text{KL}}(Q\|P)不对称:H(P,Q)H(Q,P)H(P,Q) \neq H(Q,P)
非负性总是 0\ge 0,且等于 0 当且仅当 P=QP=Q可以小于 0(当使用自然对数时,但通常定义为非负?实际交叉熵可以是任意正数,但 H(P,Q)H(P)H(P,Q) \ge H(P) 非负)
熵的依赖显式减去 H(P)H(P)包含 H(P)H(P) 在内
优化目标常用于分布匹配(如变分自编码器中的 KL 项)常用于分类、语言模型等(直接最大化似然)
数值稳定性直接计算可能遇到 log(0) 问题,需处理同样有 log(0) 问题,但分类时常用 cross_entropy 函数内部做了稳定处理

  • 分类任务:真实标签为 one‑hot(PP),模型输出概率 QQ,损失为 logQ(ytrue)-\log Q(y_{\text{true}}),即交叉熵。
  • 语言模型:预测下一个词的概率,真实词为 one‑hot,损失为负对数似然,等价于交叉熵。
  • 变分自编码器(VAE):ELBO 中包含后验与先验的 KL 散度,作为正则项。
  • 知识蒸馏:学生模型输出分布 QQ 与教师模型输出分布 PP 之间的 KL 散度。
  • 离散流匹配中的广义 KL 损失:虽然不直接是 KL 散度,但源于以 KL 散度作为 Bregman 散度推导而来,形式中包含交叉熵项和正则项。

假设真实分布 P=[0.2,0.3,0.5]P = [0.2, 0.3, 0.5],模型预测 Q=[0.1,0.4,0.5]Q = [0.1, 0.4, 0.5]

  • H(P)=(0.2log0.2+0.3log0.3+0.5log0.5)1.029H(P) = -(0.2\log0.2 + 0.3\log0.3 + 0.5\log0.5) \approx 1.029(以 e 为底)。
  • 交叉熵 H(P,Q)=(0.2log0.1+0.3log0.4+0.5log0.5)1.152H(P,Q) = -(0.2\log0.1 + 0.3\log0.4 + 0.5\log0.5) \approx 1.152
  • KL 散度 DKL(PQ)=1.1521.029=0.123D_{\text{KL}}(P\|Q) = 1.152 - 1.029 = 0.123

  • KL 散度 = 交叉熵 − 熵。
  • 当真实分布 PP 固定时,优化交叉熵等价于优化 KL 散度(因为熵是常数)。
  • 在分类问题中,两者数值相等(因熵为 0),但概念上交叉熵是更直接的损失函数。
  • 在离散流匹配的广义 KL 损失中,实际使用的是加权交叉熵加正则项,它源于以 KL 散度为 Bregman 散度的条件流匹配损失,因此保留了“广义 KL”的名称。

在连续流匹配中,散度 div(ut)\operatorname{div}(u_t) 用于对数似然计算。常用的高效方法是 Hutchinson 迹估计

div(u)(x)=EϵN(0,I)[ϵux(x)ϵ]ϵ(uϵ)x(x). \operatorname{div}(u)(x) = \mathbb{E}_{\epsilon \sim \mathcal{N}(0,I)} \left[ \epsilon^\top \frac{\partial u}{\partial x}(x) \epsilon \right] \approx \epsilon^\top \frac{\partial (u \cdot \epsilon)}{\partial x}(x).

PyTorch 实现

def divergence_hutchinson(u_func, x, eps=None):
    """
    计算向量场 u 在点 x 处的散度(Hutchinson 估计)。
    u_func: 可调用对象,输入 x (batch, d),输出 u (batch, d)
    x: 张量,形状 (batch, d),requires_grad=True
    eps: 可选,随机噪声,默认从标准正态采样
    """
    if eps is None:
        eps = torch.randn_like(x)
    u = u_func(x)                     ## (batch, d)
    dot = (u * eps).sum(dim=1, keepdim=True)  ## (batch, 1)
    grad_dot = torch.autograd.grad(dot, x, grad_outputs=torch.ones_like(dot),
                                    create_graph=True)[0]  ## (batch, d)
    div = (grad_dot * eps).sum(dim=1)                     ## (batch,)
    return div

使用示例:

x = torch.randn(16, 10, requires_grad=True)  ## 16个10维点
u_func = lambda x: -x                         ## 简单速度场
div_val = divergence_hutchinson(u_func, x)     ## 形状 (16,)

离散流匹配中无直接的散度计算,但需要检查速率矩阵的行和是否为零:

def check_rate_matrix(u, x):
    """
    u: 速度张量,形状 (batch, seq_len, vocab_size)
    x: 当前 token 索引,仅用于调试
    """
    row_sum = u.sum(dim=-1)  ## 每行的和应为0
    assert torch.allclose(row_sum, torch.zeros_like(row_sum), atol=1e-6)

对于两个离散分布 PP(真实)和 QQ(预测),KL 散度:

DKL(PQ)=kPklogPkQk. D_{\text{KL}}(P\|Q) = \sum_k P_k \log\frac{P_k}{Q_k}.

PyTorch 实现(使用 F.kl_div,注意输入是对数概率):

import torch.nn.functional as F

## P 为概率向量(如 one-hot),Q_logits 为未归一化 logits
p_probs = torch.tensor([0.2, 0.3, 0.5])  ## 真实分布
q_logits = torch.tensor([1.0, 2.0, 3.0]) ## 模型输出
q_log_probs = F.log_softmax(q_logits, dim=-1)

kl = F.kl_div(q_log_probs, p_probs, reduction='sum')  ## 注意参数顺序:input=log(Q), target=P

对于连续密度 ppqq,KL 散度可通过蒙特卡洛估计:

DKL(pq)1Ni=1N(logp(xi)logq(xi)),xip. D_{\text{KL}}(p\|q) \approx \frac{1}{N}\sum_{i=1}^N \bigl(\log p(x_i) - \log q(x_i)\bigr), \quad x_i \sim p.

PyTorch 示例(假设已知对数密度函数):

def log_p(x):  ## 真实分布的对数密度
    return -0.5 * (x**2).sum(dim=-1) - 0.5 * x.shape[-1] * np.log(2*np.pi)

def log_q(x):  ## 模型分布的对数密度
    return -0.5 * ((x - mu)**2).sum(dim=-1) / sigma**2 - 0.5 * x.shape[-1] * np.log(2*np.pi*sigma**2)

samples = torch.randn(1000, 10)  ## 从真实分布采样(此处为标准正态)
kl = (log_p(samples) - log_q(samples)).mean()

H(P)=kPklogPk. H(P) = -\sum_k P_k \log P_k.

PyTorch 实现

probs = torch.softmax(logits, dim=-1)   ## 形状 (batch, K)
entropy = -(probs * probs.log()).sum(dim=-1).mean()  ## 平均熵

微分熵:

H(p)=p(x)logp(x)dx1Ni=1Nlogp(xi),xip. H(p) = -\int p(x)\log p(x)\,dx \approx -\frac{1}{N}\sum_{i=1}^N \log p(x_i), \quad x_i \sim p.

PyTorch 实现(假设可采样和对数密度已知):

samples = torch.randn(1000, 10)  ## 从 p 采样
log_p_vals = log_p(samples)      ## 计算对数密度
entropy = -log_p_vals.mean()

对于真实分布 PP(常为 one‑hot)和模型预测 QQ,交叉熵:

H(P,Q)=kPklogQk. H(P,Q) = -\sum_k P_k \log Q_k.

PyTorch 实现(使用 F.cross_entropy,它结合了 log_softmax 和 NLL):

## logits: (batch, K), labels: (batch,) 真实类别索引
loss = F.cross_entropy(logits, labels, reduction='mean')

当需要显式计算概率时:

probs = F.softmax(logits, dim=-1)
cross_entropy = -(probs[range(len(labels)), labels]).log().mean()

对于连续空间,交叉熵定义为:

H(p,q)=p(x)logq(x)dx1Ni=1Nlogq(xi),xip. H(p,q) = -\int p(x)\log q(x)\,dx \approx -\frac{1}{N}\sum_{i=1}^N \log q(x_i), \quad x_i \sim p.

PyTorch 实现

samples = torch.randn(1000, 10)  ## 从真实分布采样
log_q_vals = log_q(samples)      ## 模型分布的对数密度
cross_entropy = -log_q_vals.mean()

在 DFM 的 广义 KL 损失 中,实际使用的是 加权交叉熵 + 正则项,而非直接调用上述函数。但我们可以将其理解为:

  • XtX1X_t \neq X_1 时,损失项为 λtlogp1tθ(X1Xt)-\lambda_t \log p_{1|t}^\theta(X_1|X_t),这正是加权交叉熵
  • Xt=X1X_t = X_1 时,损失项为 λt(1p1tθ(XtXt))\lambda_t (1 - p_{1|t}^\theta(X_t|X_t)),可看作对过高置信度的惩罚。

该损失在代码中通常这样实现:

def generalized_kl_loss(logits, x_1, x_t, t, scheduler):
    ## logits: (batch, seq_len, vocab)
    ## x_1, x_t: (batch, seq_len)
    ## t: (batch,)
    log_p_1t = F.log_softmax(logits, dim=-1)                      ## log p(y|x_t)
    p_1t = log_p_1t.exp()                                         ## p(y|x_t)
    log_p_x1 = torch.gather(log_p_1t, -1, x_1.unsqueeze(-1)).squeeze(-1)
    p_xt = torch.gather(p_1t, -1, x_t.unsqueeze(-1)).squeeze(-1)
    delta = (x_t == x_1).float()
    lam = scheduler(t)                                            ## lambda_t = d_kappa / (1 - kappa)
    lam = lam.view(-1, *([1]*(x_1.dim()-1)))                     ## 广播
    loss = -lam * ((1 - delta) * log_p_x1 + (delta - p_xt))
    return loss.mean()

连续空间离散空间主要用途
散度Hutchinson 估计或自动微分无直接计算(行和为零)似然计算、守恒约束
KL 散度蒙特卡洛估计F.kl_div分布匹配、变分推断
蒙特卡洛估计-(p*log p).sum()不确定性度量
交叉熵蒙特卡洛估计F.cross_entropy分类、最大似然训练

在离散流匹配中,这些概念被融合进广义 KL 损失中,通过加权交叉熵和正则项实现模型训练。理解它们的计算方法有助于调试和扩展 DFM 代码。

相关内容