Flow Matching Guide and Code 第5章解读:FlatTorus Riemannian Flow Matching 训练逻辑技术文档
系列 -
目录
本文档详细说明 examples/2d_riemannian_flow_matching_flat_torus.ipynb 中训练块(超参数设置 + 单步训练循环)的代码逻辑、数学原理与函数调用链。
1. 总览
训练目标:在平坦环面 上学习一个速度场 ,使得从先验 沿 ODE 积分到 时,得到数据分布 。
- 概率路径:测地线路径,, 由调度器给出(本示例中
CondOTScheduler取 )。 - 损失:Flow matching 的 L2 损失,拟合条件速度 。
2. 代码结构概览
2.0 模块与对象层次
| 层级 | 名称 | 说明 |
|---|---|---|
| 流形 | manifold | FlatTorus(),定义 上的 projx / proju / expmap / logmap |
| 调度器 | scheduler | CondOTScheduler(),提供 、 等 |
| 概率路径 | path | GeodesicProbPath(scheduler, manifold),封装测地线路径与 sample(t, x_0, x_1) |
| 速度场 | vf | ProjectToTangent(MLP, manifold),训练与推理时统一入口 vf(x, t) |
| 优化器 | optim | torch.optim.Adam(vf.parameters(), lr),只更新 vf 参数 |
训练循环入口:单步顺序为
采样端点 → wrap 上流形 → 采样 → path.sample 得 → vf(x_t, t) 与 dx_t 算 L2 损失 → backward / step。
数据流:inf_train_gen / randn_like → wrap(manifold, ·) → path.sample → PathSample(x_t, dx_t, ...) → vf(x_t, t) 与 dx_t 求差、均方 → loss。
3. 超参数与对象初始化
3.1 超参数
| 变量 | 含义 | 本示例取值 |
|---|---|---|
lr | 学习率 | 0.001 |
batch_size | 每步采样的 对数 | 4096 |
iterations | 总迭代步数 | 5001 |
print_every | 打印间隔 | 1000 |
manifold | 流形实例 | FlatTorus(),表示 |
dim | 流形维度 | 2 |
hidden_dim | MLP 隐藏层维度 | 512 |
3.2 速度场模型 vf
vf = ProjectToTangent(MLP(...), manifold=manifold)
vf.to(device)- MLP:输入 (流形坐标 + 时间),输出 的向量(欧氏速度)。
- ProjectToTangent:对输入 做
manifold.projx(x),再对 MLP 输出做manifold.proju(x, v),保证输出是 处切空间中的向量。 - 调用链(推理/训练时):
vf(x, t)→ProjectToTangent.forward(x, t)- →
x' = manifold.projx(x)(环面:) - →
v = MLP(x', t) - →
v' = manifold.proju(x', v)(FlatTorus 上恒等) - → 返回
v'
3.3 概率路径 path
path = GeodesicProbPath(scheduler=CondOTScheduler(), manifold=manifold)- GeodesicProbPath:基于流形测地线的概率路径;需要
ConvexScheduler提供 。 - CondOTScheduler:,,,。
- path.sample(t, x_0, x_1) 的数学与实现见下文 §5。
3.4 优化器
optim = torch.optim.Adam(vf.parameters(), lr=lr)仅优化 vf(即 ProjectToTangent 内的 MLP)的参数。
4. 单步训练循环:逻辑与数据流
每一步迭代完成以下流程(顺序与代码一致)。
4.1 清空梯度
optim.zero_grad()为当前步的 loss.backward() 做准备。
4.2 从耦合 采样
x_1:数据端
- 调用
inf_train_gen(batch_size=batch_size, device=device) - 在平面区域上生成棋盘格状样本,形状
(batch_size, 2),数值约在 或类似范围。
- 调用
x_0:先验端
x_0 = torch.randn_like(x_1).to(device)- 即 (与 x_1 同 shape)。
4.3 将端点投影到流形
x_1 = wrap(manifold, x_1)
x_0 = wrap(manifold, x_0)- wrap(manifold, samples) 实现:
center = zeros_like(samples)return manifold.expmap(center, samples)
- 对 FlatTorus:
expmap(0, u) = u % (2π),因此x_1、x_0被映射到 ,保证路径两端都在流形上。
调用关系:wrap → manifold.expmap(0, samples) → FlatTorus 上等价于 samples % (2π)。
4.4 采样时间
t = torch.rand(x_1.shape[0]).to(device)- 每个样本一个 ,shape
(batch_size,)。
4.5 沿概率路径采样
path_sample = path.sample(t=t, x_0=x_0, x_1=x_1)- path:
GeodesicProbPath(scheduler=CondOTScheduler(), manifold=manifold) - path.sample(x_0, x_1, t)(即
GeodesicProbPath.sample(x_0, x_1, t))返回PathSample,包含:x_t:路径在时间 上的点,流形坐标dx_t:路径在 处对时间的导数(条件目标速度)x_0,x_1,t(同输入)
内部函数调用链(见 §5):
expand_tensor_like(t, x_1[..., 0:1]),使t与 batch 维度对齐。- 对每个 (通过
vmap):geodesic(manifold, x_0, x_1)→ 得到可调用对象path(τ),满足
,。scheduler(t)→SchedulerOutput(alpha_t=t, sigma_t=1-t, d_alpha_t=1, d_sigma_t=-1)(CondOT 下 )。path(alpha_t)=path(t)= 测地线在 处的点 。jvp:对 在 处求导(CondOT 下即对 在 求导),得到 与 (即dx_t)。
- 将
vmap结果 reshape 成与x_1相同 shape,构造PathSample(x_t, dx_t, x_1, x_0, t)并返回。
4.6 计算 Flow Matching L2 损失
loss = torch.pow(vf(path_sample.x_t, path_sample.t) - path_sample.dx_t, 2).mean()- path_sample.x_t:流形上点,坐标在 (与 FlatTorus 定义一致)。
- path_sample.t:与
x_t一一对应的时间,shape 与 batch 兼容。 - vf(path_sample.x_t, path_sample.t):
- 调用
ProjectToTangent.forward(x_t, t) - 内部:
projx(x_t)→ MLP →proju(x_t, v),输出切空间中的预测速度,shape 与x_t一致。
- 调用
- path_sample.dx_t:测地线在 处的真实条件速度 (目标)。
- 损失:,即 L2 拟合速度场。
4.7 反向传播与参数更新
loss.backward()
optim.step()- 梯度只流入
vf(MLP + ProjectToTangent 的 projx/proju 若可微则参与,FlatTorus 的 projx/proju 可微)。
4.8 日志
每隔 print_every 步打印当前迭代数、平均每步耗时和当前 loss.item()。
5. 关键函数调用说明
5.1 函数签名与角色速查
| 函数 / 方法 | 签名(要点) | 调用者 | 返回值 / 作用 |
|---|---|---|---|
inf_train_gen | (batch_size, device) | 训练循环 | 数据端样本 x_1,形状 (B, 2) |
wrap | (manifold, samples) | 训练循环 | 投影到流形后的样本, |
path.sample | (t, x_0, x_1) | 训练循环 | PathSample(x_t, dx_t, x_0, x_1, t) |
vf | (x, t) | 训练循环、推理 | 切空间速度向量,与 x 同 shape |
geodesic | (manifold, x_0, x_1) | GeodesicProbPath.sample 内 | 可调用 path(τ), |
scheduler(t) | (t) | GeodesicProbPath.sample 内 | SchedulerOutput(alpha_t, sigma_t, ...) |
manifold.projx | (x) | ProjectToTangent.forward | 投影到流形,FlatTorus 为 x % (2π) |
manifold.proju | (x, v) | ProjectToTangent.forward | 投影到切空间,FlatTorus 恒等 |
manifold.expmap | (x, u) | wrap、geodesic 返回的 path | 指数映射,FlatTorus 为 (x+u) % (2π) |
manifold.logmap | (x, y) | geodesic 内部 | 对数映射,FlatTorus 为 atan2(sin(y-x), cos(y-x)) |
5.2 geodesic(manifold, x_0, x_1)(flow_matching.utils.manifolds.utils)
- 作用:构造从 到 的测地线,参数化在 。
- 数学:
- (从 指向 的切向量)
- FlatTorus:
logmap(x, y) = atan2(sin(y-x), cos(y-x))(最短方向,考虑周期)expmap(x, u) = (x + u) % (2π)
- 返回:可调用对象
path(τ),输入τ的 shape 与 batch 兼容,输出流形上点的 shape 与x_0/x_1一致。
5.3 CondOTScheduler.__call__(t)
- 输入:
t,shape(batch_size,)或可广播。 - 输出:
SchedulerOutput:alpha_t = tsigma_t = 1 - td_alpha_t = 1d_sigma_t = -1
- 本示例中只用到
alpha_t,用于测地线参数:。
5.4 GeodesicProbPath.sample(x_0, x_1, t) 内部
- assert_sample_shape:检查
x_0, x_1, t的 batch 等维度兼容性。 - t 经
expand_tensor_like与x_1[..., 0:1]对齐,便于与x_0, x_1一起做vmap。 - cond_u(x_0, x_1, t)(单样本逻辑,再被 vmap):
path = geodesic(manifold, x_0, x_1)alpha_t = scheduler(t).alpha_t(CondOT 下即为t)x_t, dx_t = jvp(lambda t: path(scheduler(t).alpha_t), (t,), (1,))- 对 在 处求导,得到 和 ;由于 时 ,这里得到的就是 。
- vmap(cond_u)(x_0, x_1, t):对整批做上述计算。
- reshape_as(x_1):保证
x_t、dx_t的 shape 与x_1一致。 - 返回:
PathSample(x_t=x_t, dx_t=dx_t, x_1=x_1, x_0=x_0, t=t)。
5.5 vf(x, t) = ProjectToTangent.forward(x, t)
x = manifold.projx(x):FlatTorus 上为x % (2π),保证 在 。v = vecfield(x, t):MLP 在欧氏空间预测速度向量。v = manifold.proju(x, v):FlatTorus 上切空间即 ,故恒等。- 返回的
v即模型给出的切空间速度,与path_sample.dx_t的坐标与物理意义一致。
5.6 函数调用关系总览
训练循环 (每步)
├── inf_train_gen(batch_size, device) → x_1
├── torch.randn_like(x_1) → x_0
├── wrap(manifold, x_1), wrap(manifold, x_0)
│ └── manifold.expmap(0, samples)
├── torch.rand(...) → t
├── path.sample(t, x_0, x_1)
│ ├── expand_tensor_like(t, x_1)
│ └── vmap(cond_u)(x_0, x_1, t)
│ ├── geodesic(manifold, x_0, x_1) → path(τ),内部用 logmap/expmap
│ ├── scheduler(t) → alpha_t, sigma_t, ...
│ └── jvp(path(scheduler(·).alpha_t), t, 1) → x_t, dx_t
├── vf(path_sample.x_t, path_sample.t)
│ ├── manifold.projx(x)
│ ├── MLP(x, t)
│ └── manifold.proju(x, v)
└── loss = mean((vf(...) - dx_t)²); loss.backward(); optim.step()| 调用方向 | 说明 |
|---|---|
训练循环 → wrap | 将端点投影到 |
训练循环 → path.sample | 得到路径上点与条件速度 |
训练循环 → vf | 预测速度,与 dx_t 求 L2 损失 |
path.sample → geodesic | 构造单条测地线可调用对象 |
path.sample → scheduler | 取 与路径参数化 |
path.sample → jvp | 对路径在 处求导得 |
vf → manifold.projx / proju | 保证输入输出在流形与切空间上 |
6. 数据与坐标约定
流形坐标:FlatTorus 全程使用 。
inf_train_gen生成的数据经wrap后落入该域;x_0经wrap后也在该域;path.sample得到的x_t、dx_t以及vf的输入/输出均在该坐标下。
时间:, 对应先验端, 对应数据端;CondOT 下测地线参数 ,故 即为从 到 的测地线在 处的点。
7. 小结:单步训练的数据与调用顺序
inf_train_gen(batch_size, device) → x_1 (平面棋盘格)
torch.randn_like(x_1) → x_0 (高斯)
wrap(manifold, x_1), wrap(manifold, x_0) → x_1, x_0 ∈ [0,2π)²
torch.rand(...) → t ∈ [0,1]
path.sample(t, x_0, x_1)
→ geodesic(manifold, x_0, x_1) → path(τ)
→ scheduler(t).alpha_t → α_t = t
→ jvp(path(α_t), t, 1) → x_t, dx_t
→ PathSample(x_t, dx_t, x_1, x_0, t)
vf(path_sample.x_t, path_sample.t)
→ projx(x_t) → MLP(x_t, t) → proju(x_t, v) → v_θ(x_t, t)
loss = mean(|v_θ(x_t,t) - dx_t|²)
loss.backward(); optim.step()以上即平坦环面示例中训练块的完整逻辑与函数调用说明。