Attention 面试 Cheat Sheet
公式推导 + From-Scratch 代码 + 25 高频题(L1 必会 · L2 进阶 · L3 顶级 lab)
§0 TL;DR Cheat Sheet
一页拿下面试核心要点(详见后文 §2–§9 推导)。
- 公式:$\text{Attention}(Q,K,V) = \text{softmax}\!\left(\dfrac{QK^\top}{\sqrt{d_k}}\right) V$。
- 为什么除 √d_k:若 $q_i, k_i \sim \mathcal{N}(0,1)$ 独立,$q\cdot k$ 方差 $= d_k$;除 $\sqrt{d_k}$ 把方差拉回 1,避免 softmax 饱和。
- Multi-Head:把 $D$ 拆成 $H$ 个 head,每个 head 在不同 subspace 独立做 attention,concat 后 $W_o$ 投影。固定 $D$ 且 $d_k=D/H$ 时,标准 MHA 参数量 $\approx 4D^2$(不随 $H$ 变);MQA/GQA 下 K/V 投影变小。
- Self vs Cross:Self 的 Q/K/V 同源;Cross 的 Q 来自 query stream,K/V 来自 context stream(encoder output / image tokens / text embedding)。
- Causal mask vs Padding mask:前者用下三角阻断未来;后者用
[B,1,1,L_k]屏蔽 padding 列。 - 复杂度:$O(B H L^2 d_k)$ 时间,$O(B H L^2)$ score 显存——长序列瓶颈在二次项。
- 易踩坑:全 masked row → softmax NaN;FP16 下 $QK^\top$ 可能 overflow;attention weight ≠ 因果解释。
§1 Attention 直觉
Attention 的本质是 可学习的检索(learned retrieval):
- 每个 query("我现在需要什么信息?")
- 对所有 key("每个位置宣称自己能提供什么")计算相似度
- 用 softmax 归一化得到 权重分布
- 对所有 value("每个位置实际提供的内容")做加权求和
对比 RNN:RNN 把过去信息压缩进一个固定大小的 hidden state,长序列必丢信息;attention 在每一步都直接、全局、动态地检索过去所有位置,因此适合长程依赖。
"Q/K/V 是同一个向量经过三个不同投影"——这点要主动说明,因为面试常考新手会以为 Q/K/V 是三份不同输入。
§2 Scaled Dot-Product Attention
2.1 公式
$$\boxed{\;\text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{QK^\top}{\sqrt{d_k}}\right) V\;}$$
形状:
- $Q \in \mathbb{R}^{L_q \times d_k}$, $K \in \mathbb{R}^{L_k \times d_k}$, $V \in \mathbb{R}^{L_k \times d_v}$
- Scores $QK^\top \in \mathbb{R}^{L_q \times L_k}$(每个 query 对所有 key 的相似度)
- Softmax over key 维度:每个 query 行的权重和为 1
- Output $\in \mathbb{R}^{L_q \times d_v}$
2.2 为什么除以 √d_k(必考题,要会推方差)
假设 $q, k \in \mathbb{R}^{d_k}$ 的每个分量独立同分布,$q_i, k_i \sim \mathcal{N}(0,1)$。考虑点积:
$$q \cdot k = \sum_{i=1}^{d_k} q_i k_i$$
由独立性,每项 $q_i k_i$ 均值 $= \mathbb{E}[q_i]\mathbb{E}[k_i] = 0$,方差 $= \mathbb{E}[q_i^2]\mathbb{E}[k_i^2] = 1$。所以:
$$\mathbb{E}[q\cdot k] = 0, \quad \text{Var}[q\cdot k] = d_k$$
当 $d_k$ 大(如 64、128),$q\cdot k$ 的典型量级是 $\sqrt{d_k}$,进入 softmax 后最大 logit 容易抢走绝大部分概率,softmax 进入饱和区,梯度量级显著缩小,训练收敛变慢甚至停滞。除以 $\sqrt{d_k}$ 把方差拉回 1,减轻饱和、改善梯度尺度。
即使除了 √d_k,FP16 下 QK^T 自己累加时也可能 overflow(fp16 max ≈ 65504)。生产实现用 fused SDPA / FlashAttention 或 fp32 accumulation 解决。torch.softmax 内部有 log-sum-exp 稳定化(减最大 logit 再 exp),但那是在 softmax 一步内做的,挡不住 matmul 累加的 overflow。
2.3 Mask 与 NaN 陷阱(💣 经典 bug,面试必问)
标准做法:把要屏蔽的位置 score 填成 $-\infty$,softmax 后那些位置概率 = 0。
但有个陷阱:如果某一行所有 key 都被 mask(如 query 0 在 cross-attn 中 context 全 padding;causal + 左 padding;某 query 后无任何合法 token),那一行 score 全是 $-\infty$,softmax 输出:
$$\text{softmax}([-\infty, -\infty, ..., -\infty]) = \text{NaN}$$
因为分子分母都是 $e^{-\infty} = 0$,$0/0 = $ NaN,污染整个 batch 的梯度。
# 检测全 mask 行
all_masked = (~mask).all(dim=-1, keepdim=True) # [..., L_q, 1]
# 临时给该行放开 (避免 NaN)
safe_mask = mask | all_masked
scores = scores.masked_fill(~safe_mask, float("-inf"))
# Softmax 正常计算
weights = F.softmax(scores, dim=-1)
# 把全 mask 行的输出强制设为 0 (否则会得到均匀分布)
weights = weights.masked_fill(all_masked, 0.0)
本实现 / F.scaled_dot_product_attention:True = keep
nn.MultiheadAttention 的 attn_mask / key_padding_mask:True = mask out(相反!)
面试写代码前先问面试官约定,或主动声明你的约定,否则容易被搞反。
2.4 代码(核心 20 行)
def scaled_dot_product_attention(Q, K, V, mask=None, dropout_p=0.0, training=True):
d_k = Q.size(-1)
scores = Q @ K.transpose(-2, -1) # [..., L_q, L_k]
scores = scores / math.sqrt(d_k) # ← 关键 scale
if mask is not None:
all_masked = (~mask).all(dim=-1, keepdim=True)
safe_mask = mask | all_masked
scores = scores.masked_fill(~safe_mask, float("-inf"))
else:
all_masked = None
weights = F.softmax(scores, dim=-1)
if all_masked is not None:
weights = weights.masked_fill(all_masked, 0.0) # NaN 防护
if dropout_p > 0.0 and training:
weights = F.dropout(weights, p=dropout_p)
return weights @ V, weights # output, weights
§3 Multi-Head Attention
3.1 公式
$$\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \dots, \text{head}_H) W_o$$
$$\text{head}_h = \text{Attention}(Q W_q^{(h)},\; K W_k^{(h)},\; V W_v^{(h)})$$
每个 head 的 $W_q^{(h)}, W_k^{(h)}, W_v^{(h)} \in \mathbb{R}^{D \times d_k}$,$d_k = D/H$。工程上把 H 个 head 的投影矩阵 concat 成一个 $D \times D$ 大矩阵,一次 matmul 跑完所有 head 的投影(GPU 友好):
Input X [B, L, D]
│
│ W_q, W_k, W_v ∈ R^{D×D} (每个 = concat of H 个 W^{(h)} ∈ R^{D×d_k})
↓
Q, K, V [B, L, D]
│
│ reshape [B, L, D] → [B, L, H, d_k] → transpose → [B, H, L, d_k]
↓
对每个 head 独立做 Scaled-Dot-Product Attention (batched matmul, 并行)
↓
heads [B, H, L_q, d_k]
│
│ transpose + reshape → [B, L_q, D] (concat heads)
↓
W_o ∈ R^{D×D} → Output [B, L_q, D]
3.2 为什么要 multi-head(不是单 head 也行吗?)
- 不同 subspace:每个 head 在自己的 $d_k$ 维子空间里学一种关系模式(语法、共指、远距依赖、局部 n-gram...)
- 表达力:单 head 只能学一种 attention 模式;H 个 head 在 inference 时并行给出 H 种不同的 weighted sum 结果
- 参数效率:$d_k = D/H$ 而不是 $D$,所以参数量不会随 H 线性增加
- 面试常问:head 越多越好吗?不。$d_k = D/H$ 太小(如 $d_k < 16$)会让每个 head 表达力受限;Mistral / LLaMA 用 head_dim ≈ 64-128 是 sweet spot
3.3 参数量与 FLOPs
| 组件 | 形状 | 参数量 |
|---|---|---|
| $W_q$ | $D \times D$ | $D^2$ |
| $W_k$ | $D \times D$ | $D^2$ |
| $W_v$ | $D \times D$ | $D^2$ |
| $W_o$ | $D \times D$ | $D^2$ |
| 总计 | $4D^2$(不随 $H$ 变) |
FLOPs(单次 self-attention forward,$L_q = L_k = L$):
- QKV projection: $3 \cdot 2 B L D^2 = 6 B L D^2$
- $QK^\top$: $2 B H L^2 d_k = 2 B L^2 D$
- Softmax weight × V: $2 B L^2 D$
- Output projection $W_o$: $2 B L D^2$
- 总计 $\approx 8 B L D^2 + 4 B L^2 D$——前者随 $L$ 线性,后者随 $L$ 平方(长序列瓶颈)
3.4 代码(核心 30 行)
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads, dropout=0.0, bias=False):
super().__init__()
assert d_model % num_heads == 0
self.d_model, self.num_heads, self.d_k = d_model, num_heads, d_model // num_heads
# 合并 H 个 W^(h) 成一个 [D, D] 矩阵
self.W_q = nn.Linear(d_model, d_model, bias=bias)
self.W_k = nn.Linear(d_model, d_model, bias=bias)
self.W_v = nn.Linear(d_model, d_model, bias=bias)
self.W_o = nn.Linear(d_model, d_model, bias=bias)
self.dropout_p = dropout
def _split(self, x): # [B, L, D] → [B, H, L, d_k]
B, L, _ = x.shape
return x.view(B, L, self.num_heads, self.d_k).transpose(1, 2)
def _merge(self, x): # [B, H, L, d_k] → [B, L, D]
B, _, L, _ = x.shape
return x.transpose(1, 2).contiguous().view(B, L, self.d_model)
def forward(self, query, key, value, mask=None):
Q = self._split(self.W_q(query))
K = self._split(self.W_k(key))
V = self._split(self.W_v(value))
if mask is not None:
if mask.dim() == 2: mask = mask.unsqueeze(0).unsqueeze(0) # [1,1,L_q,L_k]
elif mask.dim() == 3: mask = mask.unsqueeze(1) # [B,1,L_q,L_k]
# dim=4: 已对齐
out, w = scaled_dot_product_attention(Q, K, V, mask=mask, dropout_p=self.dropout_p, training=self.training)
return self.W_o(self._merge(out)), w
§4 Self / Cross / Causal / Padding
4.1 Self vs Cross Attention(必考)
| Self-Attention | Cross-Attention | |
|---|---|---|
| Q 来源 | $X$ | $X_\text{decoder}$ / latent / learnable queries |
| K, V 来源 | $X$(同源) | $X_\text{encoder}$ / context / memory |
| $L_q$ vs $L_k$ | 相等 | 可以不同 |
| 典型 mask | causal (decoder) 或 padding (encoder) | K/V 端 padding mask(不用 causal) |
| 用途 | 内部位置相关性 | 从外部 memory 检索相关信息 |
| 例子 | BERT 各层;GPT 各层;ViT | Transformer Decoder 第二子层;DETR;Perceiver;Stable Diffusion (image Q × text K/V) |
4.2 Causal Mask(Decoder / GPT)
下三角矩阵(含对角线):第 $i$ 行可以看 $j \le i$ 的 key。
def causal_mask(L, device=None):
return torch.tril(torch.ones(L, L, dtype=torch.bool, device=device))
# L=4 →
# [[T F F F]
# [T T F F]
# [T T T F]
# [T T T T]]
4.3 Padding Mask(变长序列)
每个 sample 有效长度不同,padding token 不应被 attend:
def padding_mask(lengths, max_len=None):
if max_len is None: max_len = int(lengths.max())
idx = torch.arange(max_len, device=lengths.device).unsqueeze(0).expand(len(lengths), -1)
return idx < lengths.unsqueeze(1) # [B, L] True=valid, False=padding
# 用法:必须 unsqueeze 成 [B, 1, 1, L_k] 才能 broadcast 到 MHA 内部 [B, H, L_q, L_k]
pmask = padding_mask(lengths).unsqueeze(1).unsqueeze(1) # [B, 1, 1, L_k]
out, _ = mha(x, x, x, mask=pmask)
两个 mask 取 AND:combined = causal_mask & padding_mask_4d。注意 broadcast 维度对齐:causal 是 [L,L],padding 是 [B,1,1,L_k],AND 出来 [B,1,L,L]。
§5 复杂度分析
| Time | Memory | 瓶颈 | |
|---|---|---|---|
| RNN | $O(L \cdot D^2)$ | $O(D)$ | 顺序计算不可并行 |
| Self-Attention | $O(L^2 \cdot D)$ | $O(L^2 + L \cdot D)$ | $L^2$ score 矩阵(长序列) |
| Conv (kernel $k$) | $O(L \cdot k \cdot D^2)$ | $O(D)$ | 感受野有限 |
关键点:
- Self-attention 的 $L^2$ 项计算可以接受(GPU 并行),但 $L^2$ 显存(score 矩阵)是真正瓶颈——这是 Flash Attention 攻击的痛点
- LLM inference 时 prefill stage 是 $O(L^2)$;decode stage 用 KV cache 后每步是 $O(L)$(见 §6)
- 当 $L \approx D$ 时,attention 与 FFN 时间相当;当 $L \gg D$ 时,attention 占绝大部分时间
§6 KV Cache + MQA / GQA
6.1 KV Cache(autoregressive inference 关键优化)
问题:GPT 自回归生成时,每生成一个 token,把整个 prefix 重新过 forward——$t$ 步累计 $O(t^2)$ 重复计算。
解:把每层 $K^{(\ell)}, V^{(\ell)}$ 缓存。生成第 $t+1$ 个 token 时:
- 只对新 token 算 $q_{t+1}, k_{t+1}, v_{t+1}$(1 × D 大小)
- 将 $k_{t+1}, v_{t+1}$ append 到 cache
- 新 $q$ 对整个 cache 做 attention($O(t)$ 不是 $O(t^2)$)
KV cache 是 推理优化,训练时不能用——训练时所有位置同时做 attention,没有"逐个生成"的概念。
KV cache 显存(per sample):
$$\text{KV cache} = L_\text{ctx} \cdot n_\text{layers} \cdot \underbrace{2}_{K, V} \cdot H_\text{kv} \cdot d_\text{head} \cdot \text{bytes\_per\_elem}$$
注意:MQA/GQA 下 $H_\text{kv} \ll H$,cache 显著缩减。对 LLaMA-2-70B(GQA, $H_\text{kv}=8$)、$L_\text{ctx}=4096$、80 层、fp16:约 1.25 GB / sample——这就是 LLaMA-2 用 GQA 不用 MHA 的原因(vanilla MHA 会到 10 GB / sample)。
6.2 MQA / GQA(attack KV cache 显存)
| 变体 | Q heads | K/V heads | KV cache 缩减 | 用在哪 |
|---|---|---|---|---|
| MHA (Vanilla) | $H$ | $H$ | 1× | 原始 Transformer |
| MQA (Multi-Query) | $H$ | 1 | $H \times$ | PaLM, Falcon |
| GQA (Grouped-Query) | $H$ | $G$($1 < G < H$) | $H/G \times$ | LLaMA-2/3, Mistral |
核心:多个 Q head 共享一组 K/V。MQA 极端但质量略降;GQA 是折中(如 H=32, G=8),显存 / 带宽降 4 倍,质量基本不掉。
MQA/GQA 减少的是 KV cache 显存 + 显存带宽,不是 Q projection 计算(Q 头数不变)。面试常被反问"减了什么"。
§7 FlashAttention 核心 Trick
问题:标准 attention 需要物化 $L \times L$ 的 score 矩阵,HBM 读写 IO 是瓶颈(不是 FLOPs)。
FlashAttention 思路(IO-aware exact attention,不是近似):
- Block Tiling:把 $Q, K, V$ 切成 block,每次只把一个 $Q$ block 和一个 $K, V$ block 加载到 SRAM
- Online Softmax:边算边维护 running max $m$ 和 running sum $\ell$,避免一次性物化全部 scores
- Recompute on backward:反向时重算 attention 而不存 $L^2$ scores
效果:
- 避免 materialize 完整 $L \times L$ 的 scores / probs 矩阵到 HBM
- 论文给出的 HBM IO 复杂度约 $O(L^2 d^2 / M + Ld)$,对比标准 attention 的 $O(L^2 + Ld)$ HBM traffic——当 $L$ 大、$M$(SRAM)合适时 IO 减少显著
- 显存峰值从 $O(L^2)$ 降到 $O(L)$(不存中间 scores)
- 典型速度提升 2-4 倍(取决于 sequence length & GPU 架构)
- 数学上完全等价(exact attention,不是 sparse / linear approximation)
v1 (2022):online softmax + block tiling + recompute。v2 (2023):换内外循环 (Q-outer, KV-inner) + 更好 warp-level parallelism + 减少非 matmul FLOPs。v3 (2024):针对 H100 Hopper,使用 WGMMA / TMA / FP8 + asynchronous pipeline。面试一般问 v1/v2,会问到 online softmax 细节。
§8 Position Encoding (RoPE / ALiBi / Absolute)
| 方法 | 原理 | 外推性 | 用在哪 |
|---|---|---|---|
| Sinusoidal absolute | 固定 sin/cos 位置向量加到 input embedding | 位置编码本身可定义任意长度,但模型未必学到长度外的泛化 | 原始 Transformer (Vaswani 2017) |
| Learned absolute | 把位置当 token,学一个 embedding 表 | 差(位置 embedding 表是定长,硬限制) | BERT, GPT-2 |
| RoPE (Rotary) | 对 $Q, K$ 做位置相关的旋转:$q_m \to q_m e^{im\theta}$(复数视角)——位置相关项通过相对位移 $m-n$ 进入内积(内容向量仍影响分数) | 中等(自然包含相对位置;长度外推需 NTK-aware / YaRN) | LLaMA-1/2/3, Mistral, Qwen |
| ALiBi | 在 score 上加位置距离 bias:$\text{score}_{ij} - m \cdot \lvert i-j \rvert$ | 好(线性 bias 自然外推) | BLOOM, MPT |
8.1 Attention Sink(高级题)
训练好的 LLM 在 decode 时,注意力会异常集中在前 1-4 个 token(特别是 [BOS] / 第一个 token),即使内容无关。这种现象叫 attention sink。常见直觉解释:softmax 强制权重和为 1,当一个 query 实际不想 attend 任何 key 时,需要一个"垃圾位"来吸收概率质量;又因为 early tokens 对所有后续 token 都可见,训练中容易自然形成全局 sink。StreamingLLM (Xiao et al., ICLR 2024) 利用这个现象做长序列推理(保留 attention sink + 滑动窗口)。
§9 Attention in Diffusion(generative 方向必问)
对 diffusion 背景的候选人,几乎必问 attention 在生成模型里的用法。
9.1 Latent Diffusion (Stable Diffusion) 里的 Cross-Attention
Image latent (z_t) [B, C, H, W]
│
│ flatten to tokens [B, HW, D]
↓
Self-Attention (Q=K=V from image)
↓
Cross-Attention:
Q = image tokens [B, HW, D]
K, V = text embedding [B, L_text, D] ← 文本条件
↓
FFN → next layer
关键点:
- Text-to-image conditioning 通过 cross-attention 实现:image tokens 是 query,text embedding 是 key/value
- Classifier-Free Guidance (CFG):两次 forward(with text / without text)做差值。$\epsilon$-pred 时:$\epsilon_\text{CFG} = \epsilon_\text{uncond} + s (\epsilon_\text{cond} - \epsilon_\text{uncond})$;v-pred / x0-pred 时换成对应预测量,线性 guidance 形式类似
- SD / SDXL 的 U-Net 在多个 spatial resolution 的 transformer block 里都有 self-attn 与 cross-attn 交替
- DiT (Diffusion Transformer) 把 U-Net 换成 pure Transformer,conditioning 通过 AdaLN / cross-attn / token-concat 等方式
9.2 Attention 在 video diffusion
- Spatial attention:每帧内部(image patches 之间)
- Temporal attention:跨帧(同位置在不同时间步之间)
- Spatiotemporal / full attention:所有 frame × all positions——最昂贵,长视频不可行
- Long video 的 attention 是开放问题($L \sim 10^4$-$10^5$ token),常见路线:因式化(空间 + 时间交替)、sparse window、hierarchical pooling、chunked attention
§10 25 高频面试题
codex (gpt-5.5 xhigh) 作为顶级 lab 面试官视角列的,按难度分 3 档。每题点开看答案要点 + 易踩坑。
L1必会题(任何 ML 工程岗都会问)
Q1.Attention 公式是什么?
- $\text{softmax}(QK^\top / \sqrt{d_k}) V$
- Softmax over keys 维度
- 输出是 value 的加权和
把 softmax 维度写到 query 维。
Q2.为什么除以 √d_k?
- 若 $q_i, k_i$ 独立零均值单位方差
- Dot product 方差约 $d_k$
- 缩放后方差回到 1,避免 softmax 饱和
只说"防止数值太大",不给方差推导。
Q3.Q/K/V 分别代表什么?
- Q 是检索请求
- K 是匹配索引
- V 是被聚合内容
说 Q/K/V 是三份不同输入;self-attn 中它们同源但投影不同。
Q4.Multi-head 为什么有用?
- 不同子空间建模不同关系
- 多种位置/语义模式并行
- Concat 后再融合
说 head 越多一定越好。实际 $d_k$ 太小会限制表达力。
Q5.MHA 参数量随 head 数怎么变?
- 固定 $D$ 且 $d_k = D/H$(标准 MHA)
- $W_q + W_k + W_v + W_o$ 共 $4D^2$,不随 $H$ 变
- 但若用 MQA/GQA,K/V 投影矩阵会变小($H_\text{kv} < H$ 个 head)
- 这就是为什么"head 数是免费的"在标准 MHA 下成立,但在 MQA/GQA 下有显存收益
误以为 head 多参数也线性多 H 倍;或忘了 MQA/GQA 改变了 K/V 投影维度。
Q6.Self-attention 和 cross-attention 区别?
- Self: Q/K/V 同源
- Cross: Q 来自 target,K/V 来自 context
- Cross 常用于 encoder-decoder、diffusion text conditioning
只说"cross 有两个输入",不说明 Q 与 KV 来源。
Q7.写 causal mask 怎么做?
torch.tril(torch.ones(L, L, dtype=torch.bool))- 明确说 True=keep 还是 True=mask(API 间不一致)
- Broadcast 到
[B, H, L, L]或让框架隐式 broadcast
上下三角写反;忘记 broadcast 维度对齐。
Q8.Padding mask mask 的是哪一维?
- 通常 mask key/value 列(让 padding 位置概率为 0)
- Shape 可为
[B, 1, 1, L_k]对齐 head + query 维 - 注意:mask key 列不足以让 padded query 输出为 0;padded query 行通常用 loss ignore / output zeroing / packed sequence 等手段单独处理
以为 padding mask 一手包办——它只防止"看到 padding",但 padded query 自己的输出还需要外部处理。
Q9.Attention 复杂度?
- 时间 $O(B H L_q L_k d_k) = O(B L^2 D)$
- Score memory $O(B H L_q L_k)$
- 长序列瓶颈是二次项
只说 $O(n^2)$,漏 head 和 hidden 维。
Q10.Attention dropout 放在哪里?
- 放在 softmax weights 之后、与 V matmul 之前
- Training 才启用,eval 时关闭
- Dropout 后权重行和不一定是 1(期望意义上为 1)
Sanity check 时还要求 dropout 后 row-sum = 1(错的)。
L2进阶题(research-oriented 岗位)
Q11.手推 softmax 的 Jacobian。
- $y_i = \dfrac{e^{x_i}}{\sum_j e^{x_j}}$
- $\dfrac{\partial y_i}{\partial x_j} = y_i (\delta_{ij} - y_j)$
- 矩阵形式:$J = \text{diag}(y) - yy^\top$
只写对角项,漏交叉项 $-y_i y_j$。
Q12.用 -∞ 做 mask 有什么坑?
- 正常情况 masked 位置 softmax 概率为 0 ✓
- 全 masked row → softmax 输出 NaN($0/0$)
- 修复路径:先避免 all-
-inf行(临时放开),softmax 后把该行 weights 与 output 强制清 0,并确保该 query 不进入 loss / 残差累积 - Fused kernel / API 对 sentinel 数值有约束;fp16 下用一个 dtype-safe 大负数(如
finfo(dtype).min)更稳
以为 -inf 永远安全;或只在 softmax 后清 0 而不防 NaN。
Q13.Log-sum-exp trick 是什么?
- softmax 前先减 max(logits),等价不改变概率
- 防止 $e^{x_i}$ overflow(fp32 max ≈ 3.4e38,但 $e^{100}$ 已经溢出)
- $\log \sum_j e^{x_j} = m + \log \sum_j e^{x_j - m}$ 其中 $m = \max_j x_j$
忘了 $QK^\top$ overflow 可能发生在 softmax 之前(matmul 累加阶段)。
Q14.PyTorch nn.MultiheadAttention 的 in_proj_weight 顺序?
- Shape
[3D, D] - 顺序:Q, K, V(cat dim=0)
- Linear weight 是
[out, in],所以cat([W_q.weight, W_k.weight, W_v.weight], dim=0)
拼成 K/Q/V 或转置 weight。
Q15.attn_mask 和 key_padding_mask 区别?
attn_mask控制 query-key pair 级别(一般是 causal)key_padding_mask控制 key token 整体可见性(一般是 padding)- 两者 bool 语义:
nn.MultiheadAttention是 True = mask out;F.scaled_dot_product_attention的 bool mask 是 True = keep(相反!) - 同时用时,在 mask-out 语义下合并是 OR(任一为 True 就屏蔽);在 keep 语义下是 AND(两者都 True 才保留)
不查 API 文档直接套用 True/False;或把 AND/OR 搞反。
Q16.Cross-attention 中 L_q 和 L_k 能否不同?
- 可以——这正是 cross-attention 的常态
- Scores shape 是 $[L_q, L_k]$
- Mask 必须对齐 key 维度
默认 cross-attn 必须等长。
Q17.为什么需要 output projection W_o?
- 融合不同 head 的输出
- 映射回 $d_\text{model}$ 与残差相加
- 给模型学习 head 间组合(不是简单 concat)
以为 concat 后已经结束。
Q18.Pre-norm vs post-norm 对 attention block 的影响?
- Pre-norm:
x + Attn(LN(x)),深层训练更稳定,gradient 沿残差路径相对保持 - Post-norm:
LN(x + Attn(x)),Vaswani 原论文用,超深时需 warmup / careful init - 多数 decoder-only LLM 用 pre-norm(常配 RMSNorm 变体),但具体架构有例外
把 norm 位置当纯工程细节,或说"现代 LLM 都用 pre-norm" 太绝对。
Q19.Attention weight 等于"模型解释"吗?
- 可视化有参考价值(注意聚焦位置)
- 但 不等于因果解释
- Value 路径和后续层都会改变实际贡献
- Jain & Wallace "Attention is not Explanation" (2019)
直接把高 attention 权重当成"模型理由"。
Q20.Mixed-precision attention 注意什么?
- fp32 accumulation:matmul 累加 / softmax 关键步骤在 fp32 完成,再 cast 回低精度
- Softmax max-subtraction(log-sum-exp)防 exp overflow,PyTorch
F.softmax内部已做 - Mask sentinel:fp16 下用
torch.finfo(dtype).min而非字面 -inf - BF16 vs FP16:BF16 动态范围与 fp32 相近,更适合 attention;fp16 表数范围窄,QK^T 易 overflow
- Fused kernels(FlashAttention,
F.scaled_dot_product_attention)内置 kernel-level 稳定化,比手写 naive 安全
FP16 下直接手写 naive attention 不做 fp32 accumulation。
L3高级变体(顶级 lab / diffusion 方向)
Q21.KV cache 如何优化自回归解码?
- 解码第 $t+1$ 步时,只为新 token 算 $Q$(1×D)
- 复用历史 $K, V$(已经在 cache 里),append 新 $k_{t+1}, v_{t+1}$
- 每步 attention 从 $O(t^2)$ 变 $O(t)$,整段生成从 $O(L^3)$ 变 $O(L^2)$
- Per-sample 显存:$L_\text{ctx} \cdot n_\text{layers} \cdot 2 \cdot H_\text{kv} \cdot d_\text{head} \cdot \text{bytes}$(MQA/GQA 下 $H_\text{kv} \ll H$)
说 KV cache 减少训练成本——错。它只用于 autoregressive inference。另:cache 是 KV heads 数量,不是 Q heads。
Q22.MQA 和 GQA 解决什么?
- MQA:多个 Q head 共享一组 K/V(K/V 只有 1 个 head)
- GQA:折中,K/V 有 $G$ 组($1 < G < H$)
- 主要收益:decode 时 KV cache 显存 + 显存带宽(大幅降低)
- 同时也减少 K/V projection 的参数和计算(K/V 投影矩阵变小),但不减少 Q / O projection
- 质量影响:通常 GQA 质量损失小于 MQA,具体取决于模型规模和训练方式(LLaMA-2 70B / LLaMA-3 / Mistral / Qwen-2 都用 GQA)
以为减少了 Q projection;或说"GQA 质量基本不掉"过于绝对。
Q23.FlashAttention 核心 trick?
- Block tiling:把 $Q, K, V$ 切成 SRAM-sized block,分批 load
- Online softmax:增量维护 running max $m$ 与 running sum $\ell$,避免 materialize 完整 $L \times L$ scores / probs 矩阵到 HBM
- Recompute on backward:反向时根据 saved $m, \ell$ 重算 scores,不存中间结果
- 关键:IO-aware exact attention(数学等价,不是近似)
- HBM IO 复杂度约 $O(L^2 d^2 / M + Ld)$,对比标准 attention 的 $O(L^2 + Ld)$ HBM traffic——长序列下显著减少 IO(不是 FLOPs)
说它是近似 attention(如 Performer / Linformer)——错,FlashAttn 是 exact;或把 IO 复杂度和 FLOPs 复杂度混淆。
Q24.RoPE, ALiBi, absolute position 的区别?什么是 attention sink?
- Absolute:位置向量加到 input embedding 上(Vaswani sinusoidal / GPT-2 learned)
- RoPE:对 $Q, K$ 做位置相关的旋转,保留相对位置信息($q_m^\top k_n$ 只依赖 $m-n$)
- ALiBi:在 score 上加距离 bias $-m |i-j|$,自然外推
- Attention sink:训练好的 LLM 会让前 1-4 个 token(特别是 [BOS])获得异常高的 attention,即使内容无关——softmax 强制和为 1,模型需要"垃圾位"。StreamingLLM 利用此现象做长序列推理。
把 attention sink 当成 padding / CLS token 的正常 attention 行为。
Q25.Attention 在 diffusion / latent diffusion 里怎么用?
- U-Net latent tokens 作 Q,text embedding 作 K/V,做 cross-attention 注入文本条件
- Self-attention 在每个 spatial resolution 内做(image patches × image patches)
- CFG (Classifier-Free Guidance):两次 forward,差值放大 conditional 信号
- DiT (Diffusion Transformer):把 U-Net 换成 pure Transformer,conditioning 通过 AdaLN / cross-attn / token-concat
- Video diffusion:空间 attn + 时间 attn + 时空 attn 的组合(长 video 是开放问题,$L \sim 10^5$)
说 diffusion 只靠卷积;或者只在 DiT 里才有 attention(错,U-Net 里也有大量)。
§A 附录:完整 from-scratch 代码骨架
参考 from-scratch 实现包含:
scaled_dot_product_attention()—— 含 NaN 防护MultiHeadAttention—— 标准 MHA,支持 4 种 mask 形状SelfAttention/CrossAttention—— thin wrapper,调用语义清晰causal_mask()/padding_mask()/combine_masks()- 9 个 sanity check(self / causal / padding / cross / wrappers / nn.MHA 对齐 / NaN防护 / d_model%H / return_weights=False)
实跑 sanity check 输出(PyTorch 2.x,单机 GPU):
[a] self-attn out=(2, 5, 16) weights=(2, 4, 5, 5) weights row-sum=1 ✓
[b] causal mask: upper triangle ~ 0 ✓
[c] padding mask: pad-key columns ~ 0 in sample-1 ✓
[d] cross-attn out=(2, 7, 16) weights=(2, 4, 7, 5) ✓
[e] SelfAttention(causal) ✓ CrossAttention(context-pad) ✓
[f] vs nn.MultiheadAttention: |Δout|=0.00e+00 |Δweights|=0.00e+00 ✓
[g] all-masked row: no NaN, weights row = 0 ✓
[h] d_model not divisible by num_heads -> ValueError ✓
[i] return_weights=False -> weights is None ✓
代码经独立 reviewer 静态检查 + PyTorch 实跑 sanity check,与 nn.MultiheadAttention diff = 0。