目录

Flow Matching Guide and Code 第5章解读:FlatTorus Riemannian Flow Matching 训练逻辑技术文档

本文档详细说明 examples/2d_riemannian_flow_matching_flat_torus.ipynb训练块(超参数设置 + 单步训练循环)的代码逻辑、数学原理与函数调用链。


训练目标:在平坦环面 M=[0,2π)2M = [0, 2\pi)^2 上学习一个速度场 vθ(x,t)v_\theta(x, t),使得从先验 p0p_0 沿 ODE x˙=vθ(x,t)\dot{x} = v_\theta(x, t) 积分到 t=1t=1 时,得到数据分布 p1p_1

  • 概率路径:测地线路径,Xt=expX0(αtlogX0(X1))X_t = \exp_{X_0}(\alpha_t \cdot \log_{X_0}(X_1))αt\alpha_t 由调度器给出(本示例中 CondOTSchedulerαt=t\alpha_t = t)。
  • 损失:Flow matching 的 L2 损失,拟合条件速度 X˙t\dot{X}_t

层级名称说明
流形manifoldFlatTorus(),定义 [0,2π)2[0,2\pi)^2 上的 projx / proju / expmap / logmap
调度器schedulerCondOTScheduler(),提供 αt=t\alpha_t=tσt=1t\sigma_t=1-t
概率路径pathGeodesicProbPath(scheduler, manifold),封装测地线路径与 sample(t, x_0, x_1)
速度场vfProjectToTangent(MLP, manifold),训练与推理时统一入口 vf(x, t)
优化器optimtorch.optim.Adam(vf.parameters(), lr),只更新 vf 参数

训练循环入口:单步顺序为
采样端点 → wrap 上流形 → 采样 ttpath.sample(xt,x˙t)(x_t, \dot{x}_t)vf(x_t, t)dx_t 算 L2 损失 → backward / step

数据流inf_train_gen / randn_likewrap(manifold, ·)path.samplePathSample(x_t, dx_t, ...)vf(x_t, t)dx_t 求差、均方 → loss


变量含义本示例取值
lr学习率0.001
batch_size每步采样的 (X0,X1)(X_0, X_1) 对数4096
iterations总迭代步数5001
print_every打印间隔1000
manifold流形实例FlatTorus(),表示 [0,2π)2[0, 2\pi)^2
dim流形维度2
hidden_dimMLP 隐藏层维度512
vf = ProjectToTangent(MLP(...), manifold=manifold)
vf.to(device)
  • MLP:输入 (x,t)(x, t)(流形坐标 + 时间),输出 Rdim\mathbb{R}^{\mathrm{dim}} 的向量(欧氏速度)。
  • ProjectToTangent:对输入 xxmanifold.projx(x),再对 MLP 输出做 manifold.proju(x, v),保证输出是 xx 处切空间中的向量。
  • 调用链(推理/训练时):
    • vf(x, t)ProjectToTangent.forward(x, t)
    • x' = manifold.projx(x)(环面:x=xmod2πx' = x \bmod 2\pi
    • v = MLP(x', t)
    • v' = manifold.proju(x', v)(FlatTorus 上恒等)
    • → 返回 v'
path = GeodesicProbPath(scheduler=CondOTScheduler(), manifold=manifold)
  • GeodesicProbPath:基于流形测地线的概率路径;需要 ConvexScheduler 提供 αt\alpha_t
  • CondOTSchedulerαt=t\alpha_t = tσt=1t\sigma_t = 1 - tα˙t=1\dot{\alpha}_t = 1σ˙t=1\dot{\sigma}_t = -1
  • path.sample(t, x_0, x_1) 的数学与实现见下文 §5。
optim = torch.optim.Adam(vf.parameters(), lr=lr)

仅优化 vf(即 ProjectToTangent 内的 MLP)的参数。


每一步迭代完成以下流程(顺序与代码一致)。

optim.zero_grad()

为当前步的 loss.backward() 做准备。

  • x_1:数据端

    • 调用 inf_train_gen(batch_size=batch_size, device=device)
    • 在平面区域上生成棋盘格状样本,形状 (batch_size, 2),数值约在 [2,2]2[-2,2]^2 或类似范围。
  • x_0:先验端

    • x_0 = torch.randn_like(x_1).to(device)
    • X0N(0,I)X_0 \sim \mathcal{N}(0, I)(与 x_1 同 shape)。
x_1 = wrap(manifold, x_1)
x_0 = wrap(manifold, x_0)
  • wrap(manifold, samples) 实现:
    • center = zeros_like(samples)
    • return manifold.expmap(center, samples)
  • FlatTorusexpmap(0, u) = u % (2π),因此
    • x_1x_0 被映射到 [0,2π)2[0, 2\pi)^2,保证路径两端都在流形上。

调用关系
wrapmanifold.expmap(0, samples) → FlatTorus 上等价于 samples % (2π)

t = torch.rand(x_1.shape[0]).to(device)
  • 每个样本一个 t[0,1]t \in [0,1],shape (batch_size,)
path_sample = path.sample(t=t, x_0=x_0, x_1=x_1)
  • pathGeodesicProbPath(scheduler=CondOTScheduler(), manifold=manifold)
  • path.sample(x_0, x_1, t)(即 GeodesicProbPath.sample(x_0, x_1, t))返回 PathSample,包含:
    • x_t:路径在时间 tt 上的点,流形坐标 [0,2π)2[0,2\pi)^2
    • dx_t:路径在 tt 处对时间的导数(条件目标速度)
    • x_0, x_1, t(同输入)

内部函数调用链(见 §5):

  1. expand_tensor_like(t, x_1[..., 0:1]),使 t 与 batch 维度对齐。
  2. 对每个 (x0(i),x1(i),t(i))(x_0^{(i)}, x_1^{(i)}, t^{(i)})(通过 vmap):
    • geodesic(manifold, x_0, x_1) → 得到可调用对象 path(τ),满足
      path(τ)=expx0(τlogx0(x1))path(\tau) = \exp_{x_0}(\tau \cdot \log_{x_0}(x_1))τ[0,1]\tau \in [0,1]
    • scheduler(t)SchedulerOutput(alpha_t=t, sigma_t=1-t, d_alpha_t=1, d_sigma_t=-1)(CondOT 下 αt=t\alpha_t = t)。
    • path(alpha_t) = path(t) = 测地线在 tt 处的点 XtX_t
    • jvp:对 τpath(scheduler(τ).alpha_t)\tau \mapsto path(\mathrm{scheduler}(\tau).\mathrm{alpha\_t})τ=t\tau=t 处求导(CondOT 下即对 path(τ)path(\tau)τ=t\tau=t 求导),得到 XtX_ttXt=X˙t\frac{\partial}{\partial t}X_t = \dot{X}_t(即 dx_t)。
  3. vmap 结果 reshape 成与 x_1 相同 shape,构造 PathSample(x_t, dx_t, x_1, x_0, t) 并返回。
loss = torch.pow(vf(path_sample.x_t, path_sample.t) - path_sample.dx_t, 2).mean()
  • path_sample.x_t:流形上点,坐标在 [0,2π)2[0, 2\pi)^2(与 FlatTorus 定义一致)。
  • path_sample.t:与 x_t 一一对应的时间,shape 与 batch 兼容。
  • vf(path_sample.x_t, path_sample.t)
    • 调用 ProjectToTangent.forward(x_t, t)
    • 内部:projx(x_t) → MLP → proju(x_t, v),输出切空间中的预测速度,shape 与 x_t 一致。
  • path_sample.dx_t:测地线在 tt 处的真实条件速度 X˙t\dot{X}_t(目标)。
  • 损失L=E[vθ(Xt,t)X˙t2]\mathcal{L} = \mathbb{E}\bigl[\|v_\theta(X_t, t) - \dot{X}_t\|^2\bigr],即 L2 拟合速度场。
loss.backward()
optim.step()
  • 梯度只流入 vf(MLP + ProjectToTangent 的 projx/proju 若可微则参与,FlatTorus 的 projx/proju 可微)。

每隔 print_every 步打印当前迭代数、平均每步耗时和当前 loss.item()


函数 / 方法签名(要点)调用者返回值 / 作用
inf_train_gen(batch_size, device)训练循环数据端样本 x_1,形状 (B, 2)
wrap(manifold, samples)训练循环投影到流形后的样本,[0,2π)2[0,2\pi)^2
path.sample(t, x_0, x_1)训练循环PathSample(x_t, dx_t, x_0, x_1, t)
vf(x, t)训练循环、推理切空间速度向量,与 x 同 shape
geodesic(manifold, x_0, x_1)GeodesicProbPath.sample可调用 path(τ)τ[0,1]\tau\in[0,1]
scheduler(t)(t)GeodesicProbPath.sampleSchedulerOutput(alpha_t, sigma_t, ...)
manifold.projx(x)ProjectToTangent.forward投影到流形,FlatTorus 为 x % (2π)
manifold.proju(x, v)ProjectToTangent.forward投影到切空间,FlatTorus 恒等
manifold.expmap(x, u)wrapgeodesic 返回的 path指数映射,FlatTorus 为 (x+u) % (2π)
manifold.logmap(x, y)geodesic 内部对数映射,FlatTorus 为 atan2(sin(y-x), cos(y-x))
  • 作用:构造从 x0x_0x1x_1 的测地线,参数化在 τ[0,1]\tau \in [0,1]
  • 数学
    • u=logmap(x0,x1)u = \mathrm{logmap}(x_0, x_1)(从 x0x_0 指向 x1x_1 的切向量)
    • path(τ)=expmap(x0,τu)path(\tau) = \mathrm{expmap}(x_0, \tau \cdot u)
  • FlatTorus
    • logmap(x, y) = atan2(sin(y-x), cos(y-x))(最短方向,考虑周期)
    • expmap(x, u) = (x + u) % (2π)
  • 返回:可调用对象 path(τ),输入 τ 的 shape 与 batch 兼容,输出流形上点的 shape 与 x_0/x_1 一致。
  • 输入t,shape (batch_size,) 或可广播。
  • 输出SchedulerOutput
    • alpha_t = t
    • sigma_t = 1 - t
    • d_alpha_t = 1
    • d_sigma_t = -1
  • 本示例中只用到 alpha_t,用于测地线参数:Xt=path(αt)=path(t)X_t = path(\alpha_t) = path(t)
  • assert_sample_shape:检查 x_0, x_1, t 的 batch 等维度兼容性。
  • texpand_tensor_likex_1[..., 0:1] 对齐,便于与 x_0, x_1 一起做 vmap
  • cond_u(x_0, x_1, t)(单样本逻辑,再被 vmap):
    1. path = geodesic(manifold, x_0, x_1)
    2. alpha_t = scheduler(t).alpha_t(CondOT 下即为 t
    3. x_t, dx_t = jvp(lambda t: path(scheduler(t).alpha_t), (t,), (1,))
      • τpath(τ)\tau \mapsto path(\tau)τ=t\tau = t 处求导,得到 path(t)path(t)ddτpath(τ)τ=t\frac{d}{d\tau}path(\tau)\big|_{\tau=t};由于 τ=t\tau=tddtαt=1\frac{d}{dt}\alpha_t = 1,这里得到的就是 X˙t\dot{X}_t
  • vmap(cond_u)(x_0, x_1, t):对整批做上述计算。
  • reshape_as(x_1):保证 x_tdx_t 的 shape 与 x_1 一致。
  • 返回PathSample(x_t=x_t, dx_t=dx_t, x_1=x_1, x_0=x_0, t=t)
  • x = manifold.projx(x):FlatTorus 上为 x % (2π),保证 xx[0,2π)2[0, 2\pi)^2
  • v = vecfield(x, t):MLP 在欧氏空间预测速度向量。
  • v = manifold.proju(x, v):FlatTorus 上切空间即 R2\mathbb{R}^2,故恒等。
  • 返回的 v 即模型给出的切空间速度,与 path_sample.dx_t 的坐标与物理意义一致。
训练循环 (每步)
├── inf_train_gen(batch_size, device)     → x_1
├── torch.randn_like(x_1)                 → x_0
├── wrap(manifold, x_1), wrap(manifold, x_0)
│   └── manifold.expmap(0, samples)
├── torch.rand(...)                       → t
├── path.sample(t, x_0, x_1)
│   ├── expand_tensor_like(t, x_1)
│   └── vmap(cond_u)(x_0, x_1, t)
│       ├── geodesic(manifold, x_0, x_1)  → path(τ),内部用 logmap/expmap
│       ├── scheduler(t)                  → alpha_t, sigma_t, ...
│       └── jvp(path(scheduler(·).alpha_t), t, 1) → x_t, dx_t
├── vf(path_sample.x_t, path_sample.t)
│   ├── manifold.projx(x)
│   ├── MLP(x, t)
│   └── manifold.proju(x, v)
└── loss = mean((vf(...) - dx_t)²); loss.backward(); optim.step()
调用方向说明
训练循环 → wrap将端点投影到 [0,2π)2[0,2\pi)^2
训练循环 → path.sample得到路径上点与条件速度 (xt,x˙t)(x_t, \dot{x}_t)
训练循环 → vf预测速度,与 dx_t 求 L2 损失
path.samplegeodesic构造单条测地线可调用对象
path.sampleschedulerαt\alpha_t 与路径参数化
path.samplejvp对路径在 tt 处求导得 X˙t\dot{X}_t
vfmanifold.projx / proju保证输入输出在流形与切空间上

  • 流形坐标:FlatTorus 全程使用 [0,2π)2[0, 2\pi)^2

    • inf_train_gen 生成的数据经 wrap 后落入该域;
    • x_0wrap 后也在该域;
    • path.sample 得到的 x_tdx_t 以及 vf 的输入/输出均在该坐标下。
  • 时间t[0,1]t \in [0,1]t=0t=0 对应先验端,t=1t=1 对应数据端;CondOT 下测地线参数 αt=t\alpha_t = t,故 XtX_t 即为从 x0x_0x1x_1 的测地线在 tt 处的点。


inf_train_gen(batch_size, device)  →  x_1 (平面棋盘格)
torch.randn_like(x_1)              →  x_0 (高斯)
wrap(manifold, x_1), wrap(manifold, x_0)  →  x_1, x_0 ∈ [0,2π)²
torch.rand(...)                    →  t ∈ [0,1]

path.sample(t, x_0, x_1)
  → geodesic(manifold, x_0, x_1)   →  path(τ)
  → scheduler(t).alpha_t           →  α_t = t
  → jvp(path(α_t), t, 1)           →  x_t, dx_t
  → PathSample(x_t, dx_t, x_1, x_0, t)

vf(path_sample.x_t, path_sample.t)
  → projx(x_t) → MLP(x_t, t) → proju(x_t, v)  →  v_θ(x_t, t)

loss = mean(|v_θ(x_t,t) - dx_t|²)
loss.backward(); optim.step()

以上即平坦环面示例中训练块的完整逻辑与函数调用说明。

相关内容