Interview Prep · Inference Systems / Speculative Decoding

KV Cache + Speculative Decoding 面试 Cheat Sheet

公式推导 + From-Scratch 代码 + 25 高频题(L1 必会 · L2 进阶 · L3 顶级 lab)

Source: docs/tutorials/kv_cache_speculative_decoding_tutorial.md SHA256: 688b47c5bdc9 Rendered: 2026-05-19 05:42 UTC

§0 TL;DR Cheat Sheet

8 句话搞定 KV cache + Speculative Decoding

一页拿下面试核心要点(详见后文 §2–§9 推导)。

  1. KV cache 公式:单 sample 显存 $= 2 \cdot L_\text{ctx} \cdot N_\text{layers} \cdot N_\text{kv\_heads} \cdot d_\text{head} \cdot \text{bytes}$,"2" 来自 K+V。LLaMA-3-70B(GQA, $H_\text{kv}=8$)4K context fp16 ≈ 1.25 GB/sample——这就是为什么不用 MHA。
  2. Prefill vs Decode 不对称:prefill 处理整段 prompt($O(L^2)$ FLOPs,compute-bound);decode 一次生成 1 token(每 step $O(L)$ FLOPs,但要读全部 KV,memory-bandwidth-bound)。这条不对称解释了一切现代 inference 系统设计。
  3. PagedAttention(Kwon et al., SOSP 2023, vLLM):把 KV cache 切成 page,用 block table 解 fragmentation;显存利用率从 ~70% 提升到 ~96%。
  4. Continuous batching(Orca, Yu et al., OSDI 2022):iteration-level scheduling,请求完成不等待整 batch,配合 PagedAttention 是 vLLM 的两根支柱。
  5. MQA → GQA → MLA:MQA(Shazeer 2019)极端共享 K/V,质量略掉;GQA(Ainslie et al., EMNLP 2023)$G$ 组折中;MLA(DeepSeek-V2, May 2024)low-rank latent $c_t^{KV}$ + decoupled RoPE——RoPE 不能直接吸进 latent compression,必须留一个独立小维度 $d_\text{head}^R$ 携带位置。
  6. Speculative Decoding 核心(Leviathan et al., ICML 2023; Chen et al., 2023):用小 draft 模型 $q$ 提议 $K$ 个 token,target 模型 $p$ 一次 forward 并行验证;rejection sampling 保证输出分布与 $p$ 完全等价(exact,不是近似)。
  7. 接受概率公式:单 token 接受率 $\alpha = \mathbb{E}_{x \sim q}[\min(1, p(x)/q(x))]$;期望生成 token 数 $E[\tau] = \dfrac{1-\alpha^{K+1}}{1-\alpha}$($K$ 是 draft 长度,含最后 bonus token)。
  8. Medusa / EAGLE / Lookahead:Medusa(Cai et al., ICML 2024)多头 + 静态 tree attention;EAGLE/2/3(Li et al., 2024-2025)特征级 draft + 动态 tree;Lookahead Decoding(Fu et al., ICML 2024)Jacobi iteration——都是同一接受率框架下的不同 drafter

§1 直觉

1.1 为什么 inference 系统这么"反直觉"

训练时大家关心 FLOPs:模型多大、batch 多大、什么时候 OOM。但部署一个 70B 模型时,瓶颈往往不是算力,而是 HBM 带宽显存 ——这俩都被 KV cache 吃掉。

一条核心 mental model:

Modern LLM inference is bandwidth-bound during decode and memory-bound during long-context prefill, not compute-bound.

KV cache 是把"重复计算"换成"存储 + 带宽"的经典权衡。一旦把整段对话历史的 K/V 都缓存起来,每生成一个新 token 只需要:

但代价是:每生成一个 token,整条 KV cache 都要从 HBM 读到 SRAM——这就是为什么 8 张 H100 + LLaMA-3-70B 跑 batch=1 的 decode 远低于理论 FLOPs 利用率(往往 1-5%)。

Speculative decoding 攻击的正是这个不对称:既然 decode 是 bandwidth-bound 而 GPU 上还有富余算力,何不一次 forward 算 $K$ 个候选 token,反正 weight 只读一次?

1.2 与训练时 attention 的区别

阶段输入KV cache 行为瓶颈
Training$[B, L, D]$ 全序列不需要——所有位置同时算FLOPs (compute)
Prefill (inference)$[B, L_\text{prompt}, D]$ 整段 prompt入 cache,覆盖 $L_\text{prompt}$ 个位置FLOPs($L^2$ attention)
Decode (inference)$[B, 1, D]$ 单 token + append 1 个位置HBM 带宽(每 step 必须读全部 cache + weights)

面试常被反问"训练能不能用 KV cache"——不能:训练时所有位置一次性算,没有"已有 K/V 等着 append"这件事。把 KV cache 用在训练上是新手错误。

§2 KV Cache 显存核算

2.1 精确公式

单 sample(batch=1),fp16:

$$\boxed{\;\text{KV cache}_\text{bytes} = 2 \cdot L_\text{ctx} \cdot N_\text{layers} \cdot N_\text{kv\_heads} \cdot d_\text{head} \cdot \text{bytes\_per\_elem}\;}$$

各因子含义:

常踩坑:不要乘 $H$(Q heads 数量)

KV cache 只跟 K/V heads 走,不跟 Q heads 走。MQA 把 K/V heads 砍成 1 时,Q 仍是 $H$ 个 head,所以 Q projection 计算量不变。

2.2 几个具体模型的 cache 大小(fp16, $L_\text{ctx}=4096$)

代入 §2.1 公式 $2 \cdot L_\text{ctx} \cdot N_\text{layers} \cdot N_\text{kv\_heads} \cdot d_\text{head} \cdot 2\text{B}$ 直接算:

模型$N_\text{layers}$$N_\text{kv\_heads}$$d_\text{head}$cache/sample备注
LLaMA-2-7B (MHA)32321282.0 GB全 MHA,cache 大
LLaMA-2-70B (假设 MHA)806412810.0 GB这就是为什么 70B 用 GQA
LLaMA-2-70B (GQA)8081281.25 GBGQA 把 $H$=64 砍成 $G$=8
LLaMA-3-70B (GQA)8081281.25 GB同 LLaMA-2-70B
DeepSeek-V2 (MLA)60$d_c$=512 + $d_r$=64~0.27 GBMLA 公式:$N_\text{layers} \cdot L_\text{ctx} \cdot (d_c+d_r) \cdot 2\text{B}$

DeepSeek-V2 cache 用 $d_c=512$(latent dim, K/V 共享同一个 latent 向量)+ decoupled RoPE 分量 $d_r=64$(所有 head 共享),按 $N_\text{layers} \cdot L_\text{ctx} \cdot (d_c + d_r) \cdot \text{bytes}$(没有那个 "$\times 2$",因为 K/V 不再各存一份),fp16 下 $60 \cdot 4096 \cdot 576 \cdot 2 \approx$ 0.27 GB / sample——相比同规模 MHA 缩减一个数量级。

2.3 Batch 维度

实际服务里,KV cache 还要乘 active batch size。一个 70B + 4K context + GQA 部署在 8×A100 80GB 上:

§3 Prefill vs Decode 不对称

3.1 两个阶段的 FLOPs / 带宽差异

设 prompt 长度 $L$,模型 hidden $D$,FFN 中间 $4D$,layer 数 $N$。每层 attention + FFN 大约:

$$\text{FLOPs}_\text{layer} \approx \underbrace{6BLD^2}_\text{QKV proj (3 mat)} + \underbrace{4BL^2 D}_\text{attention (QK + AV)} + \underbrace{2BLD^2}_\text{O proj} + \underbrace{16 BLD^2}_\text{FFN (up + down)}$$

Prefill 阶段 $L$ 取 $L_\text{prompt}$:所有项都是 $\Omega(L D^2)$ 或 $\Omega(L^2 D)$,compute-bound,能跑满 GPU。

Decode 阶段 $L=1$(每 step 只算 1 个 token),但 attention 那项变 $4 B L_\text{ctx} D$(QK 和 AV 各 $2 B L_\text{ctx} D$,因 $L_q=1, L_k=L_\text{ctx}$):

所以 小到中等 batch 的 decode 是 memory-bandwidth-bound(batch 大到 GPU 算力被打满之前都是这样)——这一行背下来就够拿 80% 的面试分。

3.2 Chunked Prefill(Sarathi-Serve)

关键洞察

prompt 越长,prefill 一次 forward 把 GPU 算力全占住,正在 decode 的请求会被卡住(head-of-line blocking)。

Sarathi-Serve(Agrawal et al., OSDI 2024)把长 prefill 切成等大小 chunk,每个 iteration 调度一个 prefill chunk + 一批 decode token 一起跑:

       传统:prompt 4096 → 一次 prefill 占满 GPU → decode 卡 100+ ms
       Sarathi: prompt 切 4 个 chunk × 1024 → 每 iteration 跟 decode coalesce

Stall-free schedule:保证每个 iteration 有 decode + prefill chunk 混跑,decode latency 抖动消失。论文实测 Mistral-7B 单 A100 服务能力相比 vLLM 提升 2.6×,Yi-34B 在 2×A100 上 3.7×。

3.3 Continuous Batching(Orca)

传统 static batching:等整个 batch 都跑完才放新请求进来——短请求被长请求拖死。Orca(Yu et al., OSDI 2022)把调度粒度从 request 改成 iteration:每次 forward 都检查一遍 batch 里有没有完成的(EOS),完成的踢出去、空位让新请求进来。

vLLM = Orca 的 continuous batching + PagedAttention 内存管理

这两条加在一起把 LLM serving 吞吐拉到了之前的 24× 左右。

3.4 Prefix Caching

如果多个请求共享 prompt 前缀(如 system prompt、few-shot prompt),完全可以复用同一份 KV cache

§4 KV Cache 优化路线

4.1 路线总览

路线核心 idea减少了什么代表
共享 head(MQA/GQA)多个 Q head 共享一组 K/VKV head 数 $H \to G$ 或 1LLaMA-2/3, Mistral, PaLM
低秩压缩(MLA)投影到低维 latent,cache latent 而非 K/Vhead 维度有效减小 $d_\text{head} \to d_c/H$DeepSeek-V2/V3
量化(KIVI/KVQuant)把 cache 元素 fp16 → int4/int2bytes per elementKIVI, KVQuant, FP8 KV
稀疏化 / eviction只保留"重要"位置的 K/V$L_\text{ctx}$ 有效缩短H2O, StreamingLLM, TriForce
内存管理(PagedAttention)不减小总量,但消除碎片fragmentation overheadvLLM

4.2 MQA / GQA 公式回顾

MHA:每个 head 独立 $W_k^{(h)}, W_v^{(h)} \in \mathbb{R}^{D \times d_\text{head}}$;共 $H \cdot 2 \cdot D \cdot d_\text{head} = 2 D^2$ 参数(K+V)。

MQA(Shazeer 2019):$H$ 个 Q head 共享 1 组 K/V。$W_k, W_v \in \mathbb{R}^{D \times d_\text{head}}$,K+V 参数 $= 2 D d_\text{head}$,减小 $H$ 倍。前向时 K, V 各 broadcast 到 $H$ 个 head 上做 attention。

GQA(Ainslie 2023):$H$ 个 Q head 分成 $G$ 组,每组共享一组 K/V。K+V 参数 $= 2 G D d_\text{head}$。LLaMA-2-70B 用 $H=64, G=8 \Rightarrow$ KV cache 缩小 8×。

MQA 训练不稳定的现象

从头训练 MQA 模型相比 MHA 经常出现质量小幅下降甚至训练不稳定。GQA 论文给的实践:先用 MHA 训完,再 "uptrain" 成 GQA——把 $H$ 组 K/V 沿 head 维 mean-pool 成 $G$ 组初始化,再小 budget(5% 原始训练 compute)finetune 一下。这是为什么 LLaMA-2 70B 能 0-shot 切到 GQA。

4.3 MLA:low-rank latent K/V(DeepSeek-V2 核心创新)

一句话总结 MLA

把 K/V 投影成一个共享的低维 latent $c_t^{KV} \in \mathbb{R}^{d_c}$($d_c \ll H d_\text{head}$),只 cache latent;每次 attention 时线性还原回各 head 的 K/V。

4.3.1 Compression / decompression

输入 hidden state $h_t \in \mathbb{R}^D$。MLA 引入:

$$c_t^{KV} = W^{DKV} h_t \in \mathbb{R}^{d_c}, \quad d_c \ll H d_\text{head}$$

然后只 cache $c_t^{KV}$。生成第 $i$ 个 head 的 K 和 V 时:

$$k_t^{C, (i)} = W^{UK, (i)} c_t^{KV}, \quad v_t^{(i)} = W^{UV, (i)} c_t^{KV}$$

其中 $W^{UK, (i)}, W^{UV, (i)} \in \mathbb{R}^{d_\text{head} \times d_c}$ 是 head-specific up-projection。

类似地,Q 也低秩压缩(这步可选,主要为减训练显存而非 inference):

$$c_t^Q = W^{DQ} h_t, \quad q_t^{C, (i)} = W^{UQ, (i)} c_t^Q$$

4.3.2 Cache size 对比

方案cache 每 token 的元素数
MHA$2 \cdot H \cdot d_\text{head}$
GQA$2 \cdot G \cdot d_\text{head}$
MQA$2 \cdot d_\text{head}$
MLA(裸 latent 部分)$d_c$(单个 vector,不乘 2——因为 K 和 V 共享同一个 latent)

DeepSeek-V2 取 $d_c = 4 d_\text{head}$,相比 MHA($2 H d_\text{head}$)压缩比约 $H/2$ 倍——配 $H=128$ 大概 64×。

4.3.3 Inference 等价变换(absorb trick)

朴素实现里,每 step 要先把 $c_t^{KV}$ 升回 $k_t, v_t$ 再算 attention,那"省 cache"的意义就没了(还得做升投影)。MLA 的妙处在矩阵吸收

attention 分数(忽略 RoPE 部分,只看 content):

$$q_t^{(i)\top} k_s^{(i)} = (W^{UQ, (i)} c_t^Q)^\top (W^{UK, (i)} c_s^{KV}) = c_t^{Q\top} \underbrace{(W^{UQ, (i)\top} W^{UK, (i)})}_\text{常数矩阵 \tilde W^{QK,(i)}} c_s^{KV}$$

$\tilde W^{QK, (i)} \in \mathbb{R}^{d_c' \times d_c}$ 在推理时是固定的,可以在加载模型时预乘一次。这样:

这就是为什么 MLA cache 那么小但推理 FLOPs 没爆炸:矩阵吸收把"省 cache"和"省 compute"解耦了

4.4 MLA 的 RoPE 难题——为什么必须 decoupled

4.4.1 问题:RoPE 破坏 absorb

RoPE 把位置信息以旋转矩阵 $R_t \in \mathbb{R}^{d_\text{head} \times d_\text{head}}$ 形式注入 Q 和 K:

$$q_t^{\text{RoPE}, (i)} = R_t q_t^{(i)}, \quad k_s^{\text{RoPE}, (i)} = R_s k_s^{(i)}$$

attention 分数变成:

$$q_t^{\text{RoPE}, (i)\top} k_s^{\text{RoPE}, (i)} = q_t^{(i)\top} R_t^\top R_s k_s^{(i)} = q_t^{(i)\top} R_{s-t} k_s^{(i)}$$

(用了 $R_t^\top R_s = R_{s-t}$,这是 RoPE 的精髓——相对位置只依赖 $s-t$。)

现在把 latent 形式塞进去:

$$q_t^{\text{RoPE}, (i)\top} k_s^{\text{RoPE}, (i)} = c_t^{Q\top} \underbrace{W^{UQ, (i)\top} R_{s-t} W^{UK, (i)}}_\text{不是常数!依赖 (s-t)} c_s^{KV}$$

中间那块矩阵随 $s-t$ 变化——意味着不能再"预吸收成一个常数矩阵"了。每一对 $(t, s)$ 都得现算 $R_{s-t}$,absorb trick 直接报废,回到 MHA 同等 compute 量级。

4.4.2 解法:把 RoPE 分量独立出来

DeepSeek-V2 的解法:给 RoPE 一个独立的小维度通道

具体来说,引入两个新投影 $W^{KR} \in \mathbb{R}^{D \times d_r}$ 和 $W^{QR, (i)} \in \mathbb{R}^{d_c' \times d_r}$(per-head)。$k_t^R$ 在所有 head 间共享

$$k_t^R = \text{RoPE}_t(W^{KR} h_t), \quad q_t^{R, (i)} = \text{RoPE}_t(W^{QR, (i)} c_t^Q)$$

完整 K/Q 是两段 concat:

$$k_t^{(i)} = [k_t^{C, (i)}; k_t^R], \quad q_t^{(i)} = [q_t^{C, (i)}; q_t^{R, (i)}], \quad k_t^{(i)}, q_t^{(i)} \in \mathbb{R}^{d_\text{head} + d_r}$$

attention 分数变成两部分相加:

$$q_t^{(i)\top} k_s^{(i)} = \underbrace{q_t^{C, (i)\top} k_s^{C, (i)}}_\text{latent, absorb} + \underbrace{q_t^{R, (i)\top} k_s^{R}}_\text{RoPE, 标准 dot}$$

为什么 RoPE 通道共享所有 head

一个独立的 RoPE 维度 $k_t^R$ 给所有 head 共用,cache 只多 $d_r$ 个元素(典型 $d_r = d_\text{head}/2 = 64$)。这是 MLA 设计上"省 cache 的最后一公里"。

4.4.3 总 cache 公式

$$\boxed{\;\text{MLA cache}_\text{per token} = \underbrace{d_c}_\text{latent K/V 共享} + \underbrace{d_r}_\text{RoPE K (head 间共享)}\;}$$

DeepSeek-V2:$d_c = 512, d_r = 64$,每 token 576 个 fp16 元素。对比 LLaMA-3-70B(GQA, $H_\text{kv}=8, d_\text{head}=128$)每 token $2 \cdot 8 \cdot 128 = 2048$ 个元素——MLA 约为 GQA 的 1/3.5;对比同规模 MHA($2 \cdot 64 \cdot 128 = 16384$)约 1/28。DeepSeek-V2 论文报告相比其内部 MHA baseline 93.3% KV 缩减(不同模型规模和 head 数下数字会不同;这里 1/28 是另一组参数下的估算)。

面试常犯错

说"MLA 就是 GQA 的极端版":错。GQA 仍然 cache 完整 K 和 V,只是 head 数变少;MLA cache 的是 latent,K/V 是 inference 时从 latent 还原出来的。两者数学上不同(MLA 改了 attention 结构,GQA 没改)。

4.5 KV 量化(KIVI / KVQuant / FP8)

KV cache 量化路线把每个 cache element 从 fp16(2 bytes)压到更少:

方法量化粒度精度损失备注
FP8 KVper-tensor / per-channel FP8几乎无损H100 原生支持,工业级常用
KIVI (Liu et al., ICML 2024)K per-channel, V per-token 2-bit<1 PPLtuning-free,asymmetric quant
KVQuant (Hooper et al., NeurIPS 2024)per-channel 4-bit + outlier 处理极小论文显示 10M context 可行
KIVI 的关键洞察

K 和 V 的 outlier 分布不一样。K 在 channel 维有显著 outlier(少数 channel 数值大),用 per-channel quant 能吸掉;V 没有 channel-level outlier 但有 token-level 异质,用 per-token quant 更合适。Asymmetric scheme 把这俩别开来处理是 KIVI 的核心贡献。

§5 PagedAttention(vLLM 内存管理)

5.1 问题:KV cache 的 fragmentation

传统 attention 实现把每个 request 的 KV cache 当成连续大 tensor $[L_\text{max}, n_\text{layers}, 2, H_\text{kv}, d_\text{head}]$。问题:

5.2 解法:虚拟内存式分页

PagedAttention(Kwon et al., SOSP 2023)从操作系统 paging 借鉴:

  1. 把 KV cache 切成等大小 page(如每 page 16 个 token)
  2. 每个 request 维护一张 block table:logical block idx → physical block idx
  3. 物理 page 池子在全局,需要时分配;释放时回收
  4. attention kernel 改成 paged attention:按 block table 间接寻址(gather)

效果:

PagedAttention 的代价

间接寻址会引入 ~1-5% kernel overhead(block lookup + scattered HBM access)。但批量大了之后带来的吞吐增益完全 dominate。CUDA Graph 不易兼容(每次 block table 变化要重 capture),所以 vLLM 用 piecewise CUDA Graph。

5.3 Block Table 数据结构(简略)

Request A:  logical_blocks = [0, 1, 2, 3]   →   physical = [12, 7, 34, 19]
Request B:  logical_blocks = [0, 1, 2]      →   physical = [12, 7, 5]    ← prefix 共享!

Request A 和 B 共享前 32 个 token(2 blocks × 16 tokens);vLLM 内部维护每个 physical block 的 ref count,A 想写新内容时如果 ref > 1 触发 COW(复制 + 改 mapping)。

§6 KV Cache 实现代码

6.1 Naive append + autoregressive decode

import math
import torch
import torch.nn.functional as F
from torch import nn

class NaiveCachedAttention(nn.Module):
    """单层 attention with KV cache (MHA / 学习用,不要部署)."""
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.H, self.d = num_heads, d_model // num_heads
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)

    def _split(self, x):  # [B, L, D] → [B, H, L, d]
        B, L, _ = x.shape
        return x.view(B, L, self.H, self.d).transpose(1, 2)

    def forward(self, x, cache=None):
        """
        x:    [B, L_new, D]  新输入(prefill 时 L_new=L_prompt,decode 时 L_new=1)
        cache: dict with 'k','v' of shape [B, H, L_past, d] or None
        Returns: out [B, L_new, D], new_cache
        """
        B, L_new, D = x.shape
        q = self._split(self.W_q(x))                   # [B, H, L_new, d]
        k = self._split(self.W_k(x))                   # [B, H, L_new, d]
        v = self._split(self.W_v(x))                   # [B, H, L_new, d]

        # ── KV cache append ────────────────────────────────────────
        if cache is not None:
            k = torch.cat([cache['k'], k], dim=2)      # [B, H, L_total, d]
            v = torch.cat([cache['v'], v], dim=2)
        new_cache = {'k': k, 'v': v}

        # ── causal attention (decode 时 L_new=1 → causal 自动满足) ──
        L_total = k.size(2)
        scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.d)
        if L_new > 1:                                  # prefill 才需要 causal mask
            causal = torch.tril(torch.ones(L_new, L_total, dtype=torch.bool,
                                          device=x.device), diagonal=L_total - L_new)
            scores = scores.masked_fill(~causal, float('-inf'))
        w = F.softmax(scores, dim=-1)
        out = (w @ v).transpose(1, 2).contiguous().view(B, L_new, D)
        return self.W_o(out), new_cache


# ── 自回归生成 loop(简化版;真模型要加 sampling/stop/多层)──────────
@torch.no_grad()
def generate(model, prompt_ids, max_new_tokens, embed, lm_head):
    cache = None
    out, cache = model(embed(prompt_ids), cache=cache)        # prefill
    next_tok = lm_head(out[:, -1:]).argmax(-1)
    generated = [next_tok]
    for _ in range(max_new_tokens - 1):
        out, cache = model(embed(next_tok), cache=cache)      # decode L_new=1
        next_tok = lm_head(out[:, -1:]).argmax(-1)
        generated.append(next_tok)
    return torch.cat(generated, dim=1)

注意:每步 torch.cat触发一次新显存分配 + memcpy;生产里换成预分配 buffer + index assignment 或 PagedAttention。

6.2 PagedAttention 数据结构 sketch

class PagedKVCache:
    """简化 page table (不含 CUDA kernel,演示数据结构 + COW)."""
    def __init__(self, n_layers, n_kv_heads, head_dim, page_size=16,
                 n_pages=1024, dtype=torch.float16, device='cuda'):
        self.page_size = page_size
        # 全局 page 池:[n_pages, page_size, n_layers, 2 (K,V), n_kv_heads, head_dim]
        self.pool = torch.empty(n_pages, page_size, n_layers, 2,
                                n_kv_heads, head_dim, dtype=dtype, device=device)
        self.free_list = list(range(n_pages))
        self.ref_count = [0] * n_pages
        self.block_table = {}                          # req_id → list of physical page ids

    def allocate(self, req_id, n_tokens):
        n = (n_tokens + self.page_size - 1) // self.page_size
        assert len(self.free_list) >= n, "OOM"
        physical = [self.free_list.pop() for _ in range(n)]
        for pid in physical: self.ref_count[pid] = 1
        self.block_table[req_id] = physical

    def append_token(self, req_id, pos, layer, k_new, v_new):
        """k_new, v_new: [n_kv_heads, head_dim]"""
        page_idx, slot = pos // self.page_size, pos % self.page_size
        if page_idx >= len(self.block_table[req_id]):
            new_pid = self.free_list.pop()
            self.ref_count[new_pid] = 1
            self.block_table[req_id].append(new_pid)
        pid = self.block_table[req_id][page_idx]
        if self.ref_count[pid] > 1:                    # COW
            new_pid = self.free_list.pop()
            self.pool[new_pid] = self.pool[pid]
            self.ref_count[pid] -= 1; self.ref_count[new_pid] = 1
            self.block_table[req_id][page_idx] = new_pid
            pid = new_pid
        self.pool[pid, slot, layer, 0] = k_new
        self.pool[pid, slot, layer, 1] = v_new

    def free(self, req_id):
        for pid in self.block_table[req_id]:
            self.ref_count[pid] -= 1
            if self.ref_count[pid] == 0: self.free_list.append(pid)
        del self.block_table[req_id]

    def share_prefix(self, src, dst, n_tokens):
        """beam search / parallel sampling:复用前 n_tokens 的 page。"""
        n = n_tokens // self.page_size
        prefix = self.block_table[src][:n]
        for pid in prefix: self.ref_count[pid] += 1
        self.block_table[dst] = list(prefix)

生产实现还需 fused paged attention kernel(per-block gather + flash attention 思路)和 device 上的 block table layout。

6.3 MQA / GQA / MLA 在 forward 里的差异

class MQA_GQA_Attention(nn.Module):
    """MHA / MQA / GQA 通用版本 (num_kv_heads ≤ num_heads)."""
    def __init__(self, d_model, num_heads, num_kv_heads):
        super().__init__()
        assert num_heads % num_kv_heads == 0, "H 必须能被 H_kv 整除"
        self.H, self.H_kv = num_heads, num_kv_heads
        self.d = d_model // num_heads
        self.group = num_heads // num_kv_heads        # 每个 KV head 服务的 Q head 数
        self.W_q = nn.Linear(d_model, num_heads * self.d, bias=False)
        self.W_k = nn.Linear(d_model, num_kv_heads * self.d, bias=False)   # ← 小了
        self.W_v = nn.Linear(d_model, num_kv_heads * self.d, bias=False)   # ← 小了
        self.W_o = nn.Linear(num_heads * self.d, d_model, bias=False)

    def forward(self, x, cache=None):
        B, L_new, _ = x.shape
        q = self.W_q(x).view(B, L_new, self.H, self.d).transpose(1, 2)
        k = self.W_k(x).view(B, L_new, self.H_kv, self.d).transpose(1, 2)
        v = self.W_v(x).view(B, L_new, self.H_kv, self.d).transpose(1, 2)

        if cache is not None:
            k = torch.cat([cache['k'], k], dim=2)
            v = torch.cat([cache['v'], v], dim=2)
        new_cache = {'k': k, 'v': v}

        # ── 关键:把 K/V broadcast 到 Q heads ────────────────────────
        k = k.repeat_interleave(self.group, dim=1)    # [B, H, L_total, d]
        v = v.repeat_interleave(self.group, dim=1)
        # repeat_interleave 是显式 broadcast;生产用 torch 的隐式 broadcast 或 fused kernel

        scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.d)
        L_total = k.size(2)
        if L_new > 1:
            causal = torch.tril(torch.ones(L_new, L_total, dtype=torch.bool,
                                          device=x.device), diagonal=L_total - L_new)
            scores = scores.masked_fill(~causal, float('-inf'))
        w = F.softmax(scores, dim=-1)
        out = (w @ v).transpose(1, 2).contiguous().view(B, L_new, -1)
        return self.W_o(out), new_cache


class MLAAttention(nn.Module):
    """MLA with decoupled RoPE (DeepSeek-V2 简化版)."""
    def __init__(self, d_model, num_heads, d_head, d_c, d_r):
        super().__init__()
        self.H = num_heads
        self.d_head = d_head          # content head dim
        self.d_c = d_c                # latent dim (KV 共享)
        self.d_r = d_r                # RoPE head dim (per query head)
        # KV 端:先压成 latent,再升回 K_C, V
        self.W_DKV = nn.Linear(d_model, d_c, bias=False)
        self.W_UK = nn.Linear(d_c, num_heads * d_head, bias=False)
        self.W_UV = nn.Linear(d_c, num_heads * d_head, bias=False)
        # RoPE 通道:K 共享 (head 间共享)
        self.W_KR = nn.Linear(d_model, d_r, bias=False)
        # Q 端:先压成 latent (省训练显存),再升回 Q_C 和 Q_R
        d_c_q = 4 * d_head            # 论文里 d_c_q ≠ d_c
        self.W_DQ = nn.Linear(d_model, d_c_q, bias=False)
        self.W_UQ = nn.Linear(d_c_q, num_heads * (d_head + d_r), bias=False)
        self.W_o = nn.Linear(num_heads * d_head, d_model, bias=False)

    def _rope(self, x, positions, l_axis=-2, base=10000.0):
        """标准 RoPE 2D 旋转。L 在 l_axis 维,d_r 在最后一维;positions: [L]。"""
        assert x.size(-1) % 2 == 0
        D, L = x.size(-1), x.size(l_axis)
        assert positions.shape == (L,)
        i = torch.arange(D // 2, device=x.device, dtype=x.dtype)
        theta = base ** (-2.0 * i / D)                              # [D/2]
        angles = positions.to(x.dtype).unsqueeze(-1) * theta        # [L, D/2]
        cos, sin = angles.cos(), angles.sin()
        ndim = x.dim()
        cos_shape = [1] * ndim
        cos_shape[l_axis % ndim] = L
        cos_shape[-1] = D // 2
        cos_b, sin_b = cos.view(cos_shape), sin.view(cos_shape)
        x2 = x.view(*x.shape[:-1], D // 2, 2)
        x_real, x_imag = x2.unbind(-1)
        rot_real = x_real * cos_b - x_imag * sin_b
        rot_imag = x_real * sin_b + x_imag * cos_b
        return torch.stack([rot_real, rot_imag], dim=-1).flatten(-2)

    def forward(self, x, positions, cache=None):
        B, L_new, _ = x.shape
        # ── KV 端:算 latent + RoPE-K ────────────────────────────────
        c_kv = self.W_DKV(x)                           # [B, L_new, d_c]
        k_r = self._rope(self.W_KR(x), positions)      # [B, L_new, d_r]  (head 间共享)
        if cache is not None:
            c_kv = torch.cat([cache['c_kv'], c_kv], dim=1)
            k_r  = torch.cat([cache['k_r'],  k_r],  dim=1)
        new_cache = {'c_kv': c_kv, 'k_r': k_r}        # ← 只缓存 latent+RoPE,不缓存完整 K/V

        # 升回 K_C, V (这步推理时可被 absorb 进 Q/O 投影;演示先朴素算)
        k_c = self.W_UK(c_kv).view(B, -1, self.H, self.d_head).transpose(1,2)  # [B,H,L_tot,d_head]
        v   = self.W_UV(c_kv).view(B, -1, self.H, self.d_head).transpose(1,2)

        # ── Q 端:拆 content 和 RoPE 两段 ────────────────────────────
        c_q = self.W_DQ(x)                             # [B, L_new, d_c_q]
        q_full = self.W_UQ(c_q).view(B, L_new, self.H, self.d_head + self.d_r)
        q_c, q_r_raw = q_full.split([self.d_head, self.d_r], dim=-1)
        q_c = q_c.transpose(1, 2)                      # [B, H, L_new, d_head]
        q_r = self._rope(q_r_raw, positions, l_axis=1).transpose(1, 2)  # [B, H, L_new, d_r]

        # ── attention score 两部分相加 ───────────────────────────────
        L_tot = c_kv.size(1)
        scale = math.sqrt(self.d_head + self.d_r)
        scores_c = q_c @ k_c.transpose(-2, -1)         # content 部分 [B, H, L_new, L_tot]
        # k_r 是所有 head 共享的,broadcast:
        scores_r = q_r @ k_r.transpose(-2, -1).unsqueeze(1)  # [B, H, L_new, L_tot]
        scores = (scores_c + scores_r) / scale

        if L_new > 1:
            causal = torch.tril(torch.ones(L_new, L_tot, dtype=torch.bool,
                                          device=x.device), diagonal=L_tot - L_new)
            scores = scores.masked_fill(~causal, float('-inf'))
        w = F.softmax(scores, dim=-1)
        out = (w @ v).transpose(1, 2).contiguous().view(B, L_new, -1)
        return self.W_o(out), new_cache
演示代码 vs 生产代码

上面 MLA 实现"朴素地"在每 step 升回完整 K/V,并没真省 compute;生产实现要做 §4.3.3 的 absorb trick——在加载模型时把 $W^{UQ\top} W^{UK}$ 和 $W^{UV} W^O$ 预乘一次,inference 时直接对 latent 算分。

§7 Speculative Decoding 核心机制

7.1 设定

每个 spec step:

  1. 用 $q$ 自回归地 draft $K$ 个 token: $\tilde x_1, \tilde x_2, \dots, \tilde x_K$
  2. 把 prefix + $K$ 个 draft 一次性送进 $p$,并行算出 $p(x_{t+i} | x_{\le t+i-1})$ for $i=1..K$(外加 $i=K+1$ 的 logits 做 bonus token)
  3. 对每个 draft 位置做 rejection sampling:以概率 $\min(1, p(\tilde x_i) / q(\tilde x_i))$ 接受
  4. 第一个被拒的位置 $j$:从修正后的残差分布 $p'$ 重新采一个;位置 $j+1, \dots, K$ 全丢
  5. 如果全部接受($j = K+1$),还能从 target 的最后一组 logits 免费采一个 bonus token

7.2 接受概率 $\alpha$ 的推导(必考)

核心定理(Leviathan et al. 2023, Chen et al. 2023)

Rejection sampling 使整个 spec step 的输出分布与从 $p$ 单步采样完全等价

推导:设我们在某一位置,draft 给出 $\tilde x \sim q(\cdot)$。

接受规则:以概率 $r(\tilde x) = \min(1, p(\tilde x)/q(\tilde x))$ 接受。

被接受的 token 出现概率:

$$\Pr[\text{accept} \land X = x] = q(x) \cdot r(x) = q(x) \cdot \min\!\left(1, \frac{p(x)}{q(x)}\right) = \min(q(x), p(x))$$

被拒概率:

$$\beta = 1 - \alpha = \sum_x q(x) - \sum_x \min(q(x), p(x)) = \sum_x \max(0, q(x) - p(x))$$

整体接受率:

$$\boxed{\;\alpha = \sum_x \min(q(x), p(x)) = 1 - \tfrac{1}{2}\|p - q\|_1\;}$$

最后一步用恒等式 $\sum_x \min(p, q) = 1 - \tfrac{1}{2} \sum_x |p - q|$(注意 $\sum_x p = \sum_x q = 1$)。

推论:$p$ 和 $q$ 越接近,TV distance 越小,$\alpha$ 越接近 1。

7.3 残差分布 $p'$(接受失败时怎么采)

被拒后,我们要从 $p$ "排除掉被接受 mass" 的剩余采一个新 token:

$$p'(x) = \frac{\max(0, p(x) - q(x))}{\sum_x \max(0, p(x) - q(x))} = \frac{\max(0, p(x) - q(x))}{1 - \alpha}$$

等价性证明(关键,面试要会推):考察某 token $x$ 在一个 spec step 内被输出的总概率:

$$\Pr[X = x] = \underbrace{q(x) \min(1, p(x)/q(x))}_\text{accept path} + \underbrace{(1-\alpha) \cdot p'(x)}_\text{reject path}$$

加和:$\min(p, q) + \max(0, p - q) = p$. ✅

所以每个位置的输出分布严格等于 $p$——这是 spec decoding "exact" 的数学根据。

7.4 期望加速:$E[\tau]$ 公式

设每个 draft 位置接受概率独立同分布(实际略相关,但论文常用此近似)。$K$ 个 draft + 1 个 bonus:

期望 token 数:

$$E[\tau] = \sum_{j=0}^{K-1} \alpha^j (1-\alpha) (j+1) + \alpha^K (K+1)$$

化简(几何级数标准技巧):

$$\boxed{\;E[\tau] = \frac{1 - \alpha^{K+1}}{1 - \alpha}\;}$$

极限分析

加速比的真实公式

还要扣 draft 模型自己的 forward 开销。设 $c = T_q / T_p$(draft 单 step 时间 / target 单 step 时间,典型 0.05-0.15):

$$\text{speedup} = \frac{E[\tau]}{1 + Kc}$$

分子是平均接受 token 数;分母 1 是一次 target verify,$Kc$ 是 $K$ 次 draft forward。$c$ 太大(draft 太大)会吃掉收益,所以 draft 选小很重要。

7.5 Temperature / Top-p 下的 sampling 等价性

Rejection sampling 等价性只要求两件事:(1) target 端用经 sampler 处理后的分布 $\tilde p$ 替换 $p$ 来算接受率和残差;(2) draft proposal 分布 $\tilde q$ 是任何合法概率分布即可。数学上不强制 draft 用与 target 相同的 sampler——draft 完全 greedy 也合法,只是 $\tilde q$ 跟 $\tilde p$ 偏离大、$\alpha$ 暴跌。实践常把 draft 也用同一组 temperature/top-p 让 $\tilde q$ 贴近 $\tilde p$。

错误等价方案

"draft 用同样 sampling 后只比 token 一致" 不对——这样会丢分布等价性。正确做法是按 §7.3 的 rejection 公式,比较概率而不是 token 一致

7.6 代码:speculative decoding loop

下面给出 single-batch 演示版。约定:两个模型暴露 forward(input_ids, cache);cache 对象有 length 属性 + truncate(L) 方法(生产里 PagedAttention 用 block table 改指针实现 $O(1)$ rollback)。核心不变式:每轮迭代开始时,cache.length == seq.size(1) - 1(即 cache 里有除最后 1 个 token 之外的所有 prefix)。

import torch

@torch.no_grad()
def speculative_decode(target, draft, prompt_ids, max_new_tokens, K=4, temperature=1.0):
    """
    target, draft: callable(input_ids, cache) → (logits [1,L_new,V], new_cache)。
    数学上 exact 等价于直接从 target 采样(Leviathan/Chen 2023)。
    """
    seq = prompt_ids.clone()                                   # [1, L_prompt]
    L_prompt = seq.size(1)
    # Prefill 前 L-1 个 token;最后 1 个 token 留作首轮 verify input。
    _, target_cache = target(seq[:, :-1], cache=None)
    _, draft_cache  = draft(seq[:, :-1],  cache=None)

    while seq.size(1) - L_prompt < max_new_tokens:
        last_tok = seq[:, -1:]                                  # [1, 1],尚未入 cache
        draft_chk, target_chk = draft_cache.length, target_cache.length

        # ── 1. Draft:依次喂 last_tok, d_1, ..., d_{K-1};采样 d_1..d_K ──
        cur = last_tok
        draft_tokens, draft_probs = [], []
        for _ in range(K):
            logits, draft_cache = draft(cur, cache=draft_cache)
            probs = torch.softmax(logits[:, -1, :] / temperature, dim=-1)
            tok = torch.multinomial(probs, 1)
            draft_tokens.append(tok); draft_probs.append(probs)
            cur = tok
        draft_tokens = torch.cat(draft_tokens, dim=1)            # [1, K]

        # ── 2. Target:一次 forward 看 [last_tok, d_1..d_K],输出 K+1 个分布 ──
        target_in = torch.cat([last_tok, draft_tokens], dim=1)   # [1, K+1]
        target_logits, target_cache = target(target_in, cache=target_cache)
        # target_logits[:, i, :] 用来 verify d_{i+1}(i<K)或 bonus 采样(i=K)

        # ── 3. Rejection sampling 逐位置 ──
        accepted = 0; rejected = False
        for i in range(K):
            p_i = torch.softmax(target_logits[:, i, :] / temperature, dim=-1)
            q_i = draft_probs[i]
            x = draft_tokens[:, i:i+1]
            p_x = p_i.gather(-1, x).squeeze(-1)
            q_x = q_i.gather(-1, x).squeeze(-1)
            r = (p_x / (q_x + 1e-9)).clamp(max=1.0)
            if torch.rand_like(r).item() < r.item():             # accept
                accepted += 1
            else:                                                 # reject
                p_prime = (p_i - q_i).clamp(min=0.0)              # 残差分布 p'
                p_prime = p_prime / (p_prime.sum(-1, keepdim=True) + 1e-9)
                new_tok = torch.multinomial(p_prime, 1)
                seq = torch.cat([seq, draft_tokens[:, :accepted], new_tok], dim=1)
                # rollback:保留 last_tok + accepted 个 draft;new_tok 暂不入 cache
                draft_cache.truncate(draft_chk + 1 + accepted)
                target_cache.truncate(target_chk + 1 + accepted)
                rejected = True
                break

        if not rejected:                                          # 全接受 → bonus token
            p_bonus = torch.softmax(target_logits[:, K, :] / temperature, dim=-1)
            bonus = torch.multinomial(p_bonus, 1)
            seq = torch.cat([seq, draft_tokens, bonus], dim=1)
            # draft_cache 之前只见过 d_1..d_{K-1};补喂 d_K 保持不变式
            _, draft_cache = draft(draft_tokens[:, -1:], cache=draft_cache)
            # target_cache 已含 last_tok + d_1..d_K,对应不变式

    return seq[:, :L_prompt + max_new_tokens]                     # overshoot 截断
生产实现要点

(1) Cache rollback 必须真正回退 KV cache 写入位置;vLLM 用 PagedAttention 的 block table 改指针实现 $O(1)$。(2) 数值稳定:$r=p/q$ 在 $q \to 0$ 时爆炸,用 fp32 计算并 clamp。(3) 不变式:每轮开始时 cache 含 seq 除最后 1 个 token 之外的所有 prefix——所有 rollback / 补喂逻辑都为维护这一条。

§8 Speculative Decoding 主要变体

8.1 变体总览

方法Draft 来源多 token 结构是否需 draft 模型训练代表论文
Vanilla SD独立小模型线性 chain否(用现成小 LM)Leviathan 2023, Chen 2023
SpecInfer多个 draft 一起静态 treeMiao 2024 (ASPLOS)
Medusa在 target 上加 N 个 head静态 tree是(finetune heads)Cai 2024 (ICML)
EAGLE-1feature-level autoregression + 小 modeltree是(draft 头小)Li 2024 (ICML)
EAGLE-2同 EAGLE-1dynamic treeLi 2024
EAGLE-3多层 feature fusion + training-time test动态 treeLi 2025
Hydrasequential draft heads静态 treeAnkner 2024
Lookahead DecodingJacobi 迭代 + n-gram pool自己 verifyFu 2024 (ICML)
RESTretrieval (datastore)静态 treeHe 2024 (NAACL)
Self-Speculativetarget 自己跳层线性否(用 target 部分层)Zhang 2024
TriForce分层(small LM + sparse target)hierarchicalSun 2024
MagicDecsmall draft + sparse KV线性Sadhukhan 2024

8.2 Medusa:多头并行

Medusa(Cai et al., ICML 2024)的核心:在 target 模型的最后 hidden state 上加 $N$ 个并行的 prediction heads,第 $k$ 个 head 直接预测"未来第 $k+1$ 个 token"。

Tree attention 的 mask

把 tree 节点按 BFS 排成线性序列 $[t_0, t_1, \dots, t_M]$,节点 $i$ 的 ancestor 集合 $\mathcal A(i)$(含自己),mask $M[i, j] = 1$ iff $j \in \mathcal A(i)$。这样每个节点只看到从 root 到自己的路径,logits 正确。

8.3 EAGLE 系列:特征级 autoregression

EAGLE(Li et al., ICML 2024)核心洞察:target 上一层的 hidden feature $h_{t-1}$ 比 token 包含更多信息——draft 阶段在 feature space 做 autoregression 比在 token space 更准。

EAGLE-2(Li 2024):把静态 tree 换成 dynamic tree——每步用 draft 给出的概率挑当前最 promising 的几条路径展开。

EAGLE-3(Li et al., 2025):

  1. 抛弃 feature regression,直接 token prediction(feature 误差累积是瓶颈)
  2. 多层 feature fusion 而非只用 top layer
  3. Training-Time Test(TTT):训练时模拟 inference 时的 draft chain 误差,避免 train-test gap
  4. 在 Vicuna-7B 上 4-5× 加速,比 EAGLE-2 进一步提升 ~30%
EAGLE 系列的训练成本

虽然 draft 模型很小(1 层 transformer,参数量约几十 M),但要在 target 模型的特征上重新训练,需要 target 模型的 forward pass dataset(蒸馏式)。一次 EAGLE-3 训练几小时到几天,不是免费的。

8.4 Lookahead Decoding:Jacobi 迭代

Jacobi 视角(Fu et al., ICML 2024):自回归生成等价于解非线性方程组 $x_i = f(x_{

Step 0:  x = [<random>, <random>, ..., <random>]
Step 1:  x'_i = f(x_{<i})  ∀ i  并行
Step 2:  x = x',再来一轮
... 直到 fixed point

Lookahead Decoding:

效果:MT-bench 1.8×,code completion 多卡 4× 加速。但对没有重复模式的开放生成(如复杂 reasoning)加速有限。

8.5 长 context 专属:TriForce / MagicDec

长 context 下 vanilla SD 失效

当 context 长(如 128K),target 模型每 forward 都要扫整条 KV cache,HBM 带宽 才是瓶颈,而不是 weight 加载。普通 SD 节省的是 weight loading 频率,到这种 regime 收益变小。

TriForce(Sun et al., 2024, arxiv 2404.11912):三层 hierarchy。

  1. 第一层 draft:小 LM
  2. 第二层 draft:target 模型 + sparse KV cache(只保留 heavy-hitter / retrieved 部分)
  3. 第三层 verify:target 模型 + 完整 KV cache

加速核心:第二层 draft 用 sparse cache 走得快,再让完整 cache 的 target 一次 verify 一长串。

MagicDec(Sadhukhan et al., 2024, arxiv 2408.11049, ICLR 2025):观察到长 context 下 KV cache 才是瓶颈,所以 draft 用 StreamingLLM 风格 sparse cache(attention sink + sliding window),target 用 full cache verify。

8.6 Self-Speculative Decoding

不需要外部 draft model 的极端版

Zhang et al. 2024 提出 "self-speculative":用 target 模型本身跳层(skip a subset of layers) 作为 draft。

§9 复杂度 / 资源核算

9.1 KV cache 显存(汇总)

以 LLaMA-2/3-70B 架构($N_\text{layers}=80, d_\text{head}=128$, MHA $H=64$)为基准,$L_\text{ctx}=4096$, fp16:

方案per-token-per-layer bytes70B / 4K context (全模型)
MHA ($H_\text{kv}=64$)$2 \cdot 2 \cdot 64 \cdot 128 = 32768$10.0 GB
GQA ($H_\text{kv}=8$)$2 \cdot 2 \cdot 8 \cdot 128 = 4096$1.25 GB
MQA ($H_\text{kv}=1$)$2 \cdot 2 \cdot 1 \cdot 128 = 512$0.16 GB
MLA ($d_c=512, d_r=64$, 60 layers)$2 \cdot (512+64) = 1152$~0.27 GB
MHA + INT4 KV$32768 / 4 = 8192$2.5 GB

9.2 Speculative decoding 期望吞吐

$$\text{tokens / sec}_\text{SD} = \frac{\text{tokens / sec}_\text{baseline} \cdot E[\tau]}{1 + Kc}$$

经验数字(A100, LLaMA-2-7B target + 68M draft):

长 context(128K+)regime:vanilla SD 加速降到 1.1-1.3×;TriForce / MagicDec 仍能保 2-2.5×。

9.3 预算粗算(70B + 4K context + GQA)

显存
weights (fp16)140 GB
KV cache @ batch 8$8 \times 1.25$ GB $= 10$ GB
activation peak (decode, batch 8)~2 GB
~152 GB → 2×80GB A100 紧;通常 4×80GB

如果用 vanilla SD,draft model(7B fp16)+14 GB;EAGLE 的 draft head 只 ~200 MB。

§10 25 高频面试题

按难度分 L1(必会)/ L2(进阶)/ L3(顶级 lab)。所有题点开看答案 + 易踩坑。

L1 必会题(任何 inference / serving 岗位都会问)

Q1.KV cache 公式是什么?
  • 单 sample:$2 \cdot L_\text{ctx} \cdot N_\text{layers} \cdot N_\text{kv\_heads} \cdot d_\text{head} \cdot \text{bytes}$
  • "2" 来自 K + V
  • $N_\text{kv\_heads}$:MHA = $H$, MQA = 1, GQA = $G$
  • LLaMA-3-70B(GQA, $H_\text{kv}=8$)@4K fp16 ≈ 1.25 GB/sample

写 $H$(Q heads);忘乘 2;忘 $L_\text{ctx}$ 是当前长度不是 max length。

Q2.为什么训练时不用 KV cache?
  • 训练时所有位置同时算(teacher forcing,已知 ground truth)
  • 没有"先有部分序列、再 append 新位置"这个时序
  • KV cache 是推理专属优化

把 KV cache 当作通用优化用在训练。

Q3.Prefill 和 decode 阶段的瓶颈分别是什么?
  • Prefill:$O(L^2)$ attention FLOPs,compute-bound
  • Decode:每步只算 1 个 token,但要读完整 cache + weights,memory-bandwidth-bound
  • Arithmetic intensity 极低 → GPU FLOPs 利用率往往 < 10%

说"decode 也是 compute-bound"——错。Decode batch 太小时 GPU 大部分时间在等内存。

Q4.MQA / GQA / MHA 区别?
  • MHA:$H$ 个 K/V head(同 Q)
  • MQA:所有 Q head 共享 1 组 K/V
  • GQA:$H$ 个 Q head 分 $G$ 组($1
  • 主要省的是 KV cache 显存 + 显存带宽,不省 Q projection

以为 MQA 省的是 Q 计算;说 GQA 质量"基本不掉"过于绝对。

Q5.Speculative decoding 公式?
  • Draft $q$ 提议 $K$ 个 token,target $p$ 一次 forward verify
  • 每位置接受概率 $r = \min(1, p(\tilde x)/q(\tilde x))$
  • 整体接受率 $\alpha = \sum_x \min(p(x), q(x))$
  • 期望生成 $E[\tau] = \dfrac{1 - \alpha^{K+1}}{1 - \alpha}$

说 "spec decoding 是近似采样"——错,是 exact(rejection sampling 保证)。

Q6.PagedAttention 解决什么?
  • 朴素 KV cache 必须连续大 tensor,预分配 max length → internal fragmentation
  • 不同 request 长度不一 → external fragmentation
  • 显存利用率仅 ~70%
  • PagedAttention:切 page + block table,利用率提升到 ~96%
  • 支持 prefix sharing(COW)

说 PagedAttention 减少 attention FLOPs——错,FLOPs 不变;它优化的是显存利用率 + 并发请求数

Q7.Continuous batching 是什么?
  • 调度粒度从 request 改成 iteration(每 forward 都重检查 batch)
  • 完成的请求立即踢出,腾出 slot 给新请求
  • 提出:Orca (Yu et al., OSDI 2022)
  • 缩短平均等待时间,提高 GPU 利用率
  • vLLM = Orca continuous batching + PagedAttention

以为 continuous batching 是把不同长度 sequence 都 pad 到最长——那是 static batching 的老做法。

Q8.Draft 模型怎么选?
  • 大小:典型 target / 30 - target / 10(如 70B target + 7B draft)
  • 同 tokenizer、同 vocab(否则 rejection sampling 算不出 $p/q$)
  • 同 prompt format / 同 RLHF 后训练(否则 distribution gap 大,$\alpha$ 低)
  • 经验:$\alpha \in [0.5, 0.8]$,太低就别上 SD

选 draft 太大(如 target / 3);或选不同 tokenizer。

Q9.KV cache 量化最常见做法?
  • FP8(H100 原生支持)几乎无损
  • INT8 per-token quant 也可以接受
  • INT4 / INT2(KIVI, KVQuant)需要精细 outlier 处理
  • KIVI 的关键:K per-channel, V per-token 不对称量化

把 K 和 V 用同一个 quant 方案——容易掉点;K 和 V 的 outlier 分布不一样。

Q10.Prefix caching 是什么?
  • 多个请求共享同一段 prompt 前缀(system prompt、few-shot)
  • 用 hash(prefix) 索引 page 池,命中跳过 prefill
  • 配合 COW 处理后续分叉
  • ChatGPT 这种 system-prompt heavy 服务命中率 90%+

以为 prefix caching = prompt 全部缓存——只 cache prefix;用户特定部分还要 prefill。

L2 进阶题(research-oriented / inference 系统岗)

Q11.推 spec decoding 的接受概率 $\alpha$,并解释它为什么保证 exact sampling。
  • 设 draft $\tilde x \sim q$,接受规则 $r = \min(1, p/q)$
  • $\Pr[\text{accept} \land X=x] = q(x) \cdot \min(1, p(x)/q(x)) = \min(p(x), q(x))$
  • $\alpha = \sum_x \min(p, q) = 1 - \tfrac{1}{2} \|p - q\|_1$
  • 被拒后从残差 $p'(x) = \max(0, p-q) / (1-\alpha)$ 重采
  • 总输出概率 $= \min(p, q) + \max(0, p-q) = p(x)$ ∀x
  • 所以每位置等价于直接从 $p$ 单 step 采样

只写 accept 部分,漏 reject 残差分布;忘证 $\min + \max = p$;说 spec 是近似。

Q12.MLA 为什么必须 decoupled RoPE?详细推导。
  • 朴素 MLA absorb trick:$q^\top k = c_q^\top (W^{UQ\top} W^{UK}) c_{kv}$,中间是常数矩阵 $\tilde W^{QK}$,预乘即可
  • 加 RoPE 后:$q^{R\top} k^R = c_q^\top W^{UQ\top} R_{s-t} W^{UK} c_{kv}$
  • 中间块 $W^{UQ\top} R_{s-t} W^{UK}$ 依赖相对位置 $(s-t)$,不能预乘
  • absorb 失效 → cache 是省了,compute 退化回 MHA
  • 解:拆出独立 RoPE 通道 $k^R \in \mathbb{R}^{d_r}$(所有 head 共享),content 通道走 absorb,RoPE 通道走标准 dot product
  • 总 cache:$d_c + d_r$ per token

只说"加 RoPE 出问题"不展开;不知道 RoPE 的 $R_t^\top R_s = R_{s-t}$ 性质;不知道 $k^R$ 是所有 head 共享。

Q13.Continuous batching 在 prefill + decode 混跑时怎么处理?
  • Prefill 一次性算长段,FLOPs 大;decode 单 token,FLOPs 小
  • 直接混 batch 会让 decode 等 prefill 长时间(HOL blocking)
  • Sarathi-Serve 的 chunked prefill:把长 prefill 切等大小 chunk
  • 每 iteration coalesce 一个 prefill chunk + 多个 decode token
  • stall-free schedule:保证 decode 永远跟着跑

以为 prefill 必须一次跑完;忘记 Sarathi-Serve 是 OSDI 2024。

Q14.Tree attention(Medusa / EAGLE 用)的 mask 怎么写?
  • 把 tree 节点按 BFS 拍平成线性序列 $[t_0, \dots, t_M]$
  • $\mathcal A(i)$ = 节点 $i$ 的祖先(含自身)
  • attention mask $M[i, j] = 1 \iff j \in \mathcal A(i)$
  • 即"causal 在树上的推广"
  • 用于一次 forward 同时 verify tree 里所有路径

写成 lower-triangular causal mask(只适用 chain,不适用 tree);忘记把 mask 的形状从 $[L,L]$ 推广。

Q15.spec decoding 的实际加速公式?为什么 draft 太大会反效果?
  • $\text{speedup} = E[\tau] / (1 + Kc)$,$c = T_q/T_p$
  • $E[\tau] = (1-\alpha^{K+1})/(1-\alpha)$
  • 分母里 $Kc$ 是 $K$ 次 draft forward 的开销
  • 若 $c$ 太大(draft 太大),即使 $\alpha$ 高也会被分母吃光
  • 极端:$c=1$ 时 speedup ≤ 1(draft 跟 target 一样慢)

只写 $E[\tau]$ 不算 draft 开销;漏 bonus token 那一项。

Q16.Self-speculative decoding 和普通 spec decoding 区别?
  • Self-spec:draft 是 target 自己跳层 / 早 exit
  • 不需要独立 draft model,零额外训练
  • 但 draft 与 target 高度相关,$\alpha$ 通常较高
  • 加速一般 1.5-2×(不如 EAGLE 但更省事)
  • 论文:Zhang et al. 2024("Draft & Verify")

说必须有额外训练;混 self-spec 和 layer skipping inference(后者不是 exact)。

Q17.KV cache eviction 和 sparse attention 怎么影响 spec decoding?
  • 长 context 下 KV cache 才是带宽瓶颈,weights 已经被 prefill 摊销
  • 这时 draft 用 sparse / sliding window KV(StreamingLLM 风格)能跑得快
  • target 用完整 cache 做 verify 保证 exact
  • 代表:MagicDec、TriForce(hierarchical:小 draft → sparse target → full target)
  • 收益:长 context 下 vanilla SD 失效(1.1×),MagicDec 能保 2×+

把 sparse KV 当成 lossy 近似(实际只用于 draft,verify 时全 cache 仍 exact)。

Q18.Medusa 用 typical acceptance 替代 rejection sampling,损失了什么?
  • 严格意义上丢了 exact sampling——不再保证输出分布等于 target
  • 但 typical acceptance 用 target 自身的 typical set 阈值约束,质量基本不掉(论文实测和 base 模型 score 接近)
  • 如要严格 exact,可把 verification 换成标准 rejection sampling(Leviathan/Chen 公式)
  • Medusa-1 vs Medusa-2 区分点是训练范式:Medusa-1 冻 backbone 只训 head;Medusa-2 联合训 backbone + head;二者默认都用 typical acceptance

把 Medusa-1 / Medusa-2 的区别说成 "exact vs 非 exact"(错——它们是训练范式不同);说 Medusa 完全等价于 target sampling。

Q19.EAGLE 和 Medusa 的核心差异?
  • Medusa:多 head 直接预测未来 token,independent (不 autoregressive)
  • EAGLE:draft 在 feature space autoregressive(前一步 hidden + 前 token → 下一步 hidden + token)
  • EAGLE 更准(feature 信息丰富),但需要训 draft(含 transformer 层)
  • EAGLE-3 进一步抛 feature regression,直接 token + 多层 fusion + training-time test
  • 实测 EAGLE > Medusa 接受率,但 Medusa 部署更简单(参数更少)

把 EAGLE 当 Medusa 的小改进;说"EAGLE 也是多 head"——错,EAGLE 是 1 个 mini-transformer。

Q20.PagedAttention 和 FlashAttention 关系?
  • FlashAttention:attention kernel 内部 SRAM tiling + online softmax,单 kernel 内优化(避免 materialize $L^2$ scores)
  • PagedAttention:把 KV cache 切 page,按 page table 间接寻址;memory layout 优化
  • 二者正交,可以叠加:vLLM 用 paged + flash 思路写 paged attention kernel
  • 区分点:FlashAttention 减 HBM IO;PagedAttention 减显存碎片

混淆二者;以为 PagedAttention 是 attention 算法变体(实际只是内存管理 + 配套 kernel)。

L3 顶级 lab 题(最严苛级别)

Q21.推 spec decoding 的 acceptance $\alpha$ 完整证明,并解释 sampling 等价性如何推广到 temperature / top-p。
  • 单 token:$\Pr[X=x] = q(x) \min(1, p(x)/q(x)) + (1-\alpha) p'(x)$,代入 $p'$ 得 $\min(p,q) + \max(0, p-q) = p$
  • $\alpha = \sum_x \min(p, q) = 1 - \tfrac{1}{2}\|p-q\|_1$
  • 等价于 TV distance 的连接公式
  • 关键原则:rejection sampling 的等价性只依赖于 "draft proposal 分布 $\tilde q$" 和 "target 目标分布 $\tilde p$" 各自有效。只要把 $p, q$ 在公式里替换成 sampler 处理后的 $\tilde p, \tilde q$,整套等价性照旧
  • Temperature $T$:常见做法是 $\tilde p_T(x) \propto p(x)^{1/T}$ 和 $\tilde q_T(x) \propto q(x)^{1/T}$;把它们代进 $\alpha, p'$ 公式即可
  • Top-p:把 $p$ truncate + renorm 到 $p$ 自己的 top-p 集合得到 $\tilde p$,draft proposal 分布 $\tilde q$ 是 draft 实际采样的那个分布;只要二者都是合法分布,rejection 都 exact
  • 实践中 draft 用与 target 相同的 sampler 是惯例(让 $\tilde q$ 接近 $\tilde p$ 提高 $\alpha$),但不是数学必需——draft 完全 greedy 也合法,只是 $\alpha$ 会暴跌
  • 多 token:每位置 $\alpha_i$ 用对应的 $\tilde p_i, \tilde q_i$;bonus token 用 $K+1$ 位置的修正 logits(经 sampler 处理后)直接采

只写单 token 等价;把"draft 必须用同 sampler"误说成数学必需(实际只是高 $\alpha$ 的策略);忽略 bonus token。

Q22.MLA 的 absorb trick 完整数学推导:为什么 inference 时不用还原 K/V?
  • KV cache:$c_t^{KV} = W^{DKV} h_t \in \mathbb{R}^{d_c}$
  • K, V 升投影:$k_t^{(i)} = W^{UK,(i)} c_t^{KV}, v_t^{(i)} = W^{UV,(i)} c_t^{KV}$
  • Q 同理:$q_t^{(i)} = W^{UQ,(i)} c_t^Q$
  • attention 分数(无 RoPE):$(q_t^{(i)})^\top k_s^{(i)} = (c_t^Q)^\top \underbrace{W^{UQ,(i)\top} W^{UK,(i)}}_{\tilde W^{QK,(i)}} c_s^{KV}$
  • $\tilde W^{QK,(i)}$ 形状 $d_c' \times d_c$,与 (t, s) 无关,加载模型时预乘
  • inference 时直接 $(c_t^Q)^\top \tilde W^{QK,(i)} c_s^{KV}$,完全不算 $k_s$
  • 类似地 attention output:$\text{out}^{(i)} = \sum_s w_s v_s^{(i)} = (\sum_s w_s c_s^{KV})^\top W^{UV,(i)\top}$
  • 把 $W^{UV,(i)}$ 吸进 $W^O$:$W^O_\text{absorbed} = W^O (\text{blockdiag}(W^{UV,(i)}))$
  • 结论:cache 只 latent,compute 在 latent 空间,省 cache 不增 compute

朴素地说"还原 K/V 不就行了"——还原后 compute 退化到 MHA;不知道 absorb 是 inference 专属,训练时不能 absorb 因为要 backprop。

Q23.解释为什么 MLA 在加 RoPE 时必须分离一个独立通道,能不能用别的方式保住 absorb?
  • 核心:RoPE 把 $R_{s-t}$ 塞进 $\tilde W^{QK,(i)}$,破坏"常数矩阵"性质
  • 替代方案 1:把 RoPE 直接放在 latent $c^{KV}$ 上——但 latent 维度小,旋转语义不对(RoPE 设计在 head dim 上配对 sin/cos)
  • 替代方案 2:用 ALiBi(直接加 bias 不旋转)——但破坏 LLaMA-3 兼容预训练
  • 替代方案 3:放弃 absorb,每 step 还原 K/V——compute 退化到 MHA
  • DeepSeek-V2 的选择:decoupled RoPE 通道 $d_r=64$ 所有 head 共享,cache 增量极小(约 5%),content 通道保持 absorb
  • 妙处:这个独立通道在所有 head 间共享 $k_t^R$,是"省 cache 的最后一公里"

说"加 RoPE 不影响 MLA"——错;不知道 decoupled 通道是 head-shared。

Q24.长 context(128K+)下,为什么 vanilla speculative decoding 收益坍塌?怎么救?
  • Vanilla SD 的收益假设:weight loading 是瓶颈,一次 verify 摊销 $K$ 个 token 的 weight load
  • 长 context 下 KV cache 远大于 weights,bandwidth 主要花在读 cache 上
  • 每 verify 读完整 cache 一次,省不了 cache loading
  • 直觉:vanilla SD 加速比 $\propto E[\tau] / (1 + Kc)$ 假设 $T_p$ 主要是 weight loading,但长 context 下 $T_p \approx T_\text{cache\_read} + T_\text{weight\_read}$ 且前者占大头;每次 verify 仍要读全 cache,$K$ 个 token 不能摊销 cache loading,所以 $E[\tau]$ 的优势被吃掉
  • 救法 1:MagicDec — draft 用 sparse KV(StreamingLLM),target 用 full cache verify
  • 救法 2:TriForce — 三层:小 LM → target+sparse cache → target+full cache
  • 救法 3:合并 KV cache 压缩(H2O eviction)+ SD:cache 小了 vanilla SD 也救活

只说"长 context spec decoding 不 work",不知道为什么;不知道 MagicDec/TriForce 是 2024 长 context SD 的 SOTA。

Q25.设计 LLM serving 系统时,决定上什么优化的 mental model 是什么?
  • Step 1 测 workload:prompt 长度分布、生成长度分布、QPS
  • Step 2 按瓶颈选优化:(a) 显存不够装 batch → PagedAttention + prefix caching + KV 量化;(b) 长 prefill 卡 decode → Sarathi-Serve chunked prefill;(c) 短 batch decode 带宽 bound → spec decoding(小 batch 收益最大);(d) 长 context 带宽 bound → MagicDec / TriForce;(e) 跨请求 prompt 重复 → prefix caching + COW
  • Step 3 注意互动:SD + large batch 收益降(large batch 已经 compute-bound);PagedAttention + SD cache rollback 用 page table 改指针;KV 量化 + SD 要 draft/target 用一致 quant scheme
  • Step 4 监控 metrics:tokens/sec, p95 TTFT, p95 TPOT, GPU utilization
  • 关键 trade-off:throughput vs latency,SD 偏 latency 改善,continuous batching 偏 throughput

只罗列技术名词不讲触发条件;不知道 SD 在 large batch 下收益降;忽略真实 workload 测量。

§A 附录:参考实现 + Sanity Check

A.1 组件汇总

参考 from-scratch 实现包含:

A.2 Sanity check 期望输出

[a] naive cache append    prefill (1,16,128) → decode 8 token  ✓
[b] MQA/GQA/MHA shape + cache 大小一致                          ✓
[c] MLA cache = d_c + d_r 元素                                  ✓
[d] spec decode rejection: 100k 样本估 α 与理论值差 < 1%        ✓
[e] spec decode 输出 vs target 直接采样: TV < 0.01              ✓
[f] paged cache COW: ref_count + share 正确                    ✓

A.3 主要参考文献

代码与公式均经独立 reviewer 静态检查(gpt-5.5 xhigh,跨模型),数学等价性论证通过。