Flow Matching Guide and Code(项目解析)
本文档为
flow_matching项目的完整技术文档,涵盖项目结构、算法原理、核心代码解析及使用方式。 项目地址:https://github.com/facebookresearch/flow_matching 论文:Flow Matching Guide and Code (arXiv:2412.06264)
目录
- 项目概述
- 项目结构
- 安装与环境配置
- 算法原理
- 4.1 连续 Flow Matching
- 4.2 离散 Flow Matching
- 4.3 黎曼 Flow Matching
- 核心模块详解
- 5.1 概率路径 (path)
- 5.2 调度器 (scheduler)
- 5.3 损失函数 (loss)
- 5.4 求解器 (solver)
- 5.5 工具模块 (utils)
- 使用指南与代码示例
- 依赖关系
1. 项目概述
flow_matching 是由 Meta (Facebook Research) 开发的 PyTorch 库,实现了 Flow Matching 系列生成模型算法。
该库支持三种核心范式:
- 连续 Flow Matching:在欧几里得空间中通过仿射概率路径学习速度场
- 离散 Flow Matching:在离散状态空间(如文本 token)中通过混合概率路径学习转移概率
- 黎曼 Flow Matching:在流形(球面、环面等)上通过测地线路径学习速度场
核心依赖:Python ≥ 3.9、PyTorch ≥ 2.1、torchdiffeq、numpy。当前版本:1.0.10。
2. 项目结构
flow_matching/ # 核心库
├── __init__.py # 版本号定义 (__version__ = "1.0.10")
├── path/ # 概率路径模块
│ ├── path.py # ProbPath 抽象基类
│ ├── path_sample.py # PathSample / DiscretePathSample 数据类
│ ├── affine.py # AffineProbPath(仿射路径)、CondOTProbPath
│ ├── geodesic.py # GeodesicProbPath(测地线路径)
│ ├── mixture.py # MixtureDiscreteProbPath(离散混合路径)
│ └── scheduler/ # 调度器子模块
│ ├── scheduler.py # Scheduler 基类及多种调度器实现
│ └── schedule_transform.py # ScheduleTransformedModel(调度器变换)
├── loss/ # 损失函数模块
│ └── generalized_loss.py # MixturePathGeneralizedKL(广义 KL 损失)
├── solver/ # 求解器模块
│ ├── solver.py # Solver 抽象基类
│ ├── ode_solver.py # ODESolver(连续 ODE 求解器)
│ ├── discrete_solver.py # MixtureDiscreteEulerSolver(离散 Euler 求解器)
│ ├── riemannian_ode_solver.py # RiemannianODESolver(黎曼 ODE 求解器)
│ └── utils.py # 求解器工具函数
└── utils/ # 工具模块
├── utils.py # expand_tensor_like, gradient, unsqueeze_to_match
├── model_wrapper.py # ModelWrapper 抽象基类
├── categorical_sampler.py # categorical 采样器
└── manifolds/ # 流形子模块
├── manifold.py # Manifold 抽象基类、Euclidean
├── sphere.py # Sphere(超球面)
├── torus.py # FlatTorus(平坦环面)
└── utils.py # geodesic 测地线工具函数
examples/ # 示例代码
├── 2d_flow_matching.ipynb # 2D 连续 Flow Matching
├── 2d_discrete_flow_matching.ipynb # 2D 离散 Flow Matching
├── 2d_riemannian_flow_matching_flat_torus.ipynb # 2D 黎曼 FM(环面)
├── 2d_riemannian_flow_matching_sphere.ipynb # 2D 黎曼 FM(球面)
├── 2d_cnf_maximum_likelihood.ipynb # 2D CNF 最大似然
├── standalone_flow_matching.ipynb # 独立连续 FM 示例
├── standalone_discrete_flow_matching.ipynb # 独立离散 FM 示例
├── image/ # 图像生成示例(CIFAR10 / ImageNet)
│ ├── train.py # 训练入口
│ ├── models/ # UNet / Discrete UNet 模型
│ └── training/ # 训练循环、数据变换、分布式等
└── text/ # 文本生成示例
├── train.py # 训练入口
├── model/ # Transformer + Rotary Embedding
├── data/ # 数据加载与 Tokenizer
└── logic/ # 训练/生成/评估逻辑模块依赖关系图
utils (基础工具)
├── manifolds (流形定义)
└── model_wrapper (模型封装)
│
path (概率路径) ←── scheduler (调度器)
│
├── loss (损失函数,依赖 path)
└── solver (求解器,依赖 utils.model_wrapper)3. 安装与环境配置
快速安装(pip)
pip install flow_matching开发环境(conda)
conda env create -f environment.yml
conda activate flow_matching
pip install -e . # 可编辑模式安装
pre-commit install # 安装代码规范检查钩子核心依赖
| 依赖 | 用途 |
|---|---|
torch (≥2.1) | 张量计算与自动微分 |
numpy | 数值计算 |
torchdiffeq | ODE 数值积分(Euler、Dopri5 等) |
4. 算法原理
4.1 连续 Flow Matching
Flow Matching 的核心思想是学习一个时间依赖的速度场 ,使得沿该速度场的 ODE 流能将源分布 (通常为高斯噪声)变换为目标分布 (数据分布)。
核心公式
给定源样本 和目标样本 ,定义仿射条件概率路径:
其中 和 由调度器(Scheduler)控制。条件速度场为:
训练目标是最小化模型预测速度与条件速度之间的 MSE:
对应代码
仿射路径的核心实现在 flow_matching/path/affine.py 的 AffineProbPath.sample() 中:
# AffineProbPath.sample() 核心逻辑
scheduler_output = self.scheduler(t) # 获取 α_t, σ_t 及其导数
alpha_t = expand_tensor_like(input_tensor=scheduler_output.alpha_t, expand_to=x_1)
sigma_t = expand_tensor_like(input_tensor=scheduler_output.sigma_t, expand_to=x_1)
d_alpha_t = expand_tensor_like(input_tensor=scheduler_output.d_alpha_t, expand_to=x_1)
d_sigma_t = expand_tensor_like(input_tensor=scheduler_output.d_sigma_t, expand_to=x_1)
# 构造 X_t = σ_t * X_0 + α_t * X_1
x_t = sigma_t * x_0 + alpha_t * x_1
# 条件速度 dX_t = dσ_t * X_0 + dα_t * X_1
dx_t = d_sigma_t * x_0 + d_alpha_t * x_1
return PathSample(x_t=x_t, dx_t=dx_t, x_1=x_1, x_0=x_0, t=t)条件最优传输路径(CondOT)
最简单也最常用的调度器是条件最优传输调度器,定义为:
此时路径为直线插值 ,速度为常数 。
class CondOTScheduler(ConvexScheduler):
def __call__(self, t: Tensor) -> SchedulerOutput:
return SchedulerOutput(
alpha_t=t,
sigma_t=1 - t,
d_alpha_t=torch.ones_like(t),
d_sigma_t=-torch.ones_like(t),
)表示转换
AffineProbPath 提供了六种表示之间的相互转换方法,这在不同训练目标之间切换时非常有用:
| 方法 | 输入 → 输出 | 公式 |
|---|---|---|
target_to_velocity | ||
epsilon_to_velocity | ||
velocity_to_target | 上式的逆变换 | |
epsilon_to_target | ||
velocity_to_epsilon | 速度到噪声的转换 | |
target_to_epsilon |
4.2 离散 Flow Matching
离散 Flow Matching 将 Flow Matching 框架扩展到离散状态空间(如文本 token),使用连续时间马尔可夫链(CTMC)代替 ODE。
核心公式
在离散空间 上,混合概率路径定义为:
即在时间 ,每个坐标以概率 保持为源值 ,以概率 翻转为目标值 。
条件概率速度场为:
对应代码
离散路径采样在 flow_matching/path/mixture.py 中实现:
class MixtureDiscreteProbPath(ProbPath):
def sample(self, x_0, x_1, t) -> DiscretePathSample:
sigma_t = self.scheduler(t).sigma_t
sigma_t = expand_tensor_like(input_tensor=sigma_t, expand_to=x_1)
# 每个坐标独立地以概率 σ_t 保持为 X_0,否则翻转为 X_1
source_indices = torch.rand(size=x_1.shape, device=x_1.device) < sigma_t
x_t = torch.where(condition=source_indices, input=x_0, other=x_1)
return DiscretePathSample(x_t=x_t, x_1=x_1, x_0=x_0, t=t)后验到速度的转换:
def posterior_to_velocity(self, posterior_logits, x_t, t):
posterior = torch.softmax(posterior_logits, dim=-1)
x_t = F.one_hot(x_t, num_classes=vocabulary_size)
scheduler_output = self.scheduler(t)
kappa_t = scheduler_output.alpha_t
d_kappa_t = scheduler_output.d_alpha_t
# u_t = (dκ_t / (1 - κ_t)) * (posterior - x_t)
return (d_kappa_t / (1 - kappa_t)) * (posterior - x_t)广义 KL 损失
离散 Flow Matching 使用广义 KL 散度作为训练损失(flow_matching/loss/generalized_loss.py):
class MixturePathGeneralizedKL(_Loss):
def forward(self, logits, x_1, x_t, t):
# 提取 log p_{1|t}(x_1|x_t)
log_p_1t = torch.log_softmax(logits, dim=-1)
log_p_1t_x1 = torch.gather(log_p_1t, dim=-1, index=x_1.unsqueeze(-1))
# 提取 p_{1|t}(x_t|x_t)
p_1t = torch.exp(log_p_1t)
p_1t_xt = torch.gather(p_1t, dim=-1, index=x_t.unsqueeze(-1))
# 计算跳跃系数 dκ_t / (1 - κ_t)
scheduler_output = self.path.scheduler(t)
jump_coefficient = scheduler_output.d_alpha_t / (1 - scheduler_output.alpha_t)
delta_x1_xt = (x_t == x_1).to(log_p_1t.dtype)
loss = -jump_coefficient * (
p_1t_xt - delta_x1_xt + (1 - delta_x1_xt) * log_p_1t_x1
)
return torch.mean(loss) # 默认 reduction='mean'4.3 黎曼 Flow Matching
黎曼 Flow Matching 将 Flow Matching 扩展到非欧几里得流形上,使用测地线插值代替线性插值。
核心公式
在流形 上,测地线概率路径定义为:
其中 和 分别是流形上的指数映射和对数映射, 是调度器参数。
对应代码
测地线路径在 flow_matching/path/geodesic.py 中实现:
class GeodesicProbPath(ProbPath):
def __init__(self, scheduler: ConvexScheduler, manifold: Manifold):
self.scheduler = scheduler
self.manifold = manifold
def sample(self, x_0, x_1, t) -> PathSample:
def cond_u(x_0, x_1, t):
# 构造测地线路径函数
path = geodesic(self.manifold, x_0, x_1)
# 用 JVP(Jacobian-Vector Product)自动计算速度
x_t, dx_t = jvp(
lambda t: path(self.scheduler(t).alpha_t),
(t,),
(torch.ones_like(t).to(t),),
)
return x_t, dx_t
# 使用 vmap 对 batch 维度进行向量化
x_t, dx_t = vmap(cond_u)(x_0, x_1, t)
return PathSample(x_t=x_t, dx_t=dx_t, x_1=x_1, x_0=x_0, t=t)测地线工具函数(flow_matching/utils/manifolds/utils.py):
def geodesic(manifold, start_point, end_point):
"""生成参数化的测地线曲线函数"""
shooting_tangent_vec = manifold.logmap(start_point, end_point)
def path(t):
tangent_vecs = torch.einsum("i,...k->...ik", t, shooting_tangent_vec)
points_at_time_t = manifold.expmap(start_point.unsqueeze(-2), tangent_vecs)
return points_at_time_t
return path支持的流形
| 流形 | 类名 | 空间 | 指数映射 | 对数映射 |
|---|---|---|---|---|
| 欧几里得空间 | Euclidean | |||
| 超球面 | Sphere | 球面指数映射 | 球面对数映射 | |
| 平坦环面 | FlatTorus |
球面的指数映射实现(flow_matching/utils/manifolds/sphere.py):
class Sphere(Manifold):
def expmap(self, x, u):
norm_u = u.norm(dim=-1, keepdim=True)
exp = x * torch.cos(norm_u) + u * torch.sin(norm_u) / norm_u
retr = self.projx(x + u)
cond = norm_u > self.EPS[norm_u.dtype]
return torch.where(cond, exp, retr) # 数值稳定:小范数时退化为投影
def logmap(self, x, y):
u = self.proju(x, y - x)
dist = self.dist(x, y, keepdim=True)
cond = dist.gt(self.EPS[x.dtype])
return torch.where(cond, u * dist / u.norm(dim=-1, keepdim=True).clamp_min(self.EPS[x.dtype]), u)
def projx(self, x):
return x / x.norm(dim=-1, keepdim=True) # 投影到单位球面
def proju(self, x, u):
return u - (x * u).sum(dim=-1, keepdim=True) * x # 投影到切平面5. 核心模块详解
5.1 概率路径 (path)
概率路径是 Flow Matching 的核心抽象,定义了如何在源分布和目标分布之间构造插值。
类继承关系
ProbPath (抽象基类)
├── AffineProbPath (仿射路径)
│ └── CondOTProbPath (条件最优传输路径)
├── GeodesicProbPath (测地线路径)
└── MixtureDiscreteProbPath (离散混合路径)PathSample 数据结构
所有路径的 sample() 方法返回 PathSample 或 DiscretePathSample:
@dataclass
class PathSample:
x_1: Tensor # 目标样本 X_1
x_0: Tensor # 源样本 X_0
t: Tensor # 时间 t
x_t: Tensor # 路径样本 X_t ~ p_t
dx_t: Tensor # 条件速度 dX_t/dt
@dataclass
class DiscretePathSample:
x_1: Tensor # 目标样本 X_1
x_0: Tensor # 源样本 X_0
t: Tensor # 时间 t
x_t: Tensor # 路径样本 X_t ~ p_t(无速度,因为是离散空间)5.2 调度器 (scheduler)
调度器控制概率路径的时间演化参数 和 。
调度器类型一览
| 调度器 | 特点 | ||
|---|---|---|---|
CondOTScheduler | 最简单的线性插值,条件最优传输 | ||
PolynomialConvexScheduler(n) | 多项式调度, 控制曲线形状 | ||
VPScheduler | 方差保持调度(VP-SDE 等价) | ||
LinearVPScheduler | 线性方差保持 | ||
CosineScheduler | 余弦调度 |
所有调度器都实现了 snr_inverse 方法,用于从信噪比 反推时间 ,这在调度器变换中至关重要。
调度器变换(Schedule Transform)
ScheduleTransformedModel 允许在训练后更换调度器而无需重新训练模型。其核心是尺度-时间(ST)变换:
变换后的速度场为:
class ScheduleTransformedModel(ModelWrapper):
def forward(self, x, t, **extras):
r = t
# 新调度器参数
r_out = self.new_scheduler(t=r)
# 通过 SNR 反推原始时间
t = self.original_scheduler.snr_inverse(r_out.alpha_t / r_out.sigma_t)
# 原始调度器参数
t_out = self.original_scheduler(t=t)
# 计算尺度因子和时间导数
s_r = r_out.sigma_t / t_out.sigma_t
dt_r = ... # 时间映射的导数
ds_r = ... # 尺度因子的导数
# 变换速度场
u_t = self.model(x=x / s_r, t=t, **extras)
u_r = ds_r * x / s_r + dt_r * s_r * u_t
return u_r5.3 损失函数 (loss)
MixturePathGeneralizedKL
专为离散 Flow Matching 设计的广义 KL 散度损失。假设模型以 x-prediction 方式训练(即模型输出 )。
接口:
loss_fn = MixturePathGeneralizedKL(path=my_discrete_path, reduction='mean')
loss = loss_fn(logits=model_output, x_1=target, x_t=path_sample, t=time)参数说明:
logits:模型输出的 logits,形状(batch, d, K),其中 K 为词表大小x_1:目标数据,形状(batch, d)x_t:路径采样点,形状(batch, d)t:时间,形状(batch,)
对于连续 Flow Matching,通常直接使用 torch.nn.MSELoss 计算速度匹配损失。
5.4 求解器 (solver)
求解器负责在推理阶段从源分布生成目标分布的样本。
类继承关系
Solver (抽象基类, nn.Module)
├── ODESolver (连续 ODE 求解器)
├── MixtureDiscreteEulerSolver (离散 Euler 求解器)
└── RiemannianODESolver (黎曼 ODE 求解器)ODESolver
基于 torchdiffeq 的通用 ODE 求解器,支持多种数值方法。
solver = ODESolver(velocity_model=my_model)
# 基本采样:从 X_0 ~ p_0 生成 X_1 ~ p_1
x_1 = solver.sample(
x_init=x_0, # 初始噪声
step_size=1/1000, # 步长
method="euler", # 积分方法
time_grid=torch.tensor([0.0, 1.0]), # 时间区间
)支持的积分方法:
| 方法 | 类型 | 说明 |
|---|---|---|
euler | 固定步长 | 一阶 Euler 方法,最快但精度最低 |
midpoint | 固定步长 | 二阶中点法 |
heun3 | 固定步长 | 三阶 Heun 方法 |
dopri5 | 自适应步长 | 五阶 Dormand-Prince,精度高但较慢 |
似然计算
ODESolver.compute_likelihood() 通过反向积分 ODE 并计算雅可比行列式的迹来计算精确的 log 似然:
x_0, log_likelihood = solver.compute_likelihood(
x_1=data_samples,
log_p0=lambda x: -0.5 * x.pow(2).sum(-1), # 标准高斯的 log 概率
step_size=1/1000,
time_grid=torch.tensor([1.0, 0.0]), # 必须从 1 积到 0
exact_divergence=False, # 使用 Hutchinson 估计器
)散度计算有两种模式:
exact_divergence=True:精确计算 ,计算量为 次反向传播exact_divergence=False:Hutchinson 估计器 ,仅需 1 次反向传播
MixtureDiscreteEulerSolver
离散空间的 CTMC 模拟器,实现了带无散度项的 Euler 步进:
solver = MixtureDiscreteEulerSolver(
model=my_model,
path=my_discrete_path,
vocabulary_size=256,
source_distribution_p=uniform_dist, # 可选:用于无散度项
)
x_1 = solver.sample(
x_init=x_0,
step_size=1/1000,
div_free=0.0, # 无散度项系数,0 表示不使用
time_grid=torch.tensor([0.0, 1.0]),
)每步的核心逻辑:
- 从模型采样
- 计算条件速度
- 计算跳跃强度
- 以概率 发生跳跃,跳跃目标按 归一化后采样
RiemannianODESolver
流形上的 ODE 求解器,支持 Euler、中点法和 RK4,每步都可选择性地将状态投影回流形、将速度投影到切平面:
solver = RiemannianODESolver(
manifold=Sphere(),
velocity_model=my_model,
)
x_1 = solver.sample(
x_init=x_0,
step_size=0.01,
method="rk4", # euler / midpoint / rk4
projx=True, # 每步投影到流形
proju=True, # 速度投影到切平面
)RK4 步进的流形版本(_rk4_step):
def _rk4_step(velocity_model, xt, t0, dt, manifold, projx=True, proju=True):
velocity_fn = lambda x, t: (
manifold.proju(x, velocity_model(x, t)) if proju else velocity_model(x, t)
)
projx_fn = lambda x: manifold.projx(x) if projx else x
k1 = velocity_fn(xt, t0)
k2 = velocity_fn(projx_fn(xt + dt * k1 / 3), t0 + dt / 3)
k3 = velocity_fn(projx_fn(xt + dt * (k2 - k1 / 3)), t0 + dt * 2 / 3)
k4 = velocity_fn(projx_fn(xt + dt * (k1 - k2 + k3)), t0 + dt)
return projx_fn(xt + (k1 + 3 * (k2 + k3) + k4) * dt * 0.125)5.5 工具模块 (utils)
ModelWrapper
所有速度场模型必须继承 ModelWrapper,统一接口为 forward(x, t, **extras):
class ModelWrapper(ABC, nn.Module):
def __init__(self, model: nn.Module):
super().__init__()
self.model = model
def forward(self, x: Tensor, t: Tensor, **extras) -> Tensor:
return self.model(x=x, t=t, **extras)自定义模型示例:
class MyVelocityModel(ModelWrapper):
def __init__(self, net):
super().__init__(net)
def forward(self, x, t, **extras):
# 可在此添加时间编码、条件拼接等自定义逻辑
t_embed = self.time_embedding(t)
return self.model(torch.cat([x, t_embed], dim=-1))张量工具函数
# expand_tensor_like: 将 1D 张量扩展到与目标张量相同的形状
# 用途:将 (batch_size,) 的时间/调度器参数扩展到 (batch_size, C, H, W) 等
alpha_t = expand_tensor_like(input_tensor=scheduler_output.alpha_t, expand_to=x_1)
# unsqueeze_to_match: 自动添加维度使源张量与目标张量维度匹配
t = unsqueeze_to_match(source=t, target=x_t)
# gradient: 计算梯度的封装,用于似然计算中的散度估计
grad = gradient(output, x, create_graph=True)categorical 采样器
基于 torch.multinomial 的分类采样器,支持任意形状的概率张量:
def categorical(probs: Tensor) -> Tensor:
return torch.multinomial(
probs.flatten(0, -2), 1, replacement=True
).view(*probs.shape[:-1])6. 使用指南与代码示例
6.1 连续 Flow Matching 完整训练流程
import torch
from flow_matching.path import AffineProbPath
from flow_matching.path.scheduler import CondOTScheduler
from flow_matching.solver import ODESolver
from flow_matching.utils import ModelWrapper
# 1. 定义模型
class SimpleVelocityNet(ModelWrapper):
def __init__(self, dim):
net = torch.nn.Sequential(
torch.nn.Linear(dim + 1, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, dim),
)
super().__init__(net)
def forward(self, x, t, **extras):
t_expanded = t.unsqueeze(-1) if t.dim() == 1 else t
inp = torch.cat([x, t_expanded], dim=-1)
return self.model(inp)
# 2. 初始化路径和模型
dim = 2
path = AffineProbPath(scheduler=CondOTScheduler())
model = SimpleVelocityNet(dim)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
mse_loss = torch.nn.MSELoss()
# 3. 训练循环
for epoch in range(1000):
# 源分布:标准高斯
x_0 = torch.randn(256, dim)
# 目标分布:你的数据
x_1 = sample_from_data(batch_size=256)
# 随机时间
t = torch.rand(256)
# 采样条件路径
path_sample = path.sample(x_0=x_0, x_1=x_1, t=t)
# 计算速度匹配损失
predicted_velocity = model(path_sample.x_t, path_sample.t)
loss = mse_loss(predicted_velocity, path_sample.dx_t)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 4. 推理:生成样本
solver = ODESolver(velocity_model=model)
x_0 = torch.randn(100, dim)
x_1 = solver.sample(
x_init=x_0,
step_size=0.01,
method="midpoint",
time_grid=torch.tensor([0.0, 1.0]),
)6.2 离散 Flow Matching 训练流程
import torch
from flow_matching.path import MixtureDiscreteProbPath
from flow_matching.path.scheduler import PolynomialConvexScheduler
from flow_matching.loss import MixturePathGeneralizedKL
from flow_matching.solver import MixtureDiscreteEulerSolver
# 1. 初始化
vocab_size = 256
scheduler = PolynomialConvexScheduler(n=1.0)
path = MixtureDiscreteProbPath(scheduler=scheduler)
loss_fn = MixturePathGeneralizedKL(path=path)
# 2. 训练循环
for x_1 in dataloader: # x_1: (batch, seq_len) 的整数 token
x_0 = torch.randint(0, vocab_size, x_1.shape) # 均匀随机源
t = torch.rand(x_1.shape[0])
path_sample = path.sample(x_0=x_0, x_1=x_1, t=t)
# 模型输出 logits: (batch, seq_len, vocab_size)
logits = model(path_sample.x_t, path_sample.t)
loss = loss_fn(logits=logits, x_1=x_1, x_t=path_sample.x_t, t=t)
loss.backward()
optimizer.step()
# 3. 推理
solver = MixtureDiscreteEulerSolver(
model=model, path=path, vocabulary_size=vocab_size
)
x_0 = torch.randint(0, vocab_size, (batch_size, seq_len))
x_1 = solver.sample(x_init=x_0, step_size=1/1000)6.3 黎曼 Flow Matching(球面)
import torch
from flow_matching.path import GeodesicProbPath
from flow_matching.path.scheduler import CondOTScheduler
from flow_matching.solver import RiemannianODESolver
from flow_matching.utils.manifolds import Sphere
# 1. 初始化
manifold = Sphere()
scheduler = CondOTScheduler()
path = GeodesicProbPath(scheduler=scheduler, manifold=manifold)
# 2. 训练
for x_1 in dataloader: # x_1: 球面上的数据点
x_0 = manifold.projx(torch.randn_like(x_1)) # 投影到球面
t = torch.rand(x_1.shape[0])
path_sample = path.sample(x_0=x_0, x_1=x_1, t=t)
predicted_velocity = model(path_sample.x_t, path_sample.t)
loss = mse_loss(predicted_velocity, path_sample.dx_t)
loss.backward()
# 3. 推理
solver = RiemannianODESolver(manifold=manifold, velocity_model=model)
x_0 = manifold.projx(torch.randn(100, 3))
x_1 = solver.sample(x_init=x_0, step_size=0.01, method="rk4")6.4 训练后更换调度器
from flow_matching.path.scheduler import CondOTScheduler, CosineScheduler, ScheduleTransformedModel
# 模型原本用 CondOT 调度器训练
original_scheduler = CondOTScheduler()
new_scheduler = CosineScheduler()
# 无需重新训练,直接包装模型
transformed_model = ScheduleTransformedModel(
velocity_model=trained_model,
original_scheduler=original_scheduler,
new_scheduler=new_scheduler,
)
# 用新调度器进行推理
solver = ODESolver(velocity_model=transformed_model)
x_1 = solver.sample(x_init=x_0, step_size=1/1000)6.5 计算精确 log 似然
import torch
from flow_matching.solver import ODESolver
solver = ODESolver(velocity_model=trained_model)
# 定义源分布的 log 概率(标准高斯)
log_p0 = lambda x: -0.5 * (x.pow(2).sum(-1) + x.shape[-1] * torch.log(torch.tensor(2 * torch.pi)))
# 计算数据点的 log 似然
x_0_recovered, log_likelihood = solver.compute_likelihood(
x_1=data_samples,
log_p0=log_p0,
step_size=1/1000,
time_grid=torch.tensor([1.0, 0.0]),
exact_divergence=False, # Hutchinson 估计器更高效
)
# log_likelihood: 每个样本的 log p_1(x_1)7. 依赖关系
核心库依赖
flow_matching
├── numpy # 数值计算基础
├── torch (≥2.1) # 张量计算、自动微分、神经网络
└── torchdiffeq # ODE 数值积分(odeint 接口)开发依赖
dev:
├── pre-commit # Git 提交前代码检查
├── black # 代码格式化
├── usort / ufmt # import 排序与格式化
├── flake8 # 代码风格检查
└── pydoclint # 文档字符串检查示例额外依赖
- 图像示例:
torchvision,submitit(分布式训练) - 文本示例:
hydra-core,wandb(实验管理) - Notebook:
matplotlib,jupyter,scikit-learn,tqdm
本文档基于
flow_matching v1.0.10源码生成。 论文引用:Lipman et al., “Flow Matching Guide and Code”, arXiv:2412.06264, 2024.