目录

Flow Matching Guide and Code(项目解析)

本文档为 flow_matching 项目的完整技术文档,涵盖项目结构、算法原理、核心代码解析及使用方式。 项目地址:https://github.com/facebookresearch/flow_matching 论文:Flow Matching Guide and Code (arXiv:2412.06264)


  1. 项目概述
  2. 项目结构
  3. 安装与环境配置
  4. 算法原理
  5. 核心模块详解
  6. 使用指南与代码示例
  7. 依赖关系

flow_matching 是由 Meta (Facebook Research) 开发的 PyTorch 库,实现了 Flow Matching 系列生成模型算法。 该库支持三种核心范式:

  • 连续 Flow Matching:在欧几里得空间中通过仿射概率路径学习速度场
  • 离散 Flow Matching:在离散状态空间(如文本 token)中通过混合概率路径学习转移概率
  • 黎曼 Flow Matching:在流形(球面、环面等)上通过测地线路径学习速度场

核心依赖:Python ≥ 3.9、PyTorch ≥ 2.1、torchdiffeq、numpy。当前版本:1.0.10


flow_matching/                     # 核心库
├── __init__.py                    # 版本号定义 (__version__ = "1.0.10")
├── path/                          # 概率路径模块
│   ├── path.py                    # ProbPath 抽象基类
│   ├── path_sample.py             # PathSample / DiscretePathSample 数据类
│   ├── affine.py                  # AffineProbPath(仿射路径)、CondOTProbPath
│   ├── geodesic.py                # GeodesicProbPath(测地线路径)
│   ├── mixture.py                 # MixtureDiscreteProbPath(离散混合路径)
│   └── scheduler/                 # 调度器子模块
│       ├── scheduler.py           # Scheduler 基类及多种调度器实现
│       └── schedule_transform.py  # ScheduleTransformedModel(调度器变换)
├── loss/                          # 损失函数模块
│   └── generalized_loss.py        # MixturePathGeneralizedKL(广义 KL 损失)
├── solver/                        # 求解器模块
│   ├── solver.py                  # Solver 抽象基类
│   ├── ode_solver.py              # ODESolver(连续 ODE 求解器)
│   ├── discrete_solver.py         # MixtureDiscreteEulerSolver(离散 Euler 求解器)
│   ├── riemannian_ode_solver.py   # RiemannianODESolver(黎曼 ODE 求解器)
│   └── utils.py                   # 求解器工具函数
└── utils/                         # 工具模块
    ├── utils.py                   # expand_tensor_like, gradient, unsqueeze_to_match
    ├── model_wrapper.py           # ModelWrapper 抽象基类
    ├── categorical_sampler.py     # categorical 采样器
    └── manifolds/                 # 流形子模块
        ├── manifold.py            # Manifold 抽象基类、Euclidean
        ├── sphere.py              # Sphere(超球面)
        ├── torus.py               # FlatTorus(平坦环面)
        └── utils.py               # geodesic 测地线工具函数

examples/                          # 示例代码
├── 2d_flow_matching.ipynb                          # 2D 连续 Flow Matching
├── 2d_discrete_flow_matching.ipynb                 # 2D 离散 Flow Matching
├── 2d_riemannian_flow_matching_flat_torus.ipynb    # 2D 黎曼 FM(环面)
├── 2d_riemannian_flow_matching_sphere.ipynb        # 2D 黎曼 FM(球面)
├── 2d_cnf_maximum_likelihood.ipynb                 # 2D CNF 最大似然
├── standalone_flow_matching.ipynb                  # 独立连续 FM 示例
├── standalone_discrete_flow_matching.ipynb         # 独立离散 FM 示例
├── image/                         # 图像生成示例(CIFAR10 / ImageNet)
│   ├── train.py                   # 训练入口
│   ├── models/                    # UNet / Discrete UNet 模型
│   └── training/                  # 训练循环、数据变换、分布式等
└── text/                          # 文本生成示例
    ├── train.py                   # 训练入口
    ├── model/                     # Transformer + Rotary Embedding
    ├── data/                      # 数据加载与 Tokenizer
    └── logic/                     # 训练/生成/评估逻辑
utils (基础工具)
  ├── manifolds (流形定义)
  └── model_wrapper (模型封装)
path (概率路径)  ←── scheduler (调度器)
  ├── loss (损失函数,依赖 path)
  └── solver (求解器,依赖 utils.model_wrapper)

pip install flow_matching
conda env create -f environment.yml
conda activate flow_matching
pip install -e .          # 可编辑模式安装
pre-commit install        # 安装代码规范检查钩子
依赖用途
torch (≥2.1)张量计算与自动微分
numpy数值计算
torchdiffeqODE 数值积分(Euler、Dopri5 等)

Flow Matching 的核心思想是学习一个时间依赖的速度场 ut(x)u_t(x),使得沿该速度场的 ODE 流能将源分布 p0p_0(通常为高斯噪声)变换为目标分布 p1p_1(数据分布)。

给定源样本 X0p0X_0 \sim p_0 和目标样本 X1p1X_1 \sim p_1,定义仿射条件概率路径:

Xt=αtX1+σtX0X_t = \alpha_t X_1 + \sigma_t X_0

其中 αt\alpha_tσt\sigma_t 由调度器(Scheduler)控制。条件速度场为:

X˙t=α˙tX1+σ˙tX0\dot{X}_t = \dot{\alpha}_t X_1 + \dot{\sigma}_t X_0

训练目标是最小化模型预测速度与条件速度之间的 MSE:

L=Et,X0,X1[vθ(Xt,t)X˙t2]\mathcal{L} = \mathbb{E}_{t, X_0, X_1} \left[ \| v_\theta(X_t, t) - \dot{X}_t \|^2 \right]

仿射路径的核心实现在 flow_matching/path/affine.pyAffineProbPath.sample() 中:

# AffineProbPath.sample() 核心逻辑
scheduler_output = self.scheduler(t)  # 获取 α_t, σ_t 及其导数

alpha_t = expand_tensor_like(input_tensor=scheduler_output.alpha_t, expand_to=x_1)
sigma_t = expand_tensor_like(input_tensor=scheduler_output.sigma_t, expand_to=x_1)
d_alpha_t = expand_tensor_like(input_tensor=scheduler_output.d_alpha_t, expand_to=x_1)
d_sigma_t = expand_tensor_like(input_tensor=scheduler_output.d_sigma_t, expand_to=x_1)

# 构造 X_t = σ_t * X_0 + α_t * X_1
x_t = sigma_t * x_0 + alpha_t * x_1
# 条件速度 dX_t = dσ_t * X_0 + dα_t * X_1
dx_t = d_sigma_t * x_0 + d_alpha_t * x_1

return PathSample(x_t=x_t, dx_t=dx_t, x_1=x_1, x_0=x_0, t=t)

最简单也最常用的调度器是条件最优传输调度器,定义为:

αt=t,σt=1t\alpha_t = t, \quad \sigma_t = 1 - t

此时路径为直线插值 Xt=(1t)X0+tX1X_t = (1-t)X_0 + tX_1,速度为常数 X˙t=X1X0\dot{X}_t = X_1 - X_0

class CondOTScheduler(ConvexScheduler):
    def __call__(self, t: Tensor) -> SchedulerOutput:
        return SchedulerOutput(
            alpha_t=t,
            sigma_t=1 - t,
            d_alpha_t=torch.ones_like(t),
            d_sigma_t=-torch.ones_like(t),
        )

AffineProbPath 提供了六种表示之间的相互转换方法,这在不同训练目标之间切换时非常有用:

方法输入 → 输出公式
target_to_velocityX1X˙tX_1 \to \dot{X}_tX˙t=σ˙tσtXt+α˙tσtσ˙tαtσtX1\dot{X}_t = \frac{\dot{\sigma}_t}{\sigma_t} X_t + \frac{\dot{\alpha}_t \sigma_t - \dot{\sigma}_t \alpha_t}{\sigma_t} X_1
epsilon_to_velocityϵX˙t\epsilon \to \dot{X}_tX˙t=α˙tαtXt+σ˙tαtα˙tσtαtϵ\dot{X}_t = \frac{\dot{\alpha}_t}{\alpha_t} X_t + \frac{\dot{\sigma}_t \alpha_t - \dot{\alpha}_t \sigma_t}{\alpha_t} \epsilon
velocity_to_targetX˙tX1\dot{X}_t \to X_1上式的逆变换
epsilon_to_targetϵX1\epsilon \to X_1X1=1αtXtσtαtϵX_1 = \frac{1}{\alpha_t} X_t - \frac{\sigma_t}{\alpha_t} \epsilon
velocity_to_epsilonX˙tϵ\dot{X}_t \to \epsilon速度到噪声的转换
target_to_epsilonX1ϵX_1 \to \epsilonϵ=1σtXtαtσtX1\epsilon = \frac{1}{\sigma_t} X_t - \frac{\alpha_t}{\sigma_t} X_1

离散 Flow Matching 将 Flow Matching 框架扩展到离散状态空间(如文本 token),使用连续时间马尔可夫链(CTMC)代替 ODE。

在离散空间 S=[K]d\mathcal{S} = [K]^d 上,混合概率路径定义为:

P(Xt=X0)=σt,P(Xt=X1)=1σtP(X_t = X_0) = \sigma_t, \quad P(X_t = X_1) = 1 - \sigma_t

即在时间 tt,每个坐标以概率 σt\sigma_t 保持为源值 X0X_0,以概率 1σt1-\sigma_t 翻转为目标值 X1X_1

条件概率速度场为:

uti(xi,yix1i)=κ˙t1κt[δx1i(xi)δyi(xi)]u_t^i(x^i, y^i | x_1^i) = \frac{\dot{\kappa}_t}{1 - \kappa_t} \left[ \delta_{x_1^i}(x^i) - \delta_{y^i}(x^i) \right]

离散路径采样在 flow_matching/path/mixture.py 中实现:

class MixtureDiscreteProbPath(ProbPath):
    def sample(self, x_0, x_1, t) -> DiscretePathSample:
        sigma_t = self.scheduler(t).sigma_t
        sigma_t = expand_tensor_like(input_tensor=sigma_t, expand_to=x_1)

        # 每个坐标独立地以概率 σ_t 保持为 X_0,否则翻转为 X_1
        source_indices = torch.rand(size=x_1.shape, device=x_1.device) < sigma_t
        x_t = torch.where(condition=source_indices, input=x_0, other=x_1)

        return DiscretePathSample(x_t=x_t, x_1=x_1, x_0=x_0, t=t)

后验到速度的转换:

def posterior_to_velocity(self, posterior_logits, x_t, t):
    posterior = torch.softmax(posterior_logits, dim=-1)
    x_t = F.one_hot(x_t, num_classes=vocabulary_size)

    scheduler_output = self.scheduler(t)
    kappa_t = scheduler_output.alpha_t
    d_kappa_t = scheduler_output.d_alpha_t

    # u_t = (dκ_t / (1 - κ_t)) * (posterior - x_t)
    return (d_kappa_t / (1 - kappa_t)) * (posterior - x_t)

离散 Flow Matching 使用广义 KL 散度作为训练损失(flow_matching/loss/generalized_loss.py):

i(x1,xt,t)=κ˙t1κt[p1t(xtixt)δx1i(xti)+(1δx1i(xti))logp1t(x1ixt)]\ell_i(x_1, x_t, t) = -\frac{\dot{\kappa}_t}{1-\kappa_t} \left[ p_{1|t}(x_t^i|x_t) - \delta_{x_1^i}(x_t^i) + (1-\delta_{x_1^i}(x_t^i)) \log p_{1|t}(x_1^i|x_t) \right]
class MixturePathGeneralizedKL(_Loss):
    def forward(self, logits, x_1, x_t, t):
        # 提取 log p_{1|t}(x_1|x_t)
        log_p_1t = torch.log_softmax(logits, dim=-1)
        log_p_1t_x1 = torch.gather(log_p_1t, dim=-1, index=x_1.unsqueeze(-1))

        # 提取 p_{1|t}(x_t|x_t)
        p_1t = torch.exp(log_p_1t)
        p_1t_xt = torch.gather(p_1t, dim=-1, index=x_t.unsqueeze(-1))

        # 计算跳跃系数 dκ_t / (1 - κ_t)
        scheduler_output = self.path.scheduler(t)
        jump_coefficient = scheduler_output.d_alpha_t / (1 - scheduler_output.alpha_t)

        delta_x1_xt = (x_t == x_1).to(log_p_1t.dtype)

        loss = -jump_coefficient * (
            p_1t_xt - delta_x1_xt + (1 - delta_x1_xt) * log_p_1t_x1
        )
        return torch.mean(loss)  # 默认 reduction='mean'

黎曼 Flow Matching 将 Flow Matching 扩展到非欧几里得流形上,使用测地线插值代替线性插值。

在流形 M\mathcal{M} 上,测地线概率路径定义为:

Xt=ψt(X0X1)=expX1(κtlogX1(X0))X_t = \psi_t(X_0 | X_1) = \exp_{X_1}(\kappa_t \log_{X_1}(X_0))

其中 exp\explog\log 分别是流形上的指数映射和对数映射,κt\kappa_t 是调度器参数。

测地线路径在 flow_matching/path/geodesic.py 中实现:

class GeodesicProbPath(ProbPath):
    def __init__(self, scheduler: ConvexScheduler, manifold: Manifold):
        self.scheduler = scheduler
        self.manifold = manifold

    def sample(self, x_0, x_1, t) -> PathSample:
        def cond_u(x_0, x_1, t):
            # 构造测地线路径函数
            path = geodesic(self.manifold, x_0, x_1)
            # 用 JVP(Jacobian-Vector Product)自动计算速度
            x_t, dx_t = jvp(
                lambda t: path(self.scheduler(t).alpha_t),
                (t,),
                (torch.ones_like(t).to(t),),
            )
            return x_t, dx_t

        # 使用 vmap 对 batch 维度进行向量化
        x_t, dx_t = vmap(cond_u)(x_0, x_1, t)
        return PathSample(x_t=x_t, dx_t=dx_t, x_1=x_1, x_0=x_0, t=t)

测地线工具函数(flow_matching/utils/manifolds/utils.py):

def geodesic(manifold, start_point, end_point):
    """生成参数化的测地线曲线函数"""
    shooting_tangent_vec = manifold.logmap(start_point, end_point)

    def path(t):
        tangent_vecs = torch.einsum("i,...k->...ik", t, shooting_tangent_vec)
        points_at_time_t = manifold.expmap(start_point.unsqueeze(-2), tangent_vecs)
        return points_at_time_t

    return path
流形类名空间指数映射对数映射
欧几里得空间EuclideanRD\mathbb{R}^Dexpx(u)=x+u\exp_x(u) = x + ulogx(y)=yx\log_x(y) = y - x
超球面SphereSD1S^{D-1}球面指数映射球面对数映射
平坦环面FlatTorus[0,2π]D[0, 2\pi]^D(x+u)mod2π(x + u) \mod 2\piatan2(sin(yx),cos(yx))\text{atan2}(\sin(y-x), \cos(y-x))

球面的指数映射实现(flow_matching/utils/manifolds/sphere.py):

class Sphere(Manifold):
    def expmap(self, x, u):
        norm_u = u.norm(dim=-1, keepdim=True)
        exp = x * torch.cos(norm_u) + u * torch.sin(norm_u) / norm_u
        retr = self.projx(x + u)
        cond = norm_u > self.EPS[norm_u.dtype]
        return torch.where(cond, exp, retr)  # 数值稳定:小范数时退化为投影

    def logmap(self, x, y):
        u = self.proju(x, y - x)
        dist = self.dist(x, y, keepdim=True)
        cond = dist.gt(self.EPS[x.dtype])
        return torch.where(cond, u * dist / u.norm(dim=-1, keepdim=True).clamp_min(self.EPS[x.dtype]), u)

    def projx(self, x):
        return x / x.norm(dim=-1, keepdim=True)  # 投影到单位球面

    def proju(self, x, u):
        return u - (x * u).sum(dim=-1, keepdim=True) * x  # 投影到切平面

概率路径是 Flow Matching 的核心抽象,定义了如何在源分布和目标分布之间构造插值。

ProbPath (抽象基类)
├── AffineProbPath (仿射路径)
│   └── CondOTProbPath (条件最优传输路径)
├── GeodesicProbPath (测地线路径)
└── MixtureDiscreteProbPath (离散混合路径)

所有路径的 sample() 方法返回 PathSampleDiscretePathSample

@dataclass
class PathSample:
    x_1: Tensor   # 目标样本 X_1
    x_0: Tensor   # 源样本 X_0
    t: Tensor      # 时间 t
    x_t: Tensor    # 路径样本 X_t ~ p_t
    dx_t: Tensor   # 条件速度 dX_t/dt

@dataclass
class DiscretePathSample:
    x_1: Tensor   # 目标样本 X_1
    x_0: Tensor   # 源样本 X_0
    t: Tensor      # 时间 t
    x_t: Tensor    # 路径样本 X_t ~ p_t(无速度,因为是离散空间)

调度器控制概率路径的时间演化参数 αt\alpha_tσt\sigma_t

调度器αt\alpha_tσt\sigma_t特点
CondOTSchedulertt1t1-t最简单的线性插值,条件最优传输
PolynomialConvexScheduler(n)tnt^n1tn1-t^n多项式调度,nn 控制曲线形状
VPSchedulereT/2e^{-T/2}1eT\sqrt{1-e^{-T}}方差保持调度(VP-SDE 等价)
LinearVPSchedulertt1t2\sqrt{1-t^2}线性方差保持
CosineSchedulersin(π2t)\sin(\frac{\pi}{2}t)cos(π2t)\cos(\frac{\pi}{2}t)余弦调度

所有调度器都实现了 snr_inverse 方法,用于从信噪比 SNR=αt/σt\text{SNR} = \alpha_t / \sigma_t 反推时间 tt,这在调度器变换中至关重要。

ScheduleTransformedModel 允许在训练后更换调度器而无需重新训练模型。其核心是尺度-时间(ST)变换:

Xˉr=srXtr,tr=ρ1(ρˉ(r)),sr=σˉrσtr\bar{X}_r = s_r X_{t_r}, \quad t_r = \rho^{-1}(\bar{\rho}(r)), \quad s_r = \frac{\bar{\sigma}_r}{\sigma_{t_r}}

变换后的速度场为:

uˉr(x)=s˙rsrx+srt˙rutr(xsr)\bar{u}_r(x) = \frac{\dot{s}_r}{s_r} x + s_r \dot{t}_r u_{t_r}\left(\frac{x}{s_r}\right)
class ScheduleTransformedModel(ModelWrapper):
    def forward(self, x, t, **extras):
        r = t
        # 新调度器参数
        r_out = self.new_scheduler(t=r)
        # 通过 SNR 反推原始时间
        t = self.original_scheduler.snr_inverse(r_out.alpha_t / r_out.sigma_t)
        # 原始调度器参数
        t_out = self.original_scheduler(t=t)
        # 计算尺度因子和时间导数
        s_r = r_out.sigma_t / t_out.sigma_t
        dt_r = ...  # 时间映射的导数
        ds_r = ...  # 尺度因子的导数
        # 变换速度场
        u_t = self.model(x=x / s_r, t=t, **extras)
        u_r = ds_r * x / s_r + dt_r * s_r * u_t
        return u_r

专为离散 Flow Matching 设计的广义 KL 散度损失。假设模型以 x-prediction 方式训练(即模型输出 p1t(xt)p_{1|t}(\cdot|x_t))。

接口:

loss_fn = MixturePathGeneralizedKL(path=my_discrete_path, reduction='mean')
loss = loss_fn(logits=model_output, x_1=target, x_t=path_sample, t=time)

参数说明:

  • logits:模型输出的 logits,形状 (batch, d, K),其中 K 为词表大小
  • x_1:目标数据,形状 (batch, d)
  • x_t:路径采样点,形状 (batch, d)
  • t:时间,形状 (batch,)

对于连续 Flow Matching,通常直接使用 torch.nn.MSELoss 计算速度匹配损失。

求解器负责在推理阶段从源分布生成目标分布的样本。

Solver (抽象基类, nn.Module)
├── ODESolver (连续 ODE 求解器)
├── MixtureDiscreteEulerSolver (离散 Euler 求解器)
└── RiemannianODESolver (黎曼 ODE 求解器)

基于 torchdiffeq 的通用 ODE 求解器,支持多种数值方法。

solver = ODESolver(velocity_model=my_model)

# 基本采样:从 X_0 ~ p_0 生成 X_1 ~ p_1
x_1 = solver.sample(
    x_init=x_0,                              # 初始噪声
    step_size=1/1000,                         # 步长
    method="euler",                           # 积分方法
    time_grid=torch.tensor([0.0, 1.0]),       # 时间区间
)

支持的积分方法:

方法类型说明
euler固定步长一阶 Euler 方法,最快但精度最低
midpoint固定步长二阶中点法
heun3固定步长三阶 Heun 方法
dopri5自适应步长五阶 Dormand-Prince,精度高但较慢

ODESolver.compute_likelihood() 通过反向积分 ODE 并计算雅可比行列式的迹来计算精确的 log 似然:

logp1(x1)=logp0(x0)+10div(ut(xt))dt\log p_1(x_1) = \log p_0(x_0) + \int_1^0 \text{div}(u_t(x_t)) \, dt
x_0, log_likelihood = solver.compute_likelihood(
    x_1=data_samples,
    log_p0=lambda x: -0.5 * x.pow(2).sum(-1),  # 标准高斯的 log 概率
    step_size=1/1000,
    time_grid=torch.tensor([1.0, 0.0]),          # 必须从 1 积到 0
    exact_divergence=False,                       # 使用 Hutchinson 估计器
)

散度计算有两种模式:

  • exact_divergence=True:精确计算 div(ut)=iutixi\text{div}(u_t) = \sum_i \frac{\partial u_t^i}{\partial x^i},计算量为 O(D)O(D) 次反向传播
  • exact_divergence=False:Hutchinson 估计器 div(ut)zTx(utz)\text{div}(u_t) \approx z^T \nabla_x (u_t \cdot z),仅需 1 次反向传播

离散空间的 CTMC 模拟器,实现了带无散度项的 Euler 步进:

solver = MixtureDiscreteEulerSolver(
    model=my_model,
    path=my_discrete_path,
    vocabulary_size=256,
    source_distribution_p=uniform_dist,  # 可选:用于无散度项
)

x_1 = solver.sample(
    x_init=x_0,
    step_size=1/1000,
    div_free=0.0,          # 无散度项系数,0 表示不使用
    time_grid=torch.tensor([0.0, 1.0]),
)

每步的核心逻辑:

  1. 从模型采样 X1p1t(Xt)X_1 \sim p_{1|t}(\cdot|X_t)
  2. 计算条件速度 ut(xXt,X1)u_t(x|X_t, X_1)
  3. 计算跳跃强度 λ=xXtut(xXt,X1)\lambda = \sum_{x \neq X_t} u_t(x|X_t, X_1)
  4. 以概率 1ehλ1 - e^{-h\lambda} 发生跳跃,跳跃目标按 utu_t 归一化后采样

流形上的 ODE 求解器,支持 Euler、中点法和 RK4,每步都可选择性地将状态投影回流形、将速度投影到切平面:

solver = RiemannianODESolver(
    manifold=Sphere(),
    velocity_model=my_model,
)

x_1 = solver.sample(
    x_init=x_0,
    step_size=0.01,
    method="rk4",     # euler / midpoint / rk4
    projx=True,       # 每步投影到流形
    proju=True,       # 速度投影到切平面
)

RK4 步进的流形版本(_rk4_step):

def _rk4_step(velocity_model, xt, t0, dt, manifold, projx=True, proju=True):
    velocity_fn = lambda x, t: (
        manifold.proju(x, velocity_model(x, t)) if proju else velocity_model(x, t)
    )
    projx_fn = lambda x: manifold.projx(x) if projx else x

    k1 = velocity_fn(xt, t0)
    k2 = velocity_fn(projx_fn(xt + dt * k1 / 3), t0 + dt / 3)
    k3 = velocity_fn(projx_fn(xt + dt * (k2 - k1 / 3)), t0 + dt * 2 / 3)
    k4 = velocity_fn(projx_fn(xt + dt * (k1 - k2 + k3)), t0 + dt)

    return projx_fn(xt + (k1 + 3 * (k2 + k3) + k4) * dt * 0.125)

所有速度场模型必须继承 ModelWrapper,统一接口为 forward(x, t, **extras)

class ModelWrapper(ABC, nn.Module):
    def __init__(self, model: nn.Module):
        super().__init__()
        self.model = model

    def forward(self, x: Tensor, t: Tensor, **extras) -> Tensor:
        return self.model(x=x, t=t, **extras)

自定义模型示例:

class MyVelocityModel(ModelWrapper):
    def __init__(self, net):
        super().__init__(net)

    def forward(self, x, t, **extras):
        # 可在此添加时间编码、条件拼接等自定义逻辑
        t_embed = self.time_embedding(t)
        return self.model(torch.cat([x, t_embed], dim=-1))
# expand_tensor_like: 将 1D 张量扩展到与目标张量相同的形状
# 用途:将 (batch_size,) 的时间/调度器参数扩展到 (batch_size, C, H, W) 等
alpha_t = expand_tensor_like(input_tensor=scheduler_output.alpha_t, expand_to=x_1)

# unsqueeze_to_match: 自动添加维度使源张量与目标张量维度匹配
t = unsqueeze_to_match(source=t, target=x_t)

# gradient: 计算梯度的封装,用于似然计算中的散度估计
grad = gradient(output, x, create_graph=True)

基于 torch.multinomial 的分类采样器,支持任意形状的概率张量:

def categorical(probs: Tensor) -> Tensor:
    return torch.multinomial(
        probs.flatten(0, -2), 1, replacement=True
    ).view(*probs.shape[:-1])

import torch
from flow_matching.path import AffineProbPath
from flow_matching.path.scheduler import CondOTScheduler
from flow_matching.solver import ODESolver
from flow_matching.utils import ModelWrapper

# 1. 定义模型
class SimpleVelocityNet(ModelWrapper):
    def __init__(self, dim):
        net = torch.nn.Sequential(
            torch.nn.Linear(dim + 1, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, dim),
        )
        super().__init__(net)

    def forward(self, x, t, **extras):
        t_expanded = t.unsqueeze(-1) if t.dim() == 1 else t
        inp = torch.cat([x, t_expanded], dim=-1)
        return self.model(inp)

# 2. 初始化路径和模型
dim = 2
path = AffineProbPath(scheduler=CondOTScheduler())
model = SimpleVelocityNet(dim)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
mse_loss = torch.nn.MSELoss()

# 3. 训练循环
for epoch in range(1000):
    # 源分布:标准高斯
    x_0 = torch.randn(256, dim)
    # 目标分布:你的数据
    x_1 = sample_from_data(batch_size=256)
    # 随机时间
    t = torch.rand(256)

    # 采样条件路径
    path_sample = path.sample(x_0=x_0, x_1=x_1, t=t)

    # 计算速度匹配损失
    predicted_velocity = model(path_sample.x_t, path_sample.t)
    loss = mse_loss(predicted_velocity, path_sample.dx_t)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

# 4. 推理:生成样本
solver = ODESolver(velocity_model=model)
x_0 = torch.randn(100, dim)
x_1 = solver.sample(
    x_init=x_0,
    step_size=0.01,
    method="midpoint",
    time_grid=torch.tensor([0.0, 1.0]),
)
import torch
from flow_matching.path import MixtureDiscreteProbPath
from flow_matching.path.scheduler import PolynomialConvexScheduler
from flow_matching.loss import MixturePathGeneralizedKL
from flow_matching.solver import MixtureDiscreteEulerSolver

# 1. 初始化
vocab_size = 256
scheduler = PolynomialConvexScheduler(n=1.0)
path = MixtureDiscreteProbPath(scheduler=scheduler)
loss_fn = MixturePathGeneralizedKL(path=path)

# 2. 训练循环
for x_1 in dataloader:  # x_1: (batch, seq_len) 的整数 token
    x_0 = torch.randint(0, vocab_size, x_1.shape)  # 均匀随机源
    t = torch.rand(x_1.shape[0])

    path_sample = path.sample(x_0=x_0, x_1=x_1, t=t)

    # 模型输出 logits: (batch, seq_len, vocab_size)
    logits = model(path_sample.x_t, path_sample.t)
    loss = loss_fn(logits=logits, x_1=x_1, x_t=path_sample.x_t, t=t)

    loss.backward()
    optimizer.step()

# 3. 推理
solver = MixtureDiscreteEulerSolver(
    model=model, path=path, vocabulary_size=vocab_size
)
x_0 = torch.randint(0, vocab_size, (batch_size, seq_len))
x_1 = solver.sample(x_init=x_0, step_size=1/1000)
import torch
from flow_matching.path import GeodesicProbPath
from flow_matching.path.scheduler import CondOTScheduler
from flow_matching.solver import RiemannianODESolver
from flow_matching.utils.manifolds import Sphere

# 1. 初始化
manifold = Sphere()
scheduler = CondOTScheduler()
path = GeodesicProbPath(scheduler=scheduler, manifold=manifold)

# 2. 训练
for x_1 in dataloader:  # x_1: 球面上的数据点
    x_0 = manifold.projx(torch.randn_like(x_1))  # 投影到球面
    t = torch.rand(x_1.shape[0])

    path_sample = path.sample(x_0=x_0, x_1=x_1, t=t)

    predicted_velocity = model(path_sample.x_t, path_sample.t)
    loss = mse_loss(predicted_velocity, path_sample.dx_t)
    loss.backward()

# 3. 推理
solver = RiemannianODESolver(manifold=manifold, velocity_model=model)
x_0 = manifold.projx(torch.randn(100, 3))
x_1 = solver.sample(x_init=x_0, step_size=0.01, method="rk4")
from flow_matching.path.scheduler import CondOTScheduler, CosineScheduler, ScheduleTransformedModel

# 模型原本用 CondOT 调度器训练
original_scheduler = CondOTScheduler()
new_scheduler = CosineScheduler()

# 无需重新训练,直接包装模型
transformed_model = ScheduleTransformedModel(
    velocity_model=trained_model,
    original_scheduler=original_scheduler,
    new_scheduler=new_scheduler,
)

# 用新调度器进行推理
solver = ODESolver(velocity_model=transformed_model)
x_1 = solver.sample(x_init=x_0, step_size=1/1000)
import torch
from flow_matching.solver import ODESolver

solver = ODESolver(velocity_model=trained_model)

# 定义源分布的 log 概率(标准高斯)
log_p0 = lambda x: -0.5 * (x.pow(2).sum(-1) + x.shape[-1] * torch.log(torch.tensor(2 * torch.pi)))

# 计算数据点的 log 似然
x_0_recovered, log_likelihood = solver.compute_likelihood(
    x_1=data_samples,
    log_p0=log_p0,
    step_size=1/1000,
    time_grid=torch.tensor([1.0, 0.0]),
    exact_divergence=False,  # Hutchinson 估计器更高效
)
# log_likelihood: 每个样本的 log p_1(x_1)

flow_matching
├── numpy          # 数值计算基础
├── torch (≥2.1)   # 张量计算、自动微分、神经网络
└── torchdiffeq    # ODE 数值积分(odeint 接口)
dev:
├── pre-commit     # Git 提交前代码检查
├── black          # 代码格式化
├── usort / ufmt   # import 排序与格式化
├── flake8         # 代码风格检查
└── pydoclint      # 文档字符串检查
  • 图像示例:torchvision, submitit(分布式训练)
  • 文本示例:hydra-core, wandb(实验管理)
  • Notebook:matplotlib, jupyter, scikit-learn, tqdm

本文档基于 flow_matching v1.0.10 源码生成。 论文引用:Lipman et al., “Flow Matching Guide and Code”, arXiv:2412.06264, 2024.

相关内容