目录

MeanFlow

MeanFlow 算法完整技术文档(原理+数学推导+训练推理+工程代码) 本文档为正式技术规格文档,完整复现 MeanFlow 核心理论、数学推导、训练流程、采样逻辑与可运行 Demo。

  • 算法名称:MeanFlow
  • 提出团队:Facebook AI Research(何恺明团队)
  • 发表时间:2025 年
  • 核心定位:一步生成(1‑NFE) 生成式建模框架
  • 基础依托:流匹配(Flow Matching)+ 常微分方程(ODE)
  • 核心创新:从学习瞬时速度场 vv 改为学习区间平均速度场 uu,实现单步高质量生成
  • 扩散模型 / 流匹配:学习瞬时速度场,必须多步欧拉积分才能生成高质量图像,步数越少质量越差。
  • Consistency Models:需要蒸馏、课程学习、多阶段训练,流程复杂且理论不闭合。
  • 一步模型:以往 1 步模型生成质量远低于多步模型。

直接建模时间区间 [r,t][r, t] 上的平均速度,让模型一次性学习“从噪声到数据”的完整位移,而非逐点瞬时变化。

  • ztRdz_t \in \mathbb{R}^d:时间 tt 处的隐变量
  • x0p0x_0 \sim p_0:初始高斯噪声
  • x1pdatax_1 \sim p_{\text{data}}:真实数据分布
  • v(zt,t)v(z_t, t):瞬时速度场(传统 Flow Matching 学习目标)
  • u(zt,r,t)u(z_t, r, t):平均速度场(MeanFlow 学习目标)
  • ODE 动力学:dztdt=v(zt,t)\frac{dz_t}{dt} = v(z_t, t)

平均速度是瞬时速度在区间 [r,t][r, t] 上的积分平均:

u(zt,r,t)=1trrtv(zτ,τ)dτ u(z_t, r, t) = \frac{1}{t - r} \int_{r}^{t} v(z_\tau, \tau) d\tau

物理意义:

  • vv = 某一时刻的瞬间速度
  • uu = 从 rr 走到 tt 的整体平均移动速度

对平均速度定义式关于 tt 求导,利用链式法则与微积分基本定理,可推导出平均速度与瞬时速度的严格恒等式。

v(zt,t)u(zt,r,t)=(tr)(ut+ztuv(zt,t)) v(z_t, t) - u(z_t, r, t) = (t - r) \left( \frac{\partial u}{\partial t} + \nabla_{z_t} u \cdot v(z_t, t) \right)

推导如下。

步骤 1:将定义式两边同乘 (tr)(t - r),得

(tr)u(zt,r,t)=rtv(zτ,τ)dτ. (t - r)\, u(z_t, r, t) = \int_{r}^{t} v(z_\tau, \tau)\, d\tau.

步骤 2:对等式两边关于 tt 求导。

  • 右边:由微积分基本定理,积分对上限 tt 的导数为被积函数在 τ=t\tau = t 处的值,即

    ddtrtv(zτ,τ)dτ=v(zt,t). \frac{d}{dt} \int_{r}^{t} v(z_\tau, \tau)\, d\tau = v(z_t, t).
  • 左边(tr)u(zt,r,t)(t - r)\, u(z_t, r, t)tt 的函数,且 ztz_t 也随 tt 变化(满足 dztdt=v(zt,t)\frac{d z_t}{d t} = v(z_t, t))。由乘积法则,

    ddt[(tr)u(zt,r,t)]=u(zt,r,t)+(tr)ddtu(zt,r,t). \frac{d}{dt}\bigl[ (t - r)\, u(z_t, r, t) \bigr] = u(z_t, r, t) + (t - r)\, \frac{d}{dt} u(z_t, r, t).

    u(zt,r,t)u(z_t, r, t) 关于 tt 求全导数时,uu 既直接依赖 tt,又通过 ztz_t 依赖 tt,故

    ddtu(zt,r,t)=ut+ztudztdt=ut+ztuv(zt,t). \frac{d}{dt} u(z_t, r, t) = \frac{\partial u}{\partial t} + \nabla_{z_t} u \cdot \frac{d z_t}{d t} = \frac{\partial u}{\partial t} + \nabla_{z_t} u \cdot v(z_t, t).

    因此左边等于

    u+(tr)(ut+ztuv(zt,t)). u + (t - r) \left( \frac{\partial u}{\partial t} + \nabla_{z_t} u \cdot v(z_t, t) \right).

步骤 3:左右两边相等,故

u(zt,r,t)+(tr)(ut+ztuv(zt,t))=v(zt,t). u(z_t, r, t) + (t - r) \left( \frac{\partial u}{\partial t} + \nabla_{z_t} u \cdot v(z_t, t) \right) = v(z_t, t).

移项即得核心恒等式

v(zt,t)u(zt,r,t)=(tr)(ut+ztuv(zt,t)) v(z_t, t) - u(z_t, r, t) = (t - r) \left( \frac{\partial u}{\partial t} + \nabla_{z_t} u \cdot v(z_t, t) \right)

该式无近似、无假设、完全严格,是 MeanFlow 训练的理论基础。

从恒等式中解出模型需要拟合的目标平均速度:

utgt=vt(tr)(uθt+ztuθvt) u_{\text{tgt}} = v_t - (t - r) \left( \frac{\partial u_\theta}{\partial t} + \nabla_{z_t}u_\theta \cdot v_t \right)

训练损失为模型输出与目标的 L2 距离:

L(θ)=Er<t,zt,vtuθ(zt,r,t)sg(utgt)22 \mathcal{L}(\theta) = \mathbb{E}_{r<t, z_t, v_t} \left\| u_\theta(z_t, r, t) - \text{sg}(u_{\text{tgt}}) \right\|_2^2
  • sg()\text{sg}(\cdot):停止梯度,保证目标固定
  • vt=x1x0v_t = x_1 - x_0:线性路径下的解析瞬时速度

sg()\text{sg}(\cdot) 详细说明与实现

  • 定义:前向时 sg(x)=x\text{sg}(x) = x(数值不变);反向时 sg(x)x=0\dfrac{\partial \,\text{sg}(x)}{\partial x} = 0,即该节点不向输入传梯度,在计算图中被当作常数。
  • 为何必须用:目标 utgt=vt(tr)(uθ/t+ztuθvt)u_{\text{tgt}} = v_t - (t - r)\big(\partial u_\theta/\partial t + \nabla_{z_t}u_\theta \cdot v_t\big) 依赖 uθu_\theta 及其导数,即依赖 θ\theta。若不做 sg,θL\nabla_\theta \mathcal{L} 会包含“通过 utgtu_{\text{tgt}} 再对 θ\theta 求导”的项,目标会随参数更新而变(移动目标);用 sg 后只对损失里的 uθ(zt,r,t)u_\theta(z_t,r,t) 关于 θ\theta 求导,目标在本次更新中固定,等价于监督学习:拟合给定向量 utgtu_{\text{tgt}}
  • 实现要点:先按公式算出 utgtu_{\text{tgt}}(需要自动微分得到 uθ/t\partial u_\theta/\partial tztuθ\nabla_{z_t} u_\theta),再对 utgtu_{\text{tgt}} 做 stop-gradient,最后算 MSE。下面给出 PyTorch 写法。

PyTorch:用 .detach() 把目标从计算图剥离,反向时梯度不会穿过目标。(ztuθ)vt(\nabla_{z_t} u_\theta)\cdot v_t 为 Jacobian–向量积,用 torch.autograd.functional.jvp 一次算出。

from torch.autograd.functional import jvp

# z_t, t 需 requires_grad=True 以便算 u_tgt 中的导数
u = model(z_t, r, t)   # u_theta(z_t, r, t)

# 时间导数:\partial u / \partial t(保持 z_t 不变)
du_dt = torch.autograd.grad(u, t, grad_outputs=torch.ones_like(u), create_graph=False, allow_unused=True)[0]
if du_dt is None:
    du_dt = torch.zeros_like(u)
# 空间 Jacobian–向量积:(\nabla_{z_t} u) · v_t
_, jvp_z = jvp(lambda z: model(z, r, t), z_t, v_t)

u_tgt = v_t - (t - r) * (du_dt + jvp_z)
u_tgt = u_tgt.detach()   # stop-gradient:loss 反向不传到 u_tgt
loss = F.mse_loss(u, u_tgt)
loss.backward()

推理阶段仅需一次前向传播,直接从噪声映射到数据:

z1=z0+uθ(z0,0,1) z_1 = z_0 + u_\theta(z_0, 0, 1)
  • 010 \to 1 代表完整生成过程
  • 无迭代、无积分、1-NFE 完成

 

1. 采样噪声 x0N(0,I)x_0 \sim \mathcal{N}(0, I)
2. 采样数据 x1pdatax_1 \sim p_{\text{data}}
3. 随机采样时间对 0r<t10 \le r < t \le 1
4. 构造线性插值路径:zt=(1t)x0+tx1z_t = (1-t)x_0 + t x_1
5. 计算解析瞬时速度:vt=x1x0v_t = x_1 - x_0
6. 前向计算模型输出 uθ(zt,r,t)u_\theta(z_t, r, t)
7. 利用自动微分计算:

  • uθ/t\partial u_\theta / \partial t
  • ztuθ\nabla_{z_t} u_\theta

8. 构造目标速度 utgtu_{\text{tgt}}
9. 最小化 MSE 损失更新参数

1. 采样高斯噪声 z0N(0,I)z_0 \sim \mathcal{N}(0, I)
2. 前向计算平均速度 u=uθ(z0,0,1)u = u_\theta(z_0, 0, 1)
3. 一步生成:z1=z0+uz_1 = z_0 + u
4. 输出 z1z_1 为最终样本

 

1. 理论完全闭合 从定义直接推导,无启发式、无蒸馏、无课程学习。 2. 一步生成(1‑NFE) 速度与 GAN 相当,质量逼近多步扩散模型。 3. 训练稳定 损失函数平滑,最优解唯一存在。 4. 天然支持条件生成 可直接嵌入 CFG(无分类器引导),无需修改结构。 5. 兼容所有 DiT / U-Net 架构 只需将输出从瞬时速度 vv 改为平均速度 uu

 

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import grad


# ==============================================================================
# 1. MeanFlow 核心模型:支持 z_t + r + t 输入
# ==============================================================================
class MeanFlowModel(nn.Module):
    def __init__(self, input_dim=2, hidden_dim=256):
        super().__init__()
        # 输入:z + 时间r + 时间t → 输出:平均速度场 u
        self.net = nn.Sequential(
            nn.Linear(input_dim + 2, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, input_dim)
        )

    def forward(self, z, r, t):
        """
        z:  [B, D]
        r:  [B]
        t:  [B]
        return: u [B, D]
        """
        t = t.unsqueeze(1)
        r = r.unsqueeze(1)
        x = torch.cat([z, r, t], dim=1)
        return self.net(x)


# ==============================================================================
# 2. 核心函数:计算 u_tgt(目标平均速度)
# ==============================================================================
def compute_meanflow_target(model, z_t, r, t, v_t):
    """
    实现论文核心公式:
    u_tgt = v_t - (t - r) * (du/dt + ∇z u · v_t)
    """
    B, D = z_t.shape

    # 开启微分
    z_t = z_t.detach().requires_grad_(True)
    t = t.detach().requires_grad_(True)

    # 前向
    u = model(z_t, r, t)

    # 1. 计算 du/dt
    du_dt = grad(u.sum(), t, create_graph=True)[0]  # [B]
    du_dt = du_dt.view(B, 1).expand(B, D)

    # 2. 计算 ∇_z u · v_t
    du_dz = grad(u.sum(), z_t, create_graph=True)[0]  # [B, D]
    du_dz_v = du_dz * v_t

    # 3. 目标平均速度
    delta_t = (t - r).view(B, 1)
    u_tgt = v_t - delta_t * (du_dt + du_dz_v)

    # 停止梯度,保证目标不变
    return u_tgt.detach()


# ==============================================================================
# 3. 单步训练逻辑
# ==============================================================================
def train_one_step(model, optimizer, x0, x1):
    B = x0.shape[0]
    device = x0.device

    # 采样 r < t
    t = torch.rand(B, device=device)
    r = torch.rand(B, device=device) * t

    # 线性路径
    z_t = (1 - t[:, None]) * x0 + t[:, None] * x1
    v_t = x1 - x0  # 瞬时速度

    # 前向
    u_pred = model(z_t, r, t)

    # 计算目标
    u_tgt = compute_meanflow_target(model, z_t, r, t, v_t)

    # 损失
    loss = F.mse_loss(u_pred, u_tgt)

    # 优化
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    return loss


# ==============================================================================
# 4. 一步采样(核心!1-NFE)
# ==============================================================================
@torch.no_grad()
def meanflow_sample(model, n_samples, dim, device):
    z0 = torch.randn(n_samples, dim, device=device)
    r = torch.zeros(n_samples, device=device)
    t = torch.ones(n_samples, device=device)
    u = model(z0, r, t)
    z1 = z0 + u
    return z1


# ==============================================================================
# 5. Toy 实验:双高斯分布生成(可直接运行)
# ==============================================================================
if __name__ == "__main__":
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # 模型
    model = MeanFlowModel(input_dim=2, hidden_dim=256).to(device)
    opt = torch.optim.Adam(model.parameters(), lr=1e-3)

    # 真实数据分布:两个高斯聚类
    def data_sampler(batch):
        x = torch.randn(batch, 2, device=device)
        c = torch.randint(0, 2, (batch,), device=device)
        x[c == 0] += torch.tensor([3.0, 3.0], device=device)
        x[c == 1] -= torch.tensor([3.0, 3.0], device=device)
        return x

    # 训练
    print("Start training...")
    for step in range(10000):
        x0 = torch.randn(256, 2, device=device)
        x1 = data_sampler(256)
        loss = train_one_step(model, opt, x0, x1)
        if step % 500 == 0:
            print(f"Step {step:05d} | Loss {loss:.4f}")

    # 一步采样(仅1次前向)
    samples = meanflow_sample(model, n_samples=1000, dim=2, device=device)
    print("\nSampled points (first 5):")
    print(samples[:5])
模型学习目标采样步数训练复杂度理论
Flow Matching瞬时速度 vv≥20 步干净
Consistency Model自一致性1 步极高(蒸馏)启发式
MeanFlow平均速度 uu1 步完全严格
  • 图像/视频/音频生成
  • 蛋白质结构生成
  • 高分辨率实时 AIGC
  • 端侧部署(低算力、低延迟)
  • 需要单步快速生成的工业场景

相关内容