第六章:推理与解码策略(TorchCode)
系列 - TorchCode 系列
目录
第六章:推理与解码策略
本章介绍语言模型生成文本时的三种核心解码算法:采样策略、束搜索和投机解码。
6.1 Top-k / Top-p (Nucleus) Sampling
是什么
Top-k 和 Top-p 是控制语言模型生成多样性的采样策略。它们通过过滤低概率 token,在质量和多样性之间取得平衡。
算法步骤
1. Temperature 缩放: logits = logits / temperature
2. Top-k 过滤: 只保留概率最高的 k 个 token,其余设为 -inf
3. Top-p 过滤: 按概率降序排列,保留累积概率不超过 p 的 token
4. 从过滤后的分布中采样各参数的作用
| 参数 | 效果 | 极端值 |
|---|---|---|
temperature | 控制分布的"尖锐度" | →0: 贪心(确定性);→∞: 均匀随机 |
top_k | 限制候选 token 数量 | k=1: 贪心;k=V: 无过滤 |
top_p | 限制累积概率质量 | p→0: 只选最高概率;p=1: 无过滤 |
Temperature 的直觉
- temperature < 1:分布更尖锐,高概率 token 更突出 → 更确定、更保守
- temperature > 1:分布更平坦,低概率 token 也有机会 → 更随机、更有创意
- temperature = 1:原始分布
代码示例
import torch
def sample_top_k_top_p(logits, top_k=0, top_p=1.0, temperature=1.0):
# 1. Temperature 缩放
logits = logits / temperature
# 2. Top-k 过滤
if top_k > 0:
topk_vals, _ = logits.topk(top_k)
threshold = topk_vals[-1]
logits[logits < threshold] = float('-inf')
# 3. Top-p (nucleus) 过滤
if top_p < 1.0:
sorted_logits, sorted_idx = logits.sort(descending=True)
probs = torch.softmax(sorted_logits, dim=-1)
cumsum = probs.cumsum(dim=-1)
# 移除累积概率超过 p 的 token(保留第一个超过的)
mask = cumsum - probs > top_p
sorted_logits[mask] = float('-inf')
# 恢复原始顺序
logits = sorted_logits.scatter(0, sorted_idx, sorted_logits)
# 4. 采样
probs = torch.softmax(logits, dim=-1)
return torch.multinomial(probs, 1).item()
# 测试
logits = torch.tensor([1.0, 5.0, 2.0, 0.5])
print("top_k=1:", sample_top_k_top_p(logits.clone(), top_k=1)) # 总是 1
print("temp=0.01:", sample_top_k_top_p(logits.clone(), temperature=0.01)) # 几乎总是 1实践建议
- 代码生成:temperature=0.2, top_p=0.95(低随机性)
- 创意写作:temperature=0.8, top_p=0.95(高多样性)
- 对话:temperature=0.7, top_k=50, top_p=0.9
6.2 Beam Search(束搜索)
是什么
Beam Search 是一种确定性的序列搜索算法,维护 beam_width 个最优候选序列,每步扩展所有候选并保留得分最高的。
算法
初始化: beams = [(0.0, [start_token])]
每一步:
candidates = []
对每个 beam (score, sequence):
获取下一步的 log 概率分布
对 top-beam_width 个 token:
candidates.append((score + log_prob, sequence + [token]))
beams = top beam_width candidates by score
终止条件:
最优 beam 以 eos_token 结尾,或达到 max_len与贪心搜索的区别
- 贪心搜索:每步只选最优 token(beam_width=1)
- Beam Search:保留多个候选,可能找到全局更优的序列
- 代价:计算量 × beam_width
代码示例
import torch
def beam_search(log_prob_fn, start_token, max_len, beam_width, eos_token):
beams = [(0.0, [start_token])]
for _ in range(max_len):
candidates = []
for score, seq in beams:
if seq[-1] == eos_token:
candidates.append((score, seq))
continue
log_probs = log_prob_fn(seq) # (V,)
topk_vals, topk_idx = log_probs.topk(beam_width)
for val, idx in zip(topk_vals, topk_idx):
candidates.append((score + val.item(), seq + [idx.item()]))
# 保留 top beam_width
candidates.sort(key=lambda x: x[0], reverse=True)
beams = candidates[:beam_width]
# 最优 beam 已结束
if beams[0][1][-1] == eos_token:
break
return beams[0][1]
# 测试
def simple_fn(tokens):
lp = torch.full((5,), -10.0)
lp[min(len(tokens), 4)] = 0.0
return lp
seq = beam_search(simple_fn, start_token=0, max_len=5, beam_width=2, eos_token=4)
print("序列:", seq) # [0, 1, 2, 3, 4]适用场景
- 机器翻译(需要高质量输出)
- 语音识别
- 不适合开放式生成(倾向于生成重复、无聊的文本)
6.3 Speculative Decoding(投机解码)
是什么
利用一个小型"草稿模型"(draft model)快速生成多个候选 token,然后用大型"目标模型"(target model)并行验证。被接受的 token 无需重新计算,从而加速推理。
核心思想
- 大模型推理慢(受内存带宽限制),但一次验证 K 个 token 的成本与生成 1 个 token 相近
- 小模型推理快,可以快速"猜测"接下来的 token
- 如果猜对了,等于免费获得了多个 token
接受/拒绝算法
对每个位置 i = 0, ..., K-1:
ratio = target_probs[i, token_i] / draft_probs[i, token_i]
以概率 min(1, ratio) 接受
如果拒绝:
从 normalize(max(0, target - draft)) 采样一个修正 token
追加到结果,停止
如果全部接受:
返回所有 K 个 token数学保证
这个接受/拒绝方案保证最终的 token 分布与直接从目标模型采样完全一致(不是近似)。
代码示例
import torch
def speculative_decode(target_probs, draft_probs, draft_tokens):
K = len(draft_tokens)
accepted = []
for i in range(K):
token = draft_tokens[i].item()
ratio = target_probs[i, token] / draft_probs[i, token]
accept_prob = min(1.0, ratio.item())
if torch.rand(1).item() < accept_prob:
accepted.append(token)
else:
# 从修正分布采样
residual = torch.clamp(target_probs[i] - draft_probs[i], min=0)
if residual.sum() > 0:
residual = residual / residual.sum()
correction = torch.multinomial(residual, 1).item()
else:
correction = torch.multinomial(target_probs[i], 1).item()
accepted.append(correction)
break
return accepted
# 测试:完美草稿(target == draft)应全部接受
torch.manual_seed(0)
probs = torch.softmax(torch.randn(4, 10), dim=-1)
tokens = torch.tensor([2, 5, 1, 8])
result = speculative_decode(probs, probs, tokens)
print("完美草稿:", result) # 应该是 [2, 5, 1, 8](全部接受)
# 随机草稿:可能部分拒绝
target = torch.softmax(torch.randn(4, 10), dim=-1)
draft = torch.softmax(torch.randn(4, 10), dim=-1)
result = speculative_decode(target, draft, tokens)
print("随机草稿:", result) # 长度 1-4加速效果
- 草稿模型越接近目标模型,接受率越高,加速越明显
- 典型加速:2-3x(取决于草稿模型质量和 K 值)
- 无质量损失(输出分布与目标模型完全一致)
适用场景
- LLM 推理加速(如 Llama 70B + Llama 7B 作为草稿)
- 任何自回归生成场景