目录

第三章:注意力机制(TorchCode)

第三章:注意力机制

注意力机制是 Transformer 的核心。本章从基础的 Scaled Dot-Product Attention 出发,逐步展开到所有主流变体。


注意力机制的核心计算单元。给定查询(Q)、键(K)、值(V),计算每个查询对所有键的相关性权重,然后对值进行加权求和。

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V
  1. 计算相似度矩阵:S=QKTS = QK^T,形状 (B, seq_q, seq_k)
  2. 缩放:S=S/dkS = S / \sqrt{d_k},防止点积值过大导致 softmax 饱和
  3. 归一化:对每行做 softmax,得到注意力权重
  4. 加权求和:output=weights×V\text{output} = \text{weights} \times V

dkd_k 较大时,QKTQK^T 的方差约为 dkd_k。除以 dk\sqrt{d_k} 使方差回到 1,避免 softmax 输入过大导致梯度消失。

import torch
import math

def scaled_dot_product_attention(Q, K, V):
    # Q: (B, seq_q, d_k), K: (B, seq_k, d_k), V: (B, seq_k, d_v)
    d_k = Q.shape[-1]
    scores = torch.matmul(Q, K.transpose(-1, -2)) / math.sqrt(d_k)
    weights = torch.softmax(scores, dim=-1)
    return weights @ V  # (B, seq_q, d_v)

# 自注意力
Q = K = V = torch.randn(2, 6, 64)
out = scaled_dot_product_attention(Q, K, V)
print(out.shape)  # (2, 6, 64)

# 交叉注意力(seq_q != seq_k)
Q2 = torch.randn(1, 3, 16)
K2 = torch.randn(1, 10, 16)
V2 = torch.randn(1, 10, 32)
out2 = scaled_dot_product_attention(Q2, K2, V2)
print(out2.shape)  # (1, 3, 32)
  • 所有 Transformer 变体的基础
  • 自注意力(Q=K=V 来自同一输入)和交叉注意力(Q 和 K/V 来自不同输入)

将输入投影到多个子空间(“头”),在每个子空间独立计算注意力,最后拼接并投影回原始维度。

MultiHead(Q,K,V)=Concat(head1,,headh)WO\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \dots, \text{head}_h) W^O

headi=Attention(QWiQ,KWiK,VWiV)\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
  • 单头注意力只能学习一种注意力模式
  • 多头允许模型同时关注不同位置的不同特征(如语法关系、语义关系、位置关系)
  • 每个头的维度 dk=dmodel/hd_k = d_{model} / h,总计算量与单头相同
  • W_q, W_k, W_v:投影矩阵,形状 (d_model, d_model)
  • W_o:输出投影,形状 (d_model, d_model)
  • d_k = d_model // num_heads:每个头的维度
import torch
import torch.nn as nn
import math

class MultiHeadAttention:
    def __init__(self, d_model: int, num_heads: int):
        self.num_heads = num_heads
        self.dk = d_model // num_heads
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, Q, K, V):
        B, S_q, _ = Q.shape
        S_k = K.shape[1]

        # 1. 线性投影
        q = self.W_q(Q)  # (B, S_q, d_model)
        k = self.W_k(K)  # (B, S_k, d_model)
        v = self.W_v(V)  # (B, S_k, d_model)

        # 2. 拆分为多头: (B, S, d_model) → (B, num_heads, S, dk)
        q = q.view(B, S_q, self.num_heads, self.dk).transpose(1, 2)
        k = k.view(B, S_k, self.num_heads, self.dk).transpose(1, 2)
        v = v.view(B, S_k, self.num_heads, self.dk).transpose(1, 2)

        # 3. 每个头独立计算注意力
        scores = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(self.dk)
        weights = torch.softmax(scores, dim=-1)
        attn = weights @ v  # (B, num_heads, S_q, dk)

        # 4. 拼接所有头: (B, S_q, d_model)
        out = attn.transpose(1, 2).contiguous().view(B, S_q, -1)

        # 5. 输出投影
        return self.W_o(out)

# 测试
mha = MultiHeadAttention(d_model=64, num_heads=8)
x = torch.randn(2, 10, 64)
print(mha.forward(x, x, x).shape)  # (2, 10, 64)

在标准注意力基础上添加因果掩码,确保每个位置只能看到自己和之前的位置,不能"偷看"未来的 token。这是 GPT 等自回归语言模型的核心。

scoresij={QiKjdkif jiif j>i\text{scores}_{ij} = \begin{cases} \frac{Q_i \cdot K_j}{\sqrt{d_k}} & \text{if } j \le i \\ -\infty & \text{if } j > i \end{cases}

将未来位置的分数设为 -\infty,经过 softmax 后权重变为 0。

使用上三角矩阵(torch.triu)填充 -\infty

import torch
import math

def causal_attention(Q, K, V):
    d_k = Q.shape[-1]
    scores = Q @ K.transpose(-1, -2) / math.sqrt(d_k)
    # 创建上三角掩码(对角线以上为 -inf)
    mask = torch.triu(torch.full_like(scores, float('-inf')), diagonal=1)
    scores = scores + mask
    weights = torch.softmax(scores, dim=-1)
    return weights @ V

# 测试
Q = K = V = torch.randn(1, 4, 8)
out = causal_attention(Q, K, V)
# 第 0 个位置只能看到自己,所以输出 = V[0]
print(torch.allclose(out[:, 0], V[:, 0], atol=1e-5))  # True
位置:  0  1  2  3
  0  [ 0 -∞ -∞ -∞ ]   位置 0 只看自己
  1  [ 0  0 -∞ -∞ ]   位置 1 看 0,1
  2  [ 0  0  0 -∞ ]   位置 2 看 0,1,2
  3  [ 0  0  0  0 ]   位置 3 看所有

Q 来自一个序列(如解码器),K 和 V 来自另一个序列(如编码器)。用于 encoder-decoder 架构中,让解码器"查询"编码器的信息。

  • 自注意力:Q = K = V 来自同一输入
  • 交叉注意力:Q 来自解码器,K/V 来自编码器
  • 交叉注意力不需要因果掩码(编码器的所有位置对解码器可见)
import torch
import torch.nn as nn
import math

class MultiHeadCrossAttention(nn.Module):
    def __init__(self, d_model: int, num_heads: int):
        super().__init__()
        self.num_heads = num_heads
        self.dk = d_model // num_heads
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, x_q, x_kv):
        B, S_q, _ = x_q.shape
        S_kv = x_kv.shape[1]

        q = self.W_q(x_q).view(B, S_q, self.num_heads, self.dk).transpose(1, 2)
        k = self.W_k(x_kv).view(B, S_kv, self.num_heads, self.dk).transpose(1, 2)
        v = self.W_v(x_kv).view(B, S_kv, self.num_heads, self.dk).transpose(1, 2)

        scores = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk)
        weights = torch.softmax(scores, dim=-1)
        attn = (weights @ v).transpose(1, 2).contiguous().view(B, S_q, -1)
        return self.W_o(attn)

# 测试:解码器 6 个 token 查询编码器 10 个 token
attn = MultiHeadCrossAttention(64, 4)
x_q = torch.randn(2, 6, 64)    # 解码器
x_kv = torch.randn(2, 10, 64)  # 编码器
print(attn(x_q, x_kv).shape)   # (2, 6, 64)
  • Transformer encoder-decoder(机器翻译、语音识别)
  • 多模态模型(图像特征作为 KV,文本作为 Q)

MHA 的高效变体。Q 保持完整的头数,但 K 和 V 使用更少的头(多个 Q 头共享同一组 K/V 头)。

  • 标准 MHA 中 KV cache 的大小与头数成正比,是推理时的内存瓶颈
  • GQA 减少 KV 头数,直接减少 KV cache 大小
  • num_kv_heads = 1 时退化为 Multi-Query Attention(MQA)
  • num_kv_heads = num_heads 时等价于标准 MHA

使用 repeat_interleave 将少量 KV 头扩展到与 Q 头数量匹配:

import torch
import torch.nn as nn
import math

class GroupQueryAttention:
    def __init__(self, d_model, num_heads, num_kv_heads):
        self.num_heads = num_heads
        self.num_kv_heads = num_kv_heads
        self.dk = d_model // num_heads
        self.groups = num_heads // num_kv_heads  # 每组 Q 头共享一个 KV 头

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, num_kv_heads * self.dk)
        self.W_v = nn.Linear(d_model, num_kv_heads * self.dk)
        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, x):
        B, S, _ = x.shape
        q = self.W_q(x).view(B, S, self.num_heads, self.dk).transpose(1, 2)
        k = self.W_k(x).view(B, S, self.num_kv_heads, self.dk).transpose(1, 2)
        v = self.W_v(x).view(B, S, self.num_kv_heads, self.dk).transpose(1, 2)

        # 扩展 KV 头以匹配 Q 头数量
        k = k.repeat_interleave(self.groups, dim=1)  # (B, num_heads, S, dk)
        v = v.repeat_interleave(self.groups, dim=1)

        scores = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk)
        weights = torch.softmax(scores, dim=-1)
        attn = (weights @ v).transpose(1, 2).contiguous().view(B, S, -1)
        return self.W_o(attn)

# 测试:8 个 Q 头,2 个 KV 头
gqa = GroupQueryAttention(d_model=32, num_heads=8, num_kv_heads=2)
x = torch.randn(2, 6, 32)
print(gqa.forward(x).shape)  # (2, 6, 32)
print("W_k 参数量:", gqa.W_k.weight.shape)  # (8, 32) 而非 (32, 32)
  • LLaMA 2(num_kv_heads=8, num_heads=32)
  • Mistral 7B(num_kv_heads=8, num_heads=32)
  • 所有需要减少 KV cache 内存的场景

每个位置只关注其周围固定大小窗口内的位置,而非全部位置。将注意力复杂度从 O(S²) 降低到 O(S·W)。

位置 ii 只能关注位置 jj,当且仅当 ijw|i - j| \le w(窗口大小)。

  • window_size = 0:每个位置只看自己,输出等于 V
  • window_size >= seq_len:等价于全注意力
import torch
import math

def sliding_window_attention(Q, K, V, window_size):
    d_k = Q.shape[-1]
    S = Q.shape[1]
    scores = Q @ K.transpose(-1, -2) / math.sqrt(d_k)

    # 创建滑动窗口掩码
    positions = torch.arange(S)
    mask = (positions.unsqueeze(0) - positions.unsqueeze(1)).abs() > window_size
    scores = scores.masked_fill(mask.unsqueeze(0), float('-inf'))

    weights = torch.softmax(scores, dim=-1)
    return weights @ V

# 测试
Q = K = V = torch.randn(1, 8, 16)
out = sliding_window_attention(Q, K, V, window_size=2)
print(out.shape)  # (1, 8, 16)

# window=0 时输出等于 V
out0 = sliding_window_attention(Q, K, V, window_size=0)
print(torch.allclose(out0, V, atol=1e-5))  # True
  • Longformer(结合全局注意力 token)
  • Mistral(结合 GQA)
  • 处理长文档(>4k tokens)

用核函数特征映射替代 softmax,将注意力复杂度从 O(S²·D) 降低到 O(S·D²)。

标准注意力:softmax(QKT)V\text{softmax}(QK^T)V 需要计算 S×S 矩阵。

线性注意力利用结合律改变计算顺序:

LinearAttn(Q,K,V)=ϕ(Q)(ϕ(K)TV)ϕ(Q)ϕ(K)\text{LinearAttn}(Q,K,V) = \frac{\phi(Q)(\phi(K)^T V)}{\phi(Q) \sum \phi(K)}

先计算 ϕ(K)TV\phi(K)^T V(D×D 矩阵),再乘以 ϕ(Q)\phi(Q),避免了 S×S 矩阵。

使用 ϕ(x)=elu(x)+1\phi(x) = \text{elu}(x) + 1 确保非负性(模拟 softmax 的非负权重)。

import torch
import torch.nn.functional as F

def linear_attention(Q, K, V):
    # 特征映射:elu(x) + 1 保证非负
    Q_prime = F.elu(Q) + 1  # (B, S, D_k)
    K_prime = F.elu(K) + 1  # (B, S, D_k)

    # 关键:先算 K^T V(D_k × D_v),再乘 Q
    KV = torch.bmm(K_prime.transpose(1, 2), V)  # (B, D_k, D_v)
    numerator = torch.bmm(Q_prime, KV)  # (B, S, D_v)

    # 归一化
    K_sum = K_prime.sum(dim=1, keepdim=True)  # (B, 1, D_k)
    denominator = torch.bmm(Q_prime, K_sum.transpose(1, 2))  # (B, S, 1)

    return numerator / (denominator + 1e-6)

# 测试
Q = torch.randn(1, 1024, 64)
K = torch.randn(1, 1024, 64)
V = torch.randn(1, 1024, 64)
out = linear_attention(Q, K, V)
print(out.shape)  # (1, 1024, 64)
  • 优点:长序列时计算效率高(O(S·D²) vs O(S²·D))
  • 缺点:近似质量不如标准 softmax 注意力,在某些任务上性能下降

在自回归生成(逐 token 生成)时,缓存已计算的 K 和 V,避免重复计算。

生成第 tt 个 token 时,标准注意力需要对前 tt 个 token 全部重新计算 K 和 V。KV Cache 将之前的 K/V 存起来,每步只需计算新 token 的 K/V 并追加。

Prefill(首次): 输入 [t0, t1, t2, t3]
  → 计算完整因果注意力
  → 缓存 K_{0:3}, V_{0:3}

Decode(逐步): 输入 [t4]
  → 只计算 t4 的 Q, K, V
  → K_all = concat(K_cache, K_new)
  → V_all = concat(V_cache, V_new)
  → Q_new 对 K_all 做注意力
  → 更新缓存
import torch
import torch.nn as nn
import math

class KVCacheAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.dk = d_model // num_heads
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, x, cache=None):
        B, S_new, _ = x.shape
        q = self.W_q(x).view(B, S_new, self.num_heads, self.dk).transpose(1, 2)
        k = self.W_k(x).view(B, S_new, self.num_heads, self.dk).transpose(1, 2)
        v = self.W_v(x).view(B, S_new, self.num_heads, self.dk).transpose(1, 2)

        if cache is not None:
            k = torch.cat([cache[0], k], dim=2)  # 拼接历史 K
            v = torch.cat([cache[1], v], dim=2)  # 拼接历史 V

        S_total = k.shape[2]
        scores = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk)

        # Prefill 时需要因果掩码;单 token decode 时不需要
        if S_new > 1:
            mask = torch.triu(torch.full((S_new, S_total), float('-inf'),
                              device=x.device), diagonal=S_total - S_new + 1)
            scores = scores + mask

        weights = torch.softmax(scores, dim=-1)
        attn = (weights @ v).transpose(1, 2).contiguous().view(B, S_new, -1)
        return self.W_o(attn), (k, v)

# 测试:增量解码与全量前向结果一致
torch.manual_seed(0)
attn = KVCacheAttention(64, 4)
x = torch.randn(1, 6, 64)

full_out, _ = attn(x)
out1, cache = attn(x[:, :4])
out2, cache = attn(x[:, 4:5], cache=cache)
out3, cache = attn(x[:, 5:6], cache=cache)
inc_out = torch.cat([out1, out2, out3], dim=1)
print("匹配:", torch.allclose(full_out, inc_out, atol=1e-5))

RoPE 通过旋转 Q 和 K 向量来编码位置信息,使得注意力分数自然地依赖于相对位置。

将每个向量的相邻维度对视为二维平面上的点,按位置相关的角度旋转:

[x0,x1][x0cosθx1sinθ,  x0sinθ+x1cosθ][x_0, x_1] \rightarrow [x_0 \cos\theta - x_1 \sin\theta, \; x_0 \sin\theta + x_1 \cos\theta]

角度 θ=pos/100002i/D\theta = \text{pos} / 10000^{2i/D},其中 pos 是位置,i 是维度对的索引。

旋转后 dot(qrot[i],krot[j])\text{dot}(q_{rot}[i], k_{rot}[j]) 只依赖于 iji - j(相对位置),而非绝对位置。这是因为旋转矩阵的正交性:RiTRj=RijR_i^T R_j = R_{i-j}

import torch
import math

def apply_rope(q, k):
    B, S, D = q.shape
    assert D % 2 == 0

    # 计算频率
    positions = torch.arange(S, device=q.device).float()
    dim_pairs = torch.arange(0, D, 2, device=q.device).float()
    freqs = 1.0 / (10000.0 ** (dim_pairs / D))
    angles = positions.unsqueeze(1) * freqs.unsqueeze(0)  # (S, D/2)

    cos_vals = torch.cos(angles).unsqueeze(0)  # (1, S, D/2)
    sin_vals = torch.sin(angles).unsqueeze(0)

    # 拆分为偶数和奇数维度
    q_even, q_odd = q[..., 0::2], q[..., 1::2]
    k_even, k_odd = k[..., 0::2], k[..., 1::2]

    # 旋转
    q_rot = torch.stack([
        q_even * cos_vals - q_odd * sin_vals,
        q_even * sin_vals + q_odd * cos_vals
    ], dim=-1).flatten(-2)

    k_rot = torch.stack([
        k_even * cos_vals - k_odd * sin_vals,
        k_even * sin_vals + k_odd * cos_vals
    ], dim=-1).flatten(-2)

    return q_rot, k_rot

# 测试
q = torch.randn(1, 8, 16)
k = torch.randn(1, 8, 16)
qr, kr = apply_rope(q, k)
print("形状保持:", qr.shape == q.shape)  # True
print("范数保持:", torch.allclose(q.norm(dim=-1), qr.norm(dim=-1), atol=1e-4))  # True
  • LLaMA、GPT-NeoX、Mistral 等几乎所有现代 LLM
  • 替代了传统的正弦位置编码和可学习位置编码

Flash Attention 通过分块计算和在线 softmax 算法,在不显式构建 S×S 注意力矩阵的情况下计算精确的注意力结果,大幅减少内存使用。

标准 softmax 需要两遍扫描(先求 max,再求 exp/sum)。Online softmax 在一遍扫描中完成,通过动态更新 running max 和 running sum:

对每个 KV 块:
  1. 计算局部分数 S_block = Q_block @ K_block^T / sqrt(d)
  2. 更新 running max: new_max = max(old_max, block_max)
  3. 修正累加器: acc *= exp(old_max - new_max)
  4. 累加新块: acc += exp(S_block - new_max) @ V_block
  5. 更新 running sum
最终: output = acc / row_sum
import torch
import math

def flash_attention(Q, K, V, block_size=32):
    B, S, D = Q.shape
    scale = 1.0 / math.sqrt(D)
    output = torch.zeros_like(Q)
    row_max = torch.full((B, S, 1), float('-inf'), device=Q.device)
    row_sum = torch.zeros((B, S, 1), device=Q.device)

    for j_start in range(0, S, block_size):
        j_end = min(j_start + block_size, S)
        K_block = K[:, j_start:j_end]
        V_block = V[:, j_start:j_end]

        scores = torch.bmm(Q, K_block.transpose(1, 2)) * scale  # (B, S, block)
        block_max = scores.max(dim=-1, keepdim=True).values

        new_max = torch.maximum(row_max, block_max)
        # 修正已有累加器
        correction = torch.exp(row_max - new_max)
        output = output * correction
        row_sum = row_sum * correction

        # 累加新块
        exp_scores = torch.exp(scores - new_max)
        output = output + torch.bmm(exp_scores, V_block)
        row_sum = row_sum + exp_scores.sum(dim=-1, keepdim=True)
        row_max = new_max

    return output / row_sum

# 验证与标准注意力结果一致
Q = K = V = torch.randn(1, 64, 32)
out_flash = flash_attention(Q, K, V, block_size=16)
scores = torch.bmm(Q, K.transpose(1, 2)) / math.sqrt(32)
out_ref = torch.bmm(torch.softmax(scores, dim=-1), V)
print("匹配:", torch.allclose(out_flash, out_ref, atol=1e-4))
  • 内存从 O(S²) 降低到 O(S)(不存储完整注意力矩阵)
  • 利用 GPU SRAM 的高带宽,减少 HBM 访问
  • 结果与标准注意力完全一致(不是近似)

                    标准 Attention
          ┌──────────────┼──────────────┐
          ▼              ▼              ▼
      Multi-Head     Causal        Cross-Attn
          │              │
     ┌────┴────┐    ┌────┴────┐
     ▼         ▼    ▼         ▼
   GQA    Sliding  KV Cache  Flash
  (省KV)  Window   (省计算)  (省内存)
          (省计算)

位置编码: RoPE(旋转编码,现代 LLM 标配)
线性化:   Linear Attention(O(n) 但近似)

相关内容