目录

第六章:推理与解码策略(TorchCode)

第六章:推理与解码策略

本章介绍语言模型生成文本时的三种核心解码算法:采样策略、束搜索和投机解码。


Top-k 和 Top-p 是控制语言模型生成多样性的采样策略。它们通过过滤低概率 token,在质量和多样性之间取得平衡。

1. Temperature 缩放: logits = logits / temperature
2. Top-k 过滤: 只保留概率最高的 k 个 token,其余设为 -inf
3. Top-p 过滤: 按概率降序排列,保留累积概率不超过 p 的 token
4. 从过滤后的分布中采样
参数效果极端值
temperature控制分布的"尖锐度"→0: 贪心(确定性);→∞: 均匀随机
top_k限制候选 token 数量k=1: 贪心;k=V: 无过滤
top_p限制累积概率质量p→0: 只选最高概率;p=1: 无过滤
  • temperature < 1:分布更尖锐,高概率 token 更突出 → 更确定、更保守
  • temperature > 1:分布更平坦,低概率 token 也有机会 → 更随机、更有创意
  • temperature = 1:原始分布
import torch

def sample_top_k_top_p(logits, top_k=0, top_p=1.0, temperature=1.0):
    # 1. Temperature 缩放
    logits = logits / temperature

    # 2. Top-k 过滤
    if top_k > 0:
        topk_vals, _ = logits.topk(top_k)
        threshold = topk_vals[-1]
        logits[logits < threshold] = float('-inf')

    # 3. Top-p (nucleus) 过滤
    if top_p < 1.0:
        sorted_logits, sorted_idx = logits.sort(descending=True)
        probs = torch.softmax(sorted_logits, dim=-1)
        cumsum = probs.cumsum(dim=-1)
        # 移除累积概率超过 p 的 token(保留第一个超过的)
        mask = cumsum - probs > top_p
        sorted_logits[mask] = float('-inf')
        # 恢复原始顺序
        logits = sorted_logits.scatter(0, sorted_idx, sorted_logits)

    # 4. 采样
    probs = torch.softmax(logits, dim=-1)
    return torch.multinomial(probs, 1).item()

# 测试
logits = torch.tensor([1.0, 5.0, 2.0, 0.5])
print("top_k=1:", sample_top_k_top_p(logits.clone(), top_k=1))       # 总是 1
print("temp=0.01:", sample_top_k_top_p(logits.clone(), temperature=0.01))  # 几乎总是 1
  • 代码生成:temperature=0.2, top_p=0.95(低随机性)
  • 创意写作:temperature=0.8, top_p=0.95(高多样性)
  • 对话:temperature=0.7, top_k=50, top_p=0.9

Beam Search 是一种确定性的序列搜索算法,维护 beam_width 个最优候选序列,每步扩展所有候选并保留得分最高的。

初始化: beams = [(0.0, [start_token])]

每一步:
  candidates = []
  对每个 beam (score, sequence):
    获取下一步的 log 概率分布
    对 top-beam_width 个 token:
      candidates.append((score + log_prob, sequence + [token]))
  beams = top beam_width candidates by score

终止条件:
  最优 beam 以 eos_token 结尾,或达到 max_len
  • 贪心搜索:每步只选最优 token(beam_width=1)
  • Beam Search:保留多个候选,可能找到全局更优的序列
  • 代价:计算量 × beam_width
import torch

def beam_search(log_prob_fn, start_token, max_len, beam_width, eos_token):
    beams = [(0.0, [start_token])]

    for _ in range(max_len):
        candidates = []
        for score, seq in beams:
            if seq[-1] == eos_token:
                candidates.append((score, seq))
                continue
            log_probs = log_prob_fn(seq)  # (V,)
            topk_vals, topk_idx = log_probs.topk(beam_width)
            for val, idx in zip(topk_vals, topk_idx):
                candidates.append((score + val.item(), seq + [idx.item()]))

        # 保留 top beam_width
        candidates.sort(key=lambda x: x[0], reverse=True)
        beams = candidates[:beam_width]

        # 最优 beam 已结束
        if beams[0][1][-1] == eos_token:
            break

    return beams[0][1]

# 测试
def simple_fn(tokens):
    lp = torch.full((5,), -10.0)
    lp[min(len(tokens), 4)] = 0.0
    return lp

seq = beam_search(simple_fn, start_token=0, max_len=5, beam_width=2, eos_token=4)
print("序列:", seq)  # [0, 1, 2, 3, 4]
  • 机器翻译(需要高质量输出)
  • 语音识别
  • 不适合开放式生成(倾向于生成重复、无聊的文本)

利用一个小型"草稿模型"(draft model)快速生成多个候选 token,然后用大型"目标模型"(target model)并行验证。被接受的 token 无需重新计算,从而加速推理。

  • 大模型推理慢(受内存带宽限制),但一次验证 K 个 token 的成本与生成 1 个 token 相近
  • 小模型推理快,可以快速"猜测"接下来的 token
  • 如果猜对了,等于免费获得了多个 token
对每个位置 i = 0, ..., K-1:
  ratio = target_probs[i, token_i] / draft_probs[i, token_i]
  以概率 min(1, ratio) 接受
  如果拒绝:
    从 normalize(max(0, target - draft)) 采样一个修正 token
    追加到结果,停止
如果全部接受:
  返回所有 K 个 token

这个接受/拒绝方案保证最终的 token 分布与直接从目标模型采样完全一致(不是近似)。

import torch

def speculative_decode(target_probs, draft_probs, draft_tokens):
    K = len(draft_tokens)
    accepted = []

    for i in range(K):
        token = draft_tokens[i].item()
        ratio = target_probs[i, token] / draft_probs[i, token]
        accept_prob = min(1.0, ratio.item())

        if torch.rand(1).item() < accept_prob:
            accepted.append(token)
        else:
            # 从修正分布采样
            residual = torch.clamp(target_probs[i] - draft_probs[i], min=0)
            if residual.sum() > 0:
                residual = residual / residual.sum()
                correction = torch.multinomial(residual, 1).item()
            else:
                correction = torch.multinomial(target_probs[i], 1).item()
            accepted.append(correction)
            break

    return accepted

# 测试:完美草稿(target == draft)应全部接受
torch.manual_seed(0)
probs = torch.softmax(torch.randn(4, 10), dim=-1)
tokens = torch.tensor([2, 5, 1, 8])
result = speculative_decode(probs, probs, tokens)
print("完美草稿:", result)  # 应该是 [2, 5, 1, 8](全部接受)

# 随机草稿:可能部分拒绝
target = torch.softmax(torch.randn(4, 10), dim=-1)
draft = torch.softmax(torch.randn(4, 10), dim=-1)
result = speculative_decode(target, draft, tokens)
print("随机草稿:", result)  # 长度 1-4
  • 草稿模型越接近目标模型,接受率越高,加速越明显
  • 典型加速:2-3x(取决于草稿模型质量和 K 值)
  • 无质量损失(输出分布与目标模型完全一致)
  • LLM 推理加速(如 Llama 70B + Llama 7B 作为草稿)
  • 任何自回归生成场景

相关内容