第三章:注意力机制(TorchCode)
第三章:注意力机制
注意力机制是 Transformer 的核心。本章从基础的 Scaled Dot-Product Attention 出发,逐步展开到所有主流变体。
3.1 Scaled Dot-Product Attention(缩放点积注意力)
是什么
注意力机制的核心计算单元。给定查询(Q)、键(K)、值(V),计算每个查询对所有键的相关性权重,然后对值进行加权求和。
数学定义
步骤分解
- 计算相似度矩阵:,形状
(B, seq_q, seq_k) - 缩放:,防止点积值过大导致 softmax 饱和
- 归一化:对每行做 softmax,得到注意力权重
- 加权求和:
为什么要缩放
当 较大时, 的方差约为 。除以 使方差回到 1,避免 softmax 输入过大导致梯度消失。
代码示例
import torch
import math
def scaled_dot_product_attention(Q, K, V):
# Q: (B, seq_q, d_k), K: (B, seq_k, d_k), V: (B, seq_k, d_v)
d_k = Q.shape[-1]
scores = torch.matmul(Q, K.transpose(-1, -2)) / math.sqrt(d_k)
weights = torch.softmax(scores, dim=-1)
return weights @ V # (B, seq_q, d_v)
# 自注意力
Q = K = V = torch.randn(2, 6, 64)
out = scaled_dot_product_attention(Q, K, V)
print(out.shape) # (2, 6, 64)
# 交叉注意力(seq_q != seq_k)
Q2 = torch.randn(1, 3, 16)
K2 = torch.randn(1, 10, 16)
V2 = torch.randn(1, 10, 32)
out2 = scaled_dot_product_attention(Q2, K2, V2)
print(out2.shape) # (1, 3, 32)适用场景
- 所有 Transformer 变体的基础
- 自注意力(Q=K=V 来自同一输入)和交叉注意力(Q 和 K/V 来自不同输入)
3.2 Multi-Head Attention(多头注意力)
是什么
将输入投影到多个子空间(“头”),在每个子空间独立计算注意力,最后拼接并投影回原始维度。
数学定义
为什么需要多头
- 单头注意力只能学习一种注意力模式
- 多头允许模型同时关注不同位置的不同特征(如语法关系、语义关系、位置关系)
- 每个头的维度 ,总计算量与单头相同
结构说明
W_q, W_k, W_v:投影矩阵,形状(d_model, d_model)W_o:输出投影,形状(d_model, d_model)d_k = d_model // num_heads:每个头的维度
代码示例
import torch
import torch.nn as nn
import math
class MultiHeadAttention:
def __init__(self, d_model: int, num_heads: int):
self.num_heads = num_heads
self.dk = d_model // num_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def forward(self, Q, K, V):
B, S_q, _ = Q.shape
S_k = K.shape[1]
# 1. 线性投影
q = self.W_q(Q) # (B, S_q, d_model)
k = self.W_k(K) # (B, S_k, d_model)
v = self.W_v(V) # (B, S_k, d_model)
# 2. 拆分为多头: (B, S, d_model) → (B, num_heads, S, dk)
q = q.view(B, S_q, self.num_heads, self.dk).transpose(1, 2)
k = k.view(B, S_k, self.num_heads, self.dk).transpose(1, 2)
v = v.view(B, S_k, self.num_heads, self.dk).transpose(1, 2)
# 3. 每个头独立计算注意力
scores = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(self.dk)
weights = torch.softmax(scores, dim=-1)
attn = weights @ v # (B, num_heads, S_q, dk)
# 4. 拼接所有头: (B, S_q, d_model)
out = attn.transpose(1, 2).contiguous().view(B, S_q, -1)
# 5. 输出投影
return self.W_o(out)
# 测试
mha = MultiHeadAttention(d_model=64, num_heads=8)
x = torch.randn(2, 10, 64)
print(mha.forward(x, x, x).shape) # (2, 10, 64)3.3 Causal Self-Attention(因果自注意力)
是什么
在标准注意力基础上添加因果掩码,确保每个位置只能看到自己和之前的位置,不能"偷看"未来的 token。这是 GPT 等自回归语言模型的核心。
数学定义
将未来位置的分数设为 ,经过 softmax 后权重变为 0。
掩码实现
使用上三角矩阵(torch.triu)填充 :
import torch
import math
def causal_attention(Q, K, V):
d_k = Q.shape[-1]
scores = Q @ K.transpose(-1, -2) / math.sqrt(d_k)
# 创建上三角掩码(对角线以上为 -inf)
mask = torch.triu(torch.full_like(scores, float('-inf')), diagonal=1)
scores = scores + mask
weights = torch.softmax(scores, dim=-1)
return weights @ V
# 测试
Q = K = V = torch.randn(1, 4, 8)
out = causal_attention(Q, K, V)
# 第 0 个位置只能看到自己,所以输出 = V[0]
print(torch.allclose(out[:, 0], V[:, 0], atol=1e-5)) # True掩码矩阵可视化
位置: 0 1 2 3
0 [ 0 -∞ -∞ -∞ ] 位置 0 只看自己
1 [ 0 0 -∞ -∞ ] 位置 1 看 0,1
2 [ 0 0 0 -∞ ] 位置 2 看 0,1,2
3 [ 0 0 0 0 ] 位置 3 看所有3.4 Multi-Head Cross-Attention(多头交叉注意力)
是什么
Q 来自一个序列(如解码器),K 和 V 来自另一个序列(如编码器)。用于 encoder-decoder 架构中,让解码器"查询"编码器的信息。
与自注意力的区别
- 自注意力:Q = K = V 来自同一输入
- 交叉注意力:Q 来自解码器,K/V 来自编码器
- 交叉注意力不需要因果掩码(编码器的所有位置对解码器可见)
代码示例
import torch
import torch.nn as nn
import math
class MultiHeadCrossAttention(nn.Module):
def __init__(self, d_model: int, num_heads: int):
super().__init__()
self.num_heads = num_heads
self.dk = d_model // num_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def forward(self, x_q, x_kv):
B, S_q, _ = x_q.shape
S_kv = x_kv.shape[1]
q = self.W_q(x_q).view(B, S_q, self.num_heads, self.dk).transpose(1, 2)
k = self.W_k(x_kv).view(B, S_kv, self.num_heads, self.dk).transpose(1, 2)
v = self.W_v(x_kv).view(B, S_kv, self.num_heads, self.dk).transpose(1, 2)
scores = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk)
weights = torch.softmax(scores, dim=-1)
attn = (weights @ v).transpose(1, 2).contiguous().view(B, S_q, -1)
return self.W_o(attn)
# 测试:解码器 6 个 token 查询编码器 10 个 token
attn = MultiHeadCrossAttention(64, 4)
x_q = torch.randn(2, 6, 64) # 解码器
x_kv = torch.randn(2, 10, 64) # 编码器
print(attn(x_q, x_kv).shape) # (2, 6, 64)适用场景
- Transformer encoder-decoder(机器翻译、语音识别)
- 多模态模型(图像特征作为 KV,文本作为 Q)
3.5 Grouped Query Attention(GQA,分组查询注意力)
是什么
MHA 的高效变体。Q 保持完整的头数,但 K 和 V 使用更少的头(多个 Q 头共享同一组 K/V 头)。
为什么需要
- 标准 MHA 中 KV cache 的大小与头数成正比,是推理时的内存瓶颈
- GQA 减少 KV 头数,直接减少 KV cache 大小
- 当
num_kv_heads = 1时退化为 Multi-Query Attention(MQA) - 当
num_kv_heads = num_heads时等价于标准 MHA
实现关键
使用 repeat_interleave 将少量 KV 头扩展到与 Q 头数量匹配:
import torch
import torch.nn as nn
import math
class GroupQueryAttention:
def __init__(self, d_model, num_heads, num_kv_heads):
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.dk = d_model // num_heads
self.groups = num_heads // num_kv_heads # 每组 Q 头共享一个 KV 头
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, num_kv_heads * self.dk)
self.W_v = nn.Linear(d_model, num_kv_heads * self.dk)
self.W_o = nn.Linear(d_model, d_model)
def forward(self, x):
B, S, _ = x.shape
q = self.W_q(x).view(B, S, self.num_heads, self.dk).transpose(1, 2)
k = self.W_k(x).view(B, S, self.num_kv_heads, self.dk).transpose(1, 2)
v = self.W_v(x).view(B, S, self.num_kv_heads, self.dk).transpose(1, 2)
# 扩展 KV 头以匹配 Q 头数量
k = k.repeat_interleave(self.groups, dim=1) # (B, num_heads, S, dk)
v = v.repeat_interleave(self.groups, dim=1)
scores = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk)
weights = torch.softmax(scores, dim=-1)
attn = (weights @ v).transpose(1, 2).contiguous().view(B, S, -1)
return self.W_o(attn)
# 测试:8 个 Q 头,2 个 KV 头
gqa = GroupQueryAttention(d_model=32, num_heads=8, num_kv_heads=2)
x = torch.randn(2, 6, 32)
print(gqa.forward(x).shape) # (2, 6, 32)
print("W_k 参数量:", gqa.W_k.weight.shape) # (8, 32) 而非 (32, 32)适用场景
- LLaMA 2(num_kv_heads=8, num_heads=32)
- Mistral 7B(num_kv_heads=8, num_heads=32)
- 所有需要减少 KV cache 内存的场景
3.6 Sliding Window Attention(滑动窗口注意力)
是什么
每个位置只关注其周围固定大小窗口内的位置,而非全部位置。将注意力复杂度从 O(S²) 降低到 O(S·W)。
掩码规则
位置 只能关注位置 ,当且仅当 (窗口大小)。
边界情况
window_size = 0:每个位置只看自己,输出等于 Vwindow_size >= seq_len:等价于全注意力
代码示例
import torch
import math
def sliding_window_attention(Q, K, V, window_size):
d_k = Q.shape[-1]
S = Q.shape[1]
scores = Q @ K.transpose(-1, -2) / math.sqrt(d_k)
# 创建滑动窗口掩码
positions = torch.arange(S)
mask = (positions.unsqueeze(0) - positions.unsqueeze(1)).abs() > window_size
scores = scores.masked_fill(mask.unsqueeze(0), float('-inf'))
weights = torch.softmax(scores, dim=-1)
return weights @ V
# 测试
Q = K = V = torch.randn(1, 8, 16)
out = sliding_window_attention(Q, K, V, window_size=2)
print(out.shape) # (1, 8, 16)
# window=0 时输出等于 V
out0 = sliding_window_attention(Q, K, V, window_size=0)
print(torch.allclose(out0, V, atol=1e-5)) # True适用场景
- Longformer(结合全局注意力 token)
- Mistral(结合 GQA)
- 处理长文档(>4k tokens)
3.7 Linear Attention(线性注意力)
是什么
用核函数特征映射替代 softmax,将注意力复杂度从 O(S²·D) 降低到 O(S·D²)。
核心思想
标准注意力: 需要计算 S×S 矩阵。
线性注意力利用结合律改变计算顺序:
先计算 (D×D 矩阵),再乘以 ,避免了 S×S 矩阵。
特征映射
使用 确保非负性(模拟 softmax 的非负权重)。
代码示例
import torch
import torch.nn.functional as F
def linear_attention(Q, K, V):
# 特征映射:elu(x) + 1 保证非负
Q_prime = F.elu(Q) + 1 # (B, S, D_k)
K_prime = F.elu(K) + 1 # (B, S, D_k)
# 关键:先算 K^T V(D_k × D_v),再乘 Q
KV = torch.bmm(K_prime.transpose(1, 2), V) # (B, D_k, D_v)
numerator = torch.bmm(Q_prime, KV) # (B, S, D_v)
# 归一化
K_sum = K_prime.sum(dim=1, keepdim=True) # (B, 1, D_k)
denominator = torch.bmm(Q_prime, K_sum.transpose(1, 2)) # (B, S, 1)
return numerator / (denominator + 1e-6)
# 测试
Q = torch.randn(1, 1024, 64)
K = torch.randn(1, 1024, 64)
V = torch.randn(1, 1024, 64)
out = linear_attention(Q, K, V)
print(out.shape) # (1, 1024, 64)优缺点
- 优点:长序列时计算效率高(O(S·D²) vs O(S²·D))
- 缺点:近似质量不如标准 softmax 注意力,在某些任务上性能下降
3.8 KV Cache Attention
是什么
在自回归生成(逐 token 生成)时,缓存已计算的 K 和 V,避免重复计算。
为什么需要
生成第 个 token 时,标准注意力需要对前 个 token 全部重新计算 K 和 V。KV Cache 将之前的 K/V 存起来,每步只需计算新 token 的 K/V 并追加。
工作流程
Prefill(首次): 输入 [t0, t1, t2, t3]
→ 计算完整因果注意力
→ 缓存 K_{0:3}, V_{0:3}
Decode(逐步): 输入 [t4]
→ 只计算 t4 的 Q, K, V
→ K_all = concat(K_cache, K_new)
→ V_all = concat(V_cache, V_new)
→ Q_new 对 K_all 做注意力
→ 更新缓存代码示例
import torch
import torch.nn as nn
import math
class KVCacheAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
self.num_heads = num_heads
self.dk = d_model // num_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def forward(self, x, cache=None):
B, S_new, _ = x.shape
q = self.W_q(x).view(B, S_new, self.num_heads, self.dk).transpose(1, 2)
k = self.W_k(x).view(B, S_new, self.num_heads, self.dk).transpose(1, 2)
v = self.W_v(x).view(B, S_new, self.num_heads, self.dk).transpose(1, 2)
if cache is not None:
k = torch.cat([cache[0], k], dim=2) # 拼接历史 K
v = torch.cat([cache[1], v], dim=2) # 拼接历史 V
S_total = k.shape[2]
scores = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk)
# Prefill 时需要因果掩码;单 token decode 时不需要
if S_new > 1:
mask = torch.triu(torch.full((S_new, S_total), float('-inf'),
device=x.device), diagonal=S_total - S_new + 1)
scores = scores + mask
weights = torch.softmax(scores, dim=-1)
attn = (weights @ v).transpose(1, 2).contiguous().view(B, S_new, -1)
return self.W_o(attn), (k, v)
# 测试:增量解码与全量前向结果一致
torch.manual_seed(0)
attn = KVCacheAttention(64, 4)
x = torch.randn(1, 6, 64)
full_out, _ = attn(x)
out1, cache = attn(x[:, :4])
out2, cache = attn(x[:, 4:5], cache=cache)
out3, cache = attn(x[:, 5:6], cache=cache)
inc_out = torch.cat([out1, out2, out3], dim=1)
print("匹配:", torch.allclose(full_out, inc_out, atol=1e-5))3.9 Rotary Position Embedding(RoPE,旋转位置编码)
是什么
RoPE 通过旋转 Q 和 K 向量来编码位置信息,使得注意力分数自然地依赖于相对位置。
核心思想
将每个向量的相邻维度对视为二维平面上的点,按位置相关的角度旋转:
角度 ,其中 pos 是位置,i 是维度对的索引。
关键性质
旋转后 只依赖于 (相对位置),而非绝对位置。这是因为旋转矩阵的正交性:。
代码示例
import torch
import math
def apply_rope(q, k):
B, S, D = q.shape
assert D % 2 == 0
# 计算频率
positions = torch.arange(S, device=q.device).float()
dim_pairs = torch.arange(0, D, 2, device=q.device).float()
freqs = 1.0 / (10000.0 ** (dim_pairs / D))
angles = positions.unsqueeze(1) * freqs.unsqueeze(0) # (S, D/2)
cos_vals = torch.cos(angles).unsqueeze(0) # (1, S, D/2)
sin_vals = torch.sin(angles).unsqueeze(0)
# 拆分为偶数和奇数维度
q_even, q_odd = q[..., 0::2], q[..., 1::2]
k_even, k_odd = k[..., 0::2], k[..., 1::2]
# 旋转
q_rot = torch.stack([
q_even * cos_vals - q_odd * sin_vals,
q_even * sin_vals + q_odd * cos_vals
], dim=-1).flatten(-2)
k_rot = torch.stack([
k_even * cos_vals - k_odd * sin_vals,
k_even * sin_vals + k_odd * cos_vals
], dim=-1).flatten(-2)
return q_rot, k_rot
# 测试
q = torch.randn(1, 8, 16)
k = torch.randn(1, 8, 16)
qr, kr = apply_rope(q, k)
print("形状保持:", qr.shape == q.shape) # True
print("范数保持:", torch.allclose(q.norm(dim=-1), qr.norm(dim=-1), atol=1e-4)) # True适用场景
- LLaMA、GPT-NeoX、Mistral 等几乎所有现代 LLM
- 替代了传统的正弦位置编码和可学习位置编码
3.10 Flash Attention(分块注意力)
是什么
Flash Attention 通过分块计算和在线 softmax 算法,在不显式构建 S×S 注意力矩阵的情况下计算精确的注意力结果,大幅减少内存使用。
核心算法:Online Softmax
标准 softmax 需要两遍扫描(先求 max,再求 exp/sum)。Online softmax 在一遍扫描中完成,通过动态更新 running max 和 running sum:
对每个 KV 块:
1. 计算局部分数 S_block = Q_block @ K_block^T / sqrt(d)
2. 更新 running max: new_max = max(old_max, block_max)
3. 修正累加器: acc *= exp(old_max - new_max)
4. 累加新块: acc += exp(S_block - new_max) @ V_block
5. 更新 running sum
最终: output = acc / row_sum代码示例
import torch
import math
def flash_attention(Q, K, V, block_size=32):
B, S, D = Q.shape
scale = 1.0 / math.sqrt(D)
output = torch.zeros_like(Q)
row_max = torch.full((B, S, 1), float('-inf'), device=Q.device)
row_sum = torch.zeros((B, S, 1), device=Q.device)
for j_start in range(0, S, block_size):
j_end = min(j_start + block_size, S)
K_block = K[:, j_start:j_end]
V_block = V[:, j_start:j_end]
scores = torch.bmm(Q, K_block.transpose(1, 2)) * scale # (B, S, block)
block_max = scores.max(dim=-1, keepdim=True).values
new_max = torch.maximum(row_max, block_max)
# 修正已有累加器
correction = torch.exp(row_max - new_max)
output = output * correction
row_sum = row_sum * correction
# 累加新块
exp_scores = torch.exp(scores - new_max)
output = output + torch.bmm(exp_scores, V_block)
row_sum = row_sum + exp_scores.sum(dim=-1, keepdim=True)
row_max = new_max
return output / row_sum
# 验证与标准注意力结果一致
Q = K = V = torch.randn(1, 64, 32)
out_flash = flash_attention(Q, K, V, block_size=16)
scores = torch.bmm(Q, K.transpose(1, 2)) / math.sqrt(32)
out_ref = torch.bmm(torch.softmax(scores, dim=-1), V)
print("匹配:", torch.allclose(out_flash, out_ref, atol=1e-4))优势
- 内存从 O(S²) 降低到 O(S)(不存储完整注意力矩阵)
- 利用 GPU SRAM 的高带宽,减少 HBM 访问
- 结果与标准注意力完全一致(不是近似)
3.11 注意力变体总结
标准 Attention
│
┌──────────────┼──────────────┐
▼ ▼ ▼
Multi-Head Causal Cross-Attn
│ │
┌────┴────┐ ┌────┴────┐
▼ ▼ ▼ ▼
GQA Sliding KV Cache Flash
(省KV) Window (省计算) (省内存)
(省计算)
位置编码: RoPE(旋转编码,现代 LLM 标配)
线性化: Linear Attention(O(n) 但近似)