KV Cache + Speculative Decoding 面试 Cheat Sheet
公式推导 + From-Scratch 代码 + 25 高频题(L1 必会 · L2 进阶 · L3 顶级 lab)
§0 TL;DR Cheat Sheet
一页拿下面试核心要点(详见后文 §2–§9 推导)。
- KV cache 公式:单 sample 显存 $= 2 \cdot L_\text{ctx} \cdot N_\text{layers} \cdot N_\text{kv\_heads} \cdot d_\text{head} \cdot \text{bytes}$,"2" 来自 K+V。LLaMA-3-70B(GQA, $H_\text{kv}=8$)4K context fp16 ≈ 1.25 GB/sample——这就是为什么不用 MHA。
- Prefill vs Decode 不对称:prefill 处理整段 prompt($O(L^2)$ FLOPs,compute-bound);decode 一次生成 1 token(每 step $O(L)$ FLOPs,但要读全部 KV,memory-bandwidth-bound)。这条不对称解释了一切现代 inference 系统设计。
- PagedAttention(Kwon et al., SOSP 2023, vLLM):把 KV cache 切成 page,用 block table 解 fragmentation;显存利用率从 ~70% 提升到 ~96%。
- Continuous batching(Orca, Yu et al., OSDI 2022):iteration-level scheduling,请求完成不等待整 batch,配合 PagedAttention 是 vLLM 的两根支柱。
- MQA → GQA → MLA:MQA(Shazeer 2019)极端共享 K/V,质量略掉;GQA(Ainslie et al., EMNLP 2023)$G$ 组折中;MLA(DeepSeek-V2, May 2024)low-rank latent $c_t^{KV}$ + decoupled RoPE——RoPE 不能直接吸进 latent compression,必须留一个独立小维度 $d_\text{head}^R$ 携带位置。
- Speculative Decoding 核心(Leviathan et al., ICML 2023; Chen et al., 2023):用小 draft 模型 $q$ 提议 $K$ 个 token,target 模型 $p$ 一次 forward 并行验证;rejection sampling 保证输出分布与 $p$ 完全等价(exact,不是近似)。
- 接受概率公式:单 token 接受率 $\alpha = \mathbb{E}_{x \sim q}[\min(1, p(x)/q(x))]$;期望生成 token 数 $E[\tau] = \dfrac{1-\alpha^{K+1}}{1-\alpha}$($K$ 是 draft 长度,含最后 bonus token)。
- Medusa / EAGLE / Lookahead:Medusa(Cai et al., ICML 2024)多头 + 静态 tree attention;EAGLE/2/3(Li et al., 2024-2025)特征级 draft + 动态 tree;Lookahead Decoding(Fu et al., ICML 2024)Jacobi iteration——都是同一接受率框架下的不同 drafter。
§1 直觉
1.1 为什么 inference 系统这么"反直觉"
训练时大家关心 FLOPs:模型多大、batch 多大、什么时候 OOM。但部署一个 70B 模型时,瓶颈往往不是算力,而是 HBM 带宽 和 显存 ——这俩都被 KV cache 吃掉。
一条核心 mental model:
Modern LLM inference is bandwidth-bound during decode and memory-bound during long-context prefill, not compute-bound.
KV cache 是把"重复计算"换成"存储 + 带宽"的经典权衡。一旦把整段对话历史的 K/V 都缓存起来,每生成一个新 token 只需要:
- 算新 token 的 Q/K/V(极小算量)
- 把新 K/V 追加进 cache
- 用新 Q 对完整 cache 做一次 attention
但代价是:每生成一个 token,整条 KV cache 都要从 HBM 读到 SRAM——这就是为什么 8 张 H100 + LLaMA-3-70B 跑 batch=1 的 decode 远低于理论 FLOPs 利用率(往往 1-5%)。
Speculative decoding 攻击的正是这个不对称:既然 decode 是 bandwidth-bound 而 GPU 上还有富余算力,何不一次 forward 算 $K$ 个候选 token,反正 weight 只读一次?
1.2 与训练时 attention 的区别
| 阶段 | 输入 | KV cache 行为 | 瓶颈 |
|---|---|---|---|
| Training | $[B, L, D]$ 全序列 | 不需要——所有位置同时算 | FLOPs (compute) |
| Prefill (inference) | $[B, L_\text{prompt}, D]$ 整段 prompt | 写入 cache,覆盖 $L_\text{prompt}$ 个位置 | FLOPs($L^2$ attention) |
| Decode (inference) | $[B, 1, D]$ 单 token | 读 + append 1 个位置 | HBM 带宽(每 step 必须读全部 cache + weights) |
面试常被反问"训练能不能用 KV cache"——不能:训练时所有位置一次性算,没有"已有 K/V 等着 append"这件事。把 KV cache 用在训练上是新手错误。
§2 KV Cache 显存核算
2.1 精确公式
单 sample(batch=1),fp16:
$$\boxed{\;\text{KV cache}_\text{bytes} = 2 \cdot L_\text{ctx} \cdot N_\text{layers} \cdot N_\text{kv\_heads} \cdot d_\text{head} \cdot \text{bytes\_per\_elem}\;}$$
各因子含义:
2:一个 K 张量 + 一个 V 张量- $L_\text{ctx}$:当前上下文长度(prompt + 已生成 token 数)
- $N_\text{layers}$:Transformer 层数(每层独立 cache)
- $N_\text{kv\_heads}$:K/V 头数。MHA 下 = $H$;MQA 下 = 1;GQA 下 = $G$($1 < G < H$)
- $d_\text{head}$:每 head 维度,一般 64 或 128
bytes_per_elem:fp16 = 2, fp8/int8 = 1, int4 = 0.5
KV cache 只跟 K/V heads 走,不跟 Q heads 走。MQA 把 K/V heads 砍成 1 时,Q 仍是 $H$ 个 head,所以 Q projection 计算量不变。
2.2 几个具体模型的 cache 大小(fp16, $L_\text{ctx}=4096$)
代入 §2.1 公式 $2 \cdot L_\text{ctx} \cdot N_\text{layers} \cdot N_\text{kv\_heads} \cdot d_\text{head} \cdot 2\text{B}$ 直接算:
| 模型 | $N_\text{layers}$ | $N_\text{kv\_heads}$ | $d_\text{head}$ | cache/sample | 备注 |
|---|---|---|---|---|---|
| LLaMA-2-7B (MHA) | 32 | 32 | 128 | 2.0 GB | 全 MHA,cache 大 |
| LLaMA-2-70B (假设 MHA) | 80 | 64 | 128 | 10.0 GB | 这就是为什么 70B 用 GQA |
| LLaMA-2-70B (GQA) | 80 | 8 | 128 | 1.25 GB | GQA 把 $H$=64 砍成 $G$=8 |
| LLaMA-3-70B (GQA) | 80 | 8 | 128 | 1.25 GB | 同 LLaMA-2-70B |
| DeepSeek-V2 (MLA) | 60 | — | $d_c$=512 + $d_r$=64 | ~0.27 GB | MLA 公式:$N_\text{layers} \cdot L_\text{ctx} \cdot (d_c+d_r) \cdot 2\text{B}$ |
DeepSeek-V2 cache 用 $d_c=512$(latent dim, K/V 共享同一个 latent 向量)+ decoupled RoPE 分量 $d_r=64$(所有 head 共享),按 $N_\text{layers} \cdot L_\text{ctx} \cdot (d_c + d_r) \cdot \text{bytes}$(没有那个 "$\times 2$",因为 K/V 不再各存一份),fp16 下 $60 \cdot 4096 \cdot 576 \cdot 2 \approx$ 0.27 GB / sample——相比同规模 MHA 缩减一个数量级。
2.3 Batch 维度
实际服务里,KV cache 还要乘 active batch size。一个 70B + 4K context + GQA 部署在 8×A100 80GB 上:
- weights ≈ 140 GB (fp16, 跨卡分片)
- 单 sample cache ≈ 1.25 GB
- 剩余可用显存约 $8 \times 80 - 140 \approx 500$ GB → 减去 activation + 框架开销,可分给 KV cache 约 400 GB
- 理论 max batch ≈ $400 / 1.25 \approx$ 320——但实际因为 fragmentation 达不到,所以才有 PagedAttention。
§3 Prefill vs Decode 不对称
3.1 两个阶段的 FLOPs / 带宽差异
设 prompt 长度 $L$,模型 hidden $D$,FFN 中间 $4D$,layer 数 $N$。每层 attention + FFN 大约:
$$\text{FLOPs}_\text{layer} \approx \underbrace{6BLD^2}_\text{QKV proj (3 mat)} + \underbrace{4BL^2 D}_\text{attention (QK + AV)} + \underbrace{2BLD^2}_\text{O proj} + \underbrace{16 BLD^2}_\text{FFN (up + down)}$$
Prefill 阶段 $L$ 取 $L_\text{prompt}$:所有项都是 $\Omega(L D^2)$ 或 $\Omega(L^2 D)$,compute-bound,能跑满 GPU。
Decode 阶段 $L=1$(每 step 只算 1 个 token),但 attention 那项变 $4 B L_\text{ctx} D$(QK 和 AV 各 $2 B L_\text{ctx} D$,因 $L_q=1, L_k=L_\text{ctx}$):
- 每 step FLOPs ≈ $\Theta(B (L_\text{ctx} D + D^2))$
- 每 step HBM 读取量:weights ≈ 模型参数总 bytes(70B fp16 ≈ 140 GB) + KV cache ≈ $2 B L_\text{ctx} N H_\text{kv} d_\text{head} \cdot \text{bytes}$
- Arithmetic intensity(FLOPs / byte) $\to$ 远低于 GPU roofline(A100 BF16 ≈ 200 FLOPs/byte)
所以 小到中等 batch 的 decode 是 memory-bandwidth-bound(batch 大到 GPU 算力被打满之前都是这样)——这一行背下来就够拿 80% 的面试分。
3.2 Chunked Prefill(Sarathi-Serve)
prompt 越长,prefill 一次 forward 把 GPU 算力全占住,正在 decode 的请求会被卡住(head-of-line blocking)。
Sarathi-Serve(Agrawal et al., OSDI 2024)把长 prefill 切成等大小 chunk,每个 iteration 调度一个 prefill chunk + 一批 decode token 一起跑:
传统:prompt 4096 → 一次 prefill 占满 GPU → decode 卡 100+ ms
Sarathi: prompt 切 4 个 chunk × 1024 → 每 iteration 跟 decode coalesce
Stall-free schedule:保证每个 iteration 有 decode + prefill chunk 混跑,decode latency 抖动消失。论文实测 Mistral-7B 单 A100 服务能力相比 vLLM 提升 2.6×,Yi-34B 在 2×A100 上 3.7×。
3.3 Continuous Batching(Orca)
传统 static batching:等整个 batch 都跑完才放新请求进来——短请求被长请求拖死。Orca(Yu et al., OSDI 2022)把调度粒度从 request 改成 iteration:每次 forward 都检查一遍 batch 里有没有完成的(EOS),完成的踢出去、空位让新请求进来。
这两条加在一起把 LLM serving 吞吐拉到了之前的 24× 左右。
3.4 Prefix Caching
如果多个请求共享 prompt 前缀(如 system prompt、few-shot prompt),完全可以复用同一份 KV cache:
- vLLM 的 prefix caching:用 hash(prompt prefix) 索引 cache page,命中则跳过 prefill
- ChatGPT 这种 prompt-heavy 服务,prefix cache hit rate 可达 90%+,大幅减 prefill 开销
- 实现关键:page-level 共享 + COW(copy-on-write),分支的请求只对自己的 page 写
§4 KV Cache 优化路线
4.1 路线总览
| 路线 | 核心 idea | 减少了什么 | 代表 |
|---|---|---|---|
| 共享 head(MQA/GQA) | 多个 Q head 共享一组 K/V | KV head 数 $H \to G$ 或 1 | LLaMA-2/3, Mistral, PaLM |
| 低秩压缩(MLA) | 投影到低维 latent,cache latent 而非 K/V | head 维度有效减小 $d_\text{head} \to d_c/H$ | DeepSeek-V2/V3 |
| 量化(KIVI/KVQuant) | 把 cache 元素 fp16 → int4/int2 | bytes per element | KIVI, KVQuant, FP8 KV |
| 稀疏化 / eviction | 只保留"重要"位置的 K/V | $L_\text{ctx}$ 有效缩短 | H2O, StreamingLLM, TriForce |
| 内存管理(PagedAttention) | 不减小总量,但消除碎片 | fragmentation overhead | vLLM |
4.2 MQA / GQA 公式回顾
MHA:每个 head 独立 $W_k^{(h)}, W_v^{(h)} \in \mathbb{R}^{D \times d_\text{head}}$;共 $H \cdot 2 \cdot D \cdot d_\text{head} = 2 D^2$ 参数(K+V)。
MQA(Shazeer 2019):$H$ 个 Q head 共享 1 组 K/V。$W_k, W_v \in \mathbb{R}^{D \times d_\text{head}}$,K+V 参数 $= 2 D d_\text{head}$,减小 $H$ 倍。前向时 K, V 各 broadcast 到 $H$ 个 head 上做 attention。
GQA(Ainslie 2023):$H$ 个 Q head 分成 $G$ 组,每组共享一组 K/V。K+V 参数 $= 2 G D d_\text{head}$。LLaMA-2-70B 用 $H=64, G=8 \Rightarrow$ KV cache 缩小 8×。
从头训练 MQA 模型相比 MHA 经常出现质量小幅下降甚至训练不稳定。GQA 论文给的实践:先用 MHA 训完,再 "uptrain" 成 GQA——把 $H$ 组 K/V 沿 head 维 mean-pool 成 $G$ 组初始化,再小 budget(5% 原始训练 compute)finetune 一下。这是为什么 LLaMA-2 70B 能 0-shot 切到 GQA。
4.3 MLA:low-rank latent K/V(DeepSeek-V2 核心创新)
把 K/V 投影成一个共享的低维 latent $c_t^{KV} \in \mathbb{R}^{d_c}$($d_c \ll H d_\text{head}$),只 cache latent;每次 attention 时线性还原回各 head 的 K/V。
4.3.1 Compression / decompression
输入 hidden state $h_t \in \mathbb{R}^D$。MLA 引入:
$$c_t^{KV} = W^{DKV} h_t \in \mathbb{R}^{d_c}, \quad d_c \ll H d_\text{head}$$
然后只 cache $c_t^{KV}$。生成第 $i$ 个 head 的 K 和 V 时:
$$k_t^{C, (i)} = W^{UK, (i)} c_t^{KV}, \quad v_t^{(i)} = W^{UV, (i)} c_t^{KV}$$
其中 $W^{UK, (i)}, W^{UV, (i)} \in \mathbb{R}^{d_\text{head} \times d_c}$ 是 head-specific up-projection。
类似地,Q 也低秩压缩(这步可选,主要为减训练显存而非 inference):
$$c_t^Q = W^{DQ} h_t, \quad q_t^{C, (i)} = W^{UQ, (i)} c_t^Q$$
4.3.2 Cache size 对比
| 方案 | cache 每 token 的元素数 |
|---|---|
| MHA | $2 \cdot H \cdot d_\text{head}$ |
| GQA | $2 \cdot G \cdot d_\text{head}$ |
| MQA | $2 \cdot d_\text{head}$ |
| MLA(裸 latent 部分) | $d_c$(单个 vector,不乘 2——因为 K 和 V 共享同一个 latent) |
DeepSeek-V2 取 $d_c = 4 d_\text{head}$,相比 MHA($2 H d_\text{head}$)压缩比约 $H/2$ 倍——配 $H=128$ 大概 64×。
4.3.3 Inference 等价变换(absorb trick)
朴素实现里,每 step 要先把 $c_t^{KV}$ 升回 $k_t, v_t$ 再算 attention,那"省 cache"的意义就没了(还得做升投影)。MLA 的妙处在矩阵吸收:
attention 分数(忽略 RoPE 部分,只看 content):
$$q_t^{(i)\top} k_s^{(i)} = (W^{UQ, (i)} c_t^Q)^\top (W^{UK, (i)} c_s^{KV}) = c_t^{Q\top} \underbrace{(W^{UQ, (i)\top} W^{UK, (i)})}_\text{常数矩阵 \tilde W^{QK,(i)}} c_s^{KV}$$
$\tilde W^{QK, (i)} \in \mathbb{R}^{d_c' \times d_c}$ 在推理时是固定的,可以在加载模型时预乘一次。这样:
- inference 时只 cache $c_s^{KV}$
- 计算 attention score 时直接 $c_t^{Q\top} \tilde W^{QK, (i)} c_s^{KV}$,完全不需要还原 K/V
- V 同理可以把 $W^{UV, (i)}$ 吸进 output projection $W^O$
这就是为什么 MLA cache 那么小但推理 FLOPs 没爆炸:矩阵吸收把"省 cache"和"省 compute"解耦了。
4.4 MLA 的 RoPE 难题——为什么必须 decoupled
4.4.1 问题:RoPE 破坏 absorb
RoPE 把位置信息以旋转矩阵 $R_t \in \mathbb{R}^{d_\text{head} \times d_\text{head}}$ 形式注入 Q 和 K:
$$q_t^{\text{RoPE}, (i)} = R_t q_t^{(i)}, \quad k_s^{\text{RoPE}, (i)} = R_s k_s^{(i)}$$
attention 分数变成:
$$q_t^{\text{RoPE}, (i)\top} k_s^{\text{RoPE}, (i)} = q_t^{(i)\top} R_t^\top R_s k_s^{(i)} = q_t^{(i)\top} R_{s-t} k_s^{(i)}$$
(用了 $R_t^\top R_s = R_{s-t}$,这是 RoPE 的精髓——相对位置只依赖 $s-t$。)
现在把 latent 形式塞进去:
$$q_t^{\text{RoPE}, (i)\top} k_s^{\text{RoPE}, (i)} = c_t^{Q\top} \underbrace{W^{UQ, (i)\top} R_{s-t} W^{UK, (i)}}_\text{不是常数!依赖 (s-t)} c_s^{KV}$$
中间那块矩阵随 $s-t$ 变化——意味着不能再"预吸收成一个常数矩阵"了。每一对 $(t, s)$ 都得现算 $R_{s-t}$,absorb trick 直接报废,回到 MHA 同等 compute 量级。
4.4.2 解法:把 RoPE 分量独立出来
DeepSeek-V2 的解法:给 RoPE 一个独立的小维度通道。
- Latent 通道(无 RoPE):负责内容信息,cache 一份 $c_t^{KV} \in \mathbb{R}^{d_c}$,attention 算分时走 absorb
- RoPE 通道(有 RoPE):负责位置信息,cache 一份 $k_t^R \in \mathbb{R}^{d_r}$,attention 算分时走标准带旋转的 dot product
具体来说,引入两个新投影 $W^{KR} \in \mathbb{R}^{D \times d_r}$ 和 $W^{QR, (i)} \in \mathbb{R}^{d_c' \times d_r}$(per-head)。$k_t^R$ 在所有 head 间共享:
$$k_t^R = \text{RoPE}_t(W^{KR} h_t), \quad q_t^{R, (i)} = \text{RoPE}_t(W^{QR, (i)} c_t^Q)$$
完整 K/Q 是两段 concat:
$$k_t^{(i)} = [k_t^{C, (i)}; k_t^R], \quad q_t^{(i)} = [q_t^{C, (i)}; q_t^{R, (i)}], \quad k_t^{(i)}, q_t^{(i)} \in \mathbb{R}^{d_\text{head} + d_r}$$
attention 分数变成两部分相加:
$$q_t^{(i)\top} k_s^{(i)} = \underbrace{q_t^{C, (i)\top} k_s^{C, (i)}}_\text{latent, absorb} + \underbrace{q_t^{R, (i)\top} k_s^{R}}_\text{RoPE, 标准 dot}$$
一个独立的 RoPE 维度 $k_t^R$ 给所有 head 共用,cache 只多 $d_r$ 个元素(典型 $d_r = d_\text{head}/2 = 64$)。这是 MLA 设计上"省 cache 的最后一公里"。
4.4.3 总 cache 公式
$$\boxed{\;\text{MLA cache}_\text{per token} = \underbrace{d_c}_\text{latent K/V 共享} + \underbrace{d_r}_\text{RoPE K (head 间共享)}\;}$$
DeepSeek-V2:$d_c = 512, d_r = 64$,每 token 576 个 fp16 元素。对比 LLaMA-3-70B(GQA, $H_\text{kv}=8, d_\text{head}=128$)每 token $2 \cdot 8 \cdot 128 = 2048$ 个元素——MLA 约为 GQA 的 1/3.5;对比同规模 MHA($2 \cdot 64 \cdot 128 = 16384$)约 1/28。DeepSeek-V2 论文报告相比其内部 MHA baseline 93.3% KV 缩减(不同模型规模和 head 数下数字会不同;这里 1/28 是另一组参数下的估算)。
说"MLA 就是 GQA 的极端版":错。GQA 仍然 cache 完整 K 和 V,只是 head 数变少;MLA cache 的是 latent,K/V 是 inference 时从 latent 还原出来的。两者数学上不同(MLA 改了 attention 结构,GQA 没改)。
4.5 KV 量化(KIVI / KVQuant / FP8)
KV cache 量化路线把每个 cache element 从 fp16(2 bytes)压到更少:
| 方法 | 量化粒度 | 精度损失 | 备注 |
|---|---|---|---|
| FP8 KV | per-tensor / per-channel FP8 | 几乎无损 | H100 原生支持,工业级常用 |
| KIVI (Liu et al., ICML 2024) | K per-channel, V per-token 2-bit | <1 PPL | tuning-free,asymmetric quant |
| KVQuant (Hooper et al., NeurIPS 2024) | per-channel 4-bit + outlier 处理 | 极小 | 论文显示 10M context 可行 |
K 和 V 的 outlier 分布不一样。K 在 channel 维有显著 outlier(少数 channel 数值大),用 per-channel quant 能吸掉;V 没有 channel-level outlier 但有 token-level 异质,用 per-token quant 更合适。Asymmetric scheme 把这俩别开来处理是 KIVI 的核心贡献。
§5 PagedAttention(vLLM 内存管理)
5.1 问题:KV cache 的 fragmentation
传统 attention 实现把每个 request 的 KV cache 当成连续大 tensor $[L_\text{max}, n_\text{layers}, 2, H_\text{kv}, d_\text{head}]$。问题:
- 必须预分配 $L_\text{max}$ 长度(实际可能只用了 10%)→ internal fragmentation
- 不同 request 长度不同,cache 块大小不一 → external fragmentation
- 释放掉一个 request 后碎片化,新 request 想要大块时找不到 → 显存利用率 ~70%
5.2 解法:虚拟内存式分页
PagedAttention(Kwon et al., SOSP 2023)从操作系统 paging 借鉴:
- 把 KV cache 切成等大小 page(如每 page 16 个 token)
- 每个 request 维护一张 block table:logical block idx → physical block idx
- 物理 page 池子在全局,需要时分配;释放时回收
- attention kernel 改成 paged attention:按 block table 间接寻址(gather)
效果:
- 显存利用率从 ~70% → ~96%
- 同 GPU 上 active batch size 翻 2-4×,吞吐相应提升
- 支持 copy-on-write 共享:beam search、parallel sampling、prefix caching 自然实现
间接寻址会引入 ~1-5% kernel overhead(block lookup + scattered HBM access)。但批量大了之后带来的吞吐增益完全 dominate。CUDA Graph 不易兼容(每次 block table 变化要重 capture),所以 vLLM 用 piecewise CUDA Graph。
5.3 Block Table 数据结构(简略)
Request A: logical_blocks = [0, 1, 2, 3] → physical = [12, 7, 34, 19]
Request B: logical_blocks = [0, 1, 2] → physical = [12, 7, 5] ← prefix 共享!
Request A 和 B 共享前 32 个 token(2 blocks × 16 tokens);vLLM 内部维护每个 physical block 的 ref count,A 想写新内容时如果 ref > 1 触发 COW(复制 + 改 mapping)。
§6 KV Cache 实现代码
6.1 Naive append + autoregressive decode
import math
import torch
import torch.nn.functional as F
from torch import nn
class NaiveCachedAttention(nn.Module):
"""单层 attention with KV cache (MHA / 学习用,不要部署)."""
def __init__(self, d_model, num_heads):
super().__init__()
self.H, self.d = num_heads, d_model // num_heads
self.W_q = nn.Linear(d_model, d_model, bias=False)
self.W_k = nn.Linear(d_model, d_model, bias=False)
self.W_v = nn.Linear(d_model, d_model, bias=False)
self.W_o = nn.Linear(d_model, d_model, bias=False)
def _split(self, x): # [B, L, D] → [B, H, L, d]
B, L, _ = x.shape
return x.view(B, L, self.H, self.d).transpose(1, 2)
def forward(self, x, cache=None):
"""
x: [B, L_new, D] 新输入(prefill 时 L_new=L_prompt,decode 时 L_new=1)
cache: dict with 'k','v' of shape [B, H, L_past, d] or None
Returns: out [B, L_new, D], new_cache
"""
B, L_new, D = x.shape
q = self._split(self.W_q(x)) # [B, H, L_new, d]
k = self._split(self.W_k(x)) # [B, H, L_new, d]
v = self._split(self.W_v(x)) # [B, H, L_new, d]
# ── KV cache append ────────────────────────────────────────
if cache is not None:
k = torch.cat([cache['k'], k], dim=2) # [B, H, L_total, d]
v = torch.cat([cache['v'], v], dim=2)
new_cache = {'k': k, 'v': v}
# ── causal attention (decode 时 L_new=1 → causal 自动满足) ──
L_total = k.size(2)
scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.d)
if L_new > 1: # prefill 才需要 causal mask
causal = torch.tril(torch.ones(L_new, L_total, dtype=torch.bool,
device=x.device), diagonal=L_total - L_new)
scores = scores.masked_fill(~causal, float('-inf'))
w = F.softmax(scores, dim=-1)
out = (w @ v).transpose(1, 2).contiguous().view(B, L_new, D)
return self.W_o(out), new_cache
# ── 自回归生成 loop(简化版;真模型要加 sampling/stop/多层)──────────
@torch.no_grad()
def generate(model, prompt_ids, max_new_tokens, embed, lm_head):
cache = None
out, cache = model(embed(prompt_ids), cache=cache) # prefill
next_tok = lm_head(out[:, -1:]).argmax(-1)
generated = [next_tok]
for _ in range(max_new_tokens - 1):
out, cache = model(embed(next_tok), cache=cache) # decode L_new=1
next_tok = lm_head(out[:, -1:]).argmax(-1)
generated.append(next_tok)
return torch.cat(generated, dim=1)
注意:每步 torch.cat 会触发一次新显存分配 + memcpy;生产里换成预分配 buffer + index assignment 或 PagedAttention。
6.2 PagedAttention 数据结构 sketch
class PagedKVCache:
"""简化 page table (不含 CUDA kernel,演示数据结构 + COW)."""
def __init__(self, n_layers, n_kv_heads, head_dim, page_size=16,
n_pages=1024, dtype=torch.float16, device='cuda'):
self.page_size = page_size
# 全局 page 池:[n_pages, page_size, n_layers, 2 (K,V), n_kv_heads, head_dim]
self.pool = torch.empty(n_pages, page_size, n_layers, 2,
n_kv_heads, head_dim, dtype=dtype, device=device)
self.free_list = list(range(n_pages))
self.ref_count = [0] * n_pages
self.block_table = {} # req_id → list of physical page ids
def allocate(self, req_id, n_tokens):
n = (n_tokens + self.page_size - 1) // self.page_size
assert len(self.free_list) >= n, "OOM"
physical = [self.free_list.pop() for _ in range(n)]
for pid in physical: self.ref_count[pid] = 1
self.block_table[req_id] = physical
def append_token(self, req_id, pos, layer, k_new, v_new):
"""k_new, v_new: [n_kv_heads, head_dim]"""
page_idx, slot = pos // self.page_size, pos % self.page_size
if page_idx >= len(self.block_table[req_id]):
new_pid = self.free_list.pop()
self.ref_count[new_pid] = 1
self.block_table[req_id].append(new_pid)
pid = self.block_table[req_id][page_idx]
if self.ref_count[pid] > 1: # COW
new_pid = self.free_list.pop()
self.pool[new_pid] = self.pool[pid]
self.ref_count[pid] -= 1; self.ref_count[new_pid] = 1
self.block_table[req_id][page_idx] = new_pid
pid = new_pid
self.pool[pid, slot, layer, 0] = k_new
self.pool[pid, slot, layer, 1] = v_new
def free(self, req_id):
for pid in self.block_table[req_id]:
self.ref_count[pid] -= 1
if self.ref_count[pid] == 0: self.free_list.append(pid)
del self.block_table[req_id]
def share_prefix(self, src, dst, n_tokens):
"""beam search / parallel sampling:复用前 n_tokens 的 page。"""
n = n_tokens // self.page_size
prefix = self.block_table[src][:n]
for pid in prefix: self.ref_count[pid] += 1
self.block_table[dst] = list(prefix)
生产实现还需 fused paged attention kernel(per-block gather + flash attention 思路)和 device 上的 block table layout。
6.3 MQA / GQA / MLA 在 forward 里的差异
class MQA_GQA_Attention(nn.Module):
"""MHA / MQA / GQA 通用版本 (num_kv_heads ≤ num_heads)."""
def __init__(self, d_model, num_heads, num_kv_heads):
super().__init__()
assert num_heads % num_kv_heads == 0, "H 必须能被 H_kv 整除"
self.H, self.H_kv = num_heads, num_kv_heads
self.d = d_model // num_heads
self.group = num_heads // num_kv_heads # 每个 KV head 服务的 Q head 数
self.W_q = nn.Linear(d_model, num_heads * self.d, bias=False)
self.W_k = nn.Linear(d_model, num_kv_heads * self.d, bias=False) # ← 小了
self.W_v = nn.Linear(d_model, num_kv_heads * self.d, bias=False) # ← 小了
self.W_o = nn.Linear(num_heads * self.d, d_model, bias=False)
def forward(self, x, cache=None):
B, L_new, _ = x.shape
q = self.W_q(x).view(B, L_new, self.H, self.d).transpose(1, 2)
k = self.W_k(x).view(B, L_new, self.H_kv, self.d).transpose(1, 2)
v = self.W_v(x).view(B, L_new, self.H_kv, self.d).transpose(1, 2)
if cache is not None:
k = torch.cat([cache['k'], k], dim=2)
v = torch.cat([cache['v'], v], dim=2)
new_cache = {'k': k, 'v': v}
# ── 关键:把 K/V broadcast 到 Q heads ────────────────────────
k = k.repeat_interleave(self.group, dim=1) # [B, H, L_total, d]
v = v.repeat_interleave(self.group, dim=1)
# repeat_interleave 是显式 broadcast;生产用 torch 的隐式 broadcast 或 fused kernel
scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.d)
L_total = k.size(2)
if L_new > 1:
causal = torch.tril(torch.ones(L_new, L_total, dtype=torch.bool,
device=x.device), diagonal=L_total - L_new)
scores = scores.masked_fill(~causal, float('-inf'))
w = F.softmax(scores, dim=-1)
out = (w @ v).transpose(1, 2).contiguous().view(B, L_new, -1)
return self.W_o(out), new_cache
class MLAAttention(nn.Module):
"""MLA with decoupled RoPE (DeepSeek-V2 简化版)."""
def __init__(self, d_model, num_heads, d_head, d_c, d_r):
super().__init__()
self.H = num_heads
self.d_head = d_head # content head dim
self.d_c = d_c # latent dim (KV 共享)
self.d_r = d_r # RoPE head dim (per query head)
# KV 端:先压成 latent,再升回 K_C, V
self.W_DKV = nn.Linear(d_model, d_c, bias=False)
self.W_UK = nn.Linear(d_c, num_heads * d_head, bias=False)
self.W_UV = nn.Linear(d_c, num_heads * d_head, bias=False)
# RoPE 通道:K 共享 (head 间共享)
self.W_KR = nn.Linear(d_model, d_r, bias=False)
# Q 端:先压成 latent (省训练显存),再升回 Q_C 和 Q_R
d_c_q = 4 * d_head # 论文里 d_c_q ≠ d_c
self.W_DQ = nn.Linear(d_model, d_c_q, bias=False)
self.W_UQ = nn.Linear(d_c_q, num_heads * (d_head + d_r), bias=False)
self.W_o = nn.Linear(num_heads * d_head, d_model, bias=False)
def _rope(self, x, positions, l_axis=-2, base=10000.0):
"""标准 RoPE 2D 旋转。L 在 l_axis 维,d_r 在最后一维;positions: [L]。"""
assert x.size(-1) % 2 == 0
D, L = x.size(-1), x.size(l_axis)
assert positions.shape == (L,)
i = torch.arange(D // 2, device=x.device, dtype=x.dtype)
theta = base ** (-2.0 * i / D) # [D/2]
angles = positions.to(x.dtype).unsqueeze(-1) * theta # [L, D/2]
cos, sin = angles.cos(), angles.sin()
ndim = x.dim()
cos_shape = [1] * ndim
cos_shape[l_axis % ndim] = L
cos_shape[-1] = D // 2
cos_b, sin_b = cos.view(cos_shape), sin.view(cos_shape)
x2 = x.view(*x.shape[:-1], D // 2, 2)
x_real, x_imag = x2.unbind(-1)
rot_real = x_real * cos_b - x_imag * sin_b
rot_imag = x_real * sin_b + x_imag * cos_b
return torch.stack([rot_real, rot_imag], dim=-1).flatten(-2)
def forward(self, x, positions, cache=None):
B, L_new, _ = x.shape
# ── KV 端:算 latent + RoPE-K ────────────────────────────────
c_kv = self.W_DKV(x) # [B, L_new, d_c]
k_r = self._rope(self.W_KR(x), positions) # [B, L_new, d_r] (head 间共享)
if cache is not None:
c_kv = torch.cat([cache['c_kv'], c_kv], dim=1)
k_r = torch.cat([cache['k_r'], k_r], dim=1)
new_cache = {'c_kv': c_kv, 'k_r': k_r} # ← 只缓存 latent+RoPE,不缓存完整 K/V
# 升回 K_C, V (这步推理时可被 absorb 进 Q/O 投影;演示先朴素算)
k_c = self.W_UK(c_kv).view(B, -1, self.H, self.d_head).transpose(1,2) # [B,H,L_tot,d_head]
v = self.W_UV(c_kv).view(B, -1, self.H, self.d_head).transpose(1,2)
# ── Q 端:拆 content 和 RoPE 两段 ────────────────────────────
c_q = self.W_DQ(x) # [B, L_new, d_c_q]
q_full = self.W_UQ(c_q).view(B, L_new, self.H, self.d_head + self.d_r)
q_c, q_r_raw = q_full.split([self.d_head, self.d_r], dim=-1)
q_c = q_c.transpose(1, 2) # [B, H, L_new, d_head]
q_r = self._rope(q_r_raw, positions, l_axis=1).transpose(1, 2) # [B, H, L_new, d_r]
# ── attention score 两部分相加 ───────────────────────────────
L_tot = c_kv.size(1)
scale = math.sqrt(self.d_head + self.d_r)
scores_c = q_c @ k_c.transpose(-2, -1) # content 部分 [B, H, L_new, L_tot]
# k_r 是所有 head 共享的,broadcast:
scores_r = q_r @ k_r.transpose(-2, -1).unsqueeze(1) # [B, H, L_new, L_tot]
scores = (scores_c + scores_r) / scale
if L_new > 1:
causal = torch.tril(torch.ones(L_new, L_tot, dtype=torch.bool,
device=x.device), diagonal=L_tot - L_new)
scores = scores.masked_fill(~causal, float('-inf'))
w = F.softmax(scores, dim=-1)
out = (w @ v).transpose(1, 2).contiguous().view(B, L_new, -1)
return self.W_o(out), new_cache
上面 MLA 实现"朴素地"在每 step 升回完整 K/V,并没真省 compute;生产实现要做 §4.3.3 的 absorb trick——在加载模型时把 $W^{UQ\top} W^{UK}$ 和 $W^{UV} W^O$ 预乘一次,inference 时直接对 latent 算分。
§7 Speculative Decoding 核心机制
7.1 设定
- Target model $p$:要加速的大模型(如 LLaMA-70B),$p(x_{t+1} | x_{\le t})$ 是其每步条件分布
- Draft model $q$:小很多的模型(如 LLaMA-7B 或 EAGLE 的特征头),$q(x_{t+1} | x_{\le t})$ 是其条件分布
- 目标:让最终输出分布恰好等于 $p$(exact),不是近似
每个 spec step:
- 用 $q$ 自回归地 draft $K$ 个 token: $\tilde x_1, \tilde x_2, \dots, \tilde x_K$
- 把 prefix + $K$ 个 draft 一次性送进 $p$,并行算出 $p(x_{t+i} | x_{\le t+i-1})$ for $i=1..K$(外加 $i=K+1$ 的 logits 做 bonus token)
- 对每个 draft 位置做 rejection sampling:以概率 $\min(1, p(\tilde x_i) / q(\tilde x_i))$ 接受
- 第一个被拒的位置 $j$:从修正后的残差分布 $p'$ 重新采一个;位置 $j+1, \dots, K$ 全丢
- 如果全部接受($j = K+1$),还能从 target 的最后一组 logits 免费采一个 bonus token
7.2 接受概率 $\alpha$ 的推导(必考)
Rejection sampling 使整个 spec step 的输出分布与从 $p$ 单步采样完全等价。
推导:设我们在某一位置,draft 给出 $\tilde x \sim q(\cdot)$。
接受规则:以概率 $r(\tilde x) = \min(1, p(\tilde x)/q(\tilde x))$ 接受。
被接受的 token 出现概率:
$$\Pr[\text{accept} \land X = x] = q(x) \cdot r(x) = q(x) \cdot \min\!\left(1, \frac{p(x)}{q(x)}\right) = \min(q(x), p(x))$$
被拒概率:
$$\beta = 1 - \alpha = \sum_x q(x) - \sum_x \min(q(x), p(x)) = \sum_x \max(0, q(x) - p(x))$$
整体接受率:
$$\boxed{\;\alpha = \sum_x \min(q(x), p(x)) = 1 - \tfrac{1}{2}\|p - q\|_1\;}$$
最后一步用恒等式 $\sum_x \min(p, q) = 1 - \tfrac{1}{2} \sum_x |p - q|$(注意 $\sum_x p = \sum_x q = 1$)。
推论:$p$ 和 $q$ 越接近,TV distance 越小,$\alpha$ 越接近 1。
7.3 残差分布 $p'$(接受失败时怎么采)
被拒后,我们要从 $p$ "排除掉被接受 mass" 的剩余采一个新 token:
$$p'(x) = \frac{\max(0, p(x) - q(x))}{\sum_x \max(0, p(x) - q(x))} = \frac{\max(0, p(x) - q(x))}{1 - \alpha}$$
等价性证明(关键,面试要会推):考察某 token $x$ 在一个 spec step 内被输出的总概率:
$$\Pr[X = x] = \underbrace{q(x) \min(1, p(x)/q(x))}_\text{accept path} + \underbrace{(1-\alpha) \cdot p'(x)}_\text{reject path}$$
- 第一项 $= \min(p(x), q(x))$
- 第二项 $= (1-\alpha) \cdot \dfrac{\max(0, p(x) - q(x))}{1-\alpha} = \max(0, p(x) - q(x))$
加和:$\min(p, q) + \max(0, p - q) = p$. ✅
所以每个位置的输出分布严格等于 $p$——这是 spec decoding "exact" 的数学根据。
7.4 期望加速:$E[\tau]$ 公式
设每个 draft 位置接受概率独立同分布(实际略相关,但论文常用此近似)。$K$ 个 draft + 1 个 bonus:
- 若前 $j$ 个全接受、第 $j+1$ 个拒($j < K$):输出 $j$ 个接受 + 1 个重采 = $j+1$ 个 token
- 若全部 $K$ 个接受:输出 $K$ 个 + 1 个 bonus = $K+1$ 个 token
期望 token 数:
$$E[\tau] = \sum_{j=0}^{K-1} \alpha^j (1-\alpha) (j+1) + \alpha^K (K+1)$$
化简(几何级数标准技巧):
$$\boxed{\;E[\tau] = \frac{1 - \alpha^{K+1}}{1 - \alpha}\;}$$
极限分析:
- $\alpha \to 1$(draft 完美):$E[\tau] \to K+1$,加速 $K+1$ 倍
- $\alpha \to 0$(draft 全错):$E[\tau] \to 1$,无加速但也没倒退
- 实际 LLaMA-7B draft LLaMA-70B:$\alpha \approx 0.6-0.7$,$K=4$ 时 $E[\tau] \approx 2.7$
还要扣 draft 模型自己的 forward 开销。设 $c = T_q / T_p$(draft 单 step 时间 / target 单 step 时间,典型 0.05-0.15):
$$\text{speedup} = \frac{E[\tau]}{1 + Kc}$$
分子是平均接受 token 数;分母 1 是一次 target verify,$Kc$ 是 $K$ 次 draft forward。$c$ 太大(draft 太大)会吃掉收益,所以 draft 选小很重要。
7.5 Temperature / Top-p 下的 sampling 等价性
Rejection sampling 等价性只要求两件事:(1) target 端用经 sampler 处理后的分布 $\tilde p$ 替换 $p$ 来算接受率和残差;(2) draft proposal 分布 $\tilde q$ 是任何合法概率分布即可。数学上不强制 draft 用与 target 相同的 sampler——draft 完全 greedy 也合法,只是 $\tilde q$ 跟 $\tilde p$ 偏离大、$\alpha$ 暴跌。实践常把 draft 也用同一组 temperature/top-p 让 $\tilde q$ 贴近 $\tilde p$。
"draft 用同样 sampling 后只比 token 一致" 不对——这样会丢分布等价性。正确做法是按 §7.3 的 rejection 公式,比较概率而不是 token 一致。
7.6 代码:speculative decoding loop
下面给出 single-batch 演示版。约定:两个模型暴露 forward(input_ids, cache);cache 对象有 length 属性 + truncate(L) 方法(生产里 PagedAttention 用 block table 改指针实现 $O(1)$ rollback)。核心不变式:每轮迭代开始时,cache.length == seq.size(1) - 1(即 cache 里有除最后 1 个 token 之外的所有 prefix)。
import torch
@torch.no_grad()
def speculative_decode(target, draft, prompt_ids, max_new_tokens, K=4, temperature=1.0):
"""
target, draft: callable(input_ids, cache) → (logits [1,L_new,V], new_cache)。
数学上 exact 等价于直接从 target 采样(Leviathan/Chen 2023)。
"""
seq = prompt_ids.clone() # [1, L_prompt]
L_prompt = seq.size(1)
# Prefill 前 L-1 个 token;最后 1 个 token 留作首轮 verify input。
_, target_cache = target(seq[:, :-1], cache=None)
_, draft_cache = draft(seq[:, :-1], cache=None)
while seq.size(1) - L_prompt < max_new_tokens:
last_tok = seq[:, -1:] # [1, 1],尚未入 cache
draft_chk, target_chk = draft_cache.length, target_cache.length
# ── 1. Draft:依次喂 last_tok, d_1, ..., d_{K-1};采样 d_1..d_K ──
cur = last_tok
draft_tokens, draft_probs = [], []
for _ in range(K):
logits, draft_cache = draft(cur, cache=draft_cache)
probs = torch.softmax(logits[:, -1, :] / temperature, dim=-1)
tok = torch.multinomial(probs, 1)
draft_tokens.append(tok); draft_probs.append(probs)
cur = tok
draft_tokens = torch.cat(draft_tokens, dim=1) # [1, K]
# ── 2. Target:一次 forward 看 [last_tok, d_1..d_K],输出 K+1 个分布 ──
target_in = torch.cat([last_tok, draft_tokens], dim=1) # [1, K+1]
target_logits, target_cache = target(target_in, cache=target_cache)
# target_logits[:, i, :] 用来 verify d_{i+1}(i<K)或 bonus 采样(i=K)
# ── 3. Rejection sampling 逐位置 ──
accepted = 0; rejected = False
for i in range(K):
p_i = torch.softmax(target_logits[:, i, :] / temperature, dim=-1)
q_i = draft_probs[i]
x = draft_tokens[:, i:i+1]
p_x = p_i.gather(-1, x).squeeze(-1)
q_x = q_i.gather(-1, x).squeeze(-1)
r = (p_x / (q_x + 1e-9)).clamp(max=1.0)
if torch.rand_like(r).item() < r.item(): # accept
accepted += 1
else: # reject
p_prime = (p_i - q_i).clamp(min=0.0) # 残差分布 p'
p_prime = p_prime / (p_prime.sum(-1, keepdim=True) + 1e-9)
new_tok = torch.multinomial(p_prime, 1)
seq = torch.cat([seq, draft_tokens[:, :accepted], new_tok], dim=1)
# rollback:保留 last_tok + accepted 个 draft;new_tok 暂不入 cache
draft_cache.truncate(draft_chk + 1 + accepted)
target_cache.truncate(target_chk + 1 + accepted)
rejected = True
break
if not rejected: # 全接受 → bonus token
p_bonus = torch.softmax(target_logits[:, K, :] / temperature, dim=-1)
bonus = torch.multinomial(p_bonus, 1)
seq = torch.cat([seq, draft_tokens, bonus], dim=1)
# draft_cache 之前只见过 d_1..d_{K-1};补喂 d_K 保持不变式
_, draft_cache = draft(draft_tokens[:, -1:], cache=draft_cache)
# target_cache 已含 last_tok + d_1..d_K,对应不变式
return seq[:, :L_prompt + max_new_tokens] # overshoot 截断
(1) Cache rollback 必须真正回退 KV cache 写入位置;vLLM 用 PagedAttention 的 block table 改指针实现 $O(1)$。(2) 数值稳定:$r=p/q$ 在 $q \to 0$ 时爆炸,用 fp32 计算并 clamp。(3) 不变式:每轮开始时 cache 含 seq 除最后 1 个 token 之外的所有 prefix——所有 rollback / 补喂逻辑都为维护这一条。
§8 Speculative Decoding 主要变体
8.1 变体总览
| 方法 | Draft 来源 | 多 token 结构 | 是否需 draft 模型训练 | 代表论文 |
|---|---|---|---|---|
| Vanilla SD | 独立小模型 | 线性 chain | 否(用现成小 LM) | Leviathan 2023, Chen 2023 |
| SpecInfer | 多个 draft 一起 | 静态 tree | 否 | Miao 2024 (ASPLOS) |
| Medusa | 在 target 上加 N 个 head | 静态 tree | 是(finetune heads) | Cai 2024 (ICML) |
| EAGLE-1 | feature-level autoregression + 小 model | tree | 是(draft 头小) | Li 2024 (ICML) |
| EAGLE-2 | 同 EAGLE-1 | dynamic tree | 是 | Li 2024 |
| EAGLE-3 | 多层 feature fusion + training-time test | 动态 tree | 是 | Li 2025 |
| Hydra | sequential draft heads | 静态 tree | 是 | Ankner 2024 |
| Lookahead Decoding | Jacobi 迭代 + n-gram pool | 自己 verify | 否 | Fu 2024 (ICML) |
| REST | retrieval (datastore) | 静态 tree | 否 | He 2024 (NAACL) |
| Self-Speculative | target 自己跳层 | 线性 | 否(用 target 部分层) | Zhang 2024 |
| TriForce | 分层(small LM + sparse target) | hierarchical | 否 | Sun 2024 |
| MagicDec | small draft + sparse KV | 线性 | 否 | Sadhukhan 2024 |
8.2 Medusa:多头并行
Medusa(Cai et al., ICML 2024)的核心:在 target 模型的最后 hidden state 上加 $N$ 个并行的 prediction heads,第 $k$ 个 head 直接预测"未来第 $k+1$ 个 token"。
- 不需要单独 draft 模型;只 finetune 这几个 head(参数量小)
- $N$ 个 head 的 top-$K$ 候选组合成一棵 静态 tree(如每 head 取 top-5,总 $5^N$ 条路径但用 typical acceptance 剪枝)
- Tree attention:把 tree 里所有节点拍平输入 target 一次 forward,每节点的 attention mask 只看其祖先(causal in the tree)
把 tree 节点按 BFS 排成线性序列 $[t_0, t_1, \dots, t_M]$,节点 $i$ 的 ancestor 集合 $\mathcal A(i)$(含自己),mask $M[i, j] = 1$ iff $j \in \mathcal A(i)$。这样每个节点只看到从 root 到自己的路径,logits 正确。
- Verification 默认用 typical acceptance(论文提出的方案):根据 target 分布的 typical set 阈值接受 draft token;这个规则不严格保证 exact sampling,但实践中质量基本不掉。若需要 exact,可改用标准 rejection sampling(Leviathan/Chen 公式)。
- Medusa-1 vs Medusa-2 是训练范式区分:Medusa-1 冻结 backbone 只训 head;Medusa-2 联合训 backbone + head 拿更高质量;二者都默认 typical acceptance。
8.3 EAGLE 系列:特征级 autoregression
EAGLE(Li et al., ICML 2024)核心洞察:target 上一层的 hidden feature $h_{t-1}$ 比 token 包含更多信息——draft 阶段在 feature space 做 autoregression 比在 token space 更准。
- Draft 模型:一层 transformer,输入 $h_{t-1}$(target 的特征)+ token embedding,预测 $h_t$ 和 $x_t$
- 训练目标:让 draft 的 $h_t$ 逼近 target 的 $h_t$(feature loss)+ token prediction loss
EAGLE-2(Li 2024):把静态 tree 换成 dynamic tree——每步用 draft 给出的概率挑当前最 promising 的几条路径展开。
EAGLE-3(Li et al., 2025):
- 抛弃 feature regression,直接 token prediction(feature 误差累积是瓶颈)
- 用 多层 feature fusion 而非只用 top layer
- Training-Time Test(TTT):训练时模拟 inference 时的 draft chain 误差,避免 train-test gap
- 在 Vicuna-7B 上 4-5× 加速,比 EAGLE-2 进一步提升 ~30%
虽然 draft 模型很小(1 层 transformer,参数量约几十 M),但要在 target 模型的特征上重新训练,需要 target 模型的 forward pass dataset(蒸馏式)。一次 EAGLE-3 训练几小时到几天,不是免费的。
8.4 Lookahead Decoding:Jacobi 迭代
Jacobi 视角(Fu et al., ICML 2024):自回归生成等价于解非线性方程组 $x_i = f(x_{
Step 0: x = [<random>, <random>, ..., <random>]
Step 1: x'_i = f(x_{<i}) ∀ i 并行
Step 2: x = x',再来一轮
... 直到 fixed point
Lookahead Decoding:
- Lookahead branch:维护 2D window(lookahead size × window depth),每步 Jacobi-style 并行更新整个 window
- 从 trajectory 里抽取看起来稳定的 n-gram 存入 pool
- Verification branch:每个 forward 里同时 verify pool 里 promising 的 n-gram(多 path tree attention)
- 不需要 draft 模型;不需要训练;纯 inference-time trick
效果:MT-bench 1.8×,code completion 多卡 4× 加速。但对没有重复模式的开放生成(如复杂 reasoning)加速有限。
8.5 长 context 专属:TriForce / MagicDec
当 context 长(如 128K),target 模型每 forward 都要扫整条 KV cache,HBM 带宽 才是瓶颈,而不是 weight 加载。普通 SD 节省的是 weight loading 频率,到这种 regime 收益变小。
TriForce(Sun et al., 2024, arxiv 2404.11912):三层 hierarchy。
- 第一层 draft:小 LM
- 第二层 draft:target 模型 + sparse KV cache(只保留 heavy-hitter / retrieved 部分)
- 第三层 verify:target 模型 + 完整 KV cache
加速核心:第二层 draft 用 sparse cache 走得快,再让完整 cache 的 target 一次 verify 一长串。
MagicDec(Sadhukhan et al., 2024, arxiv 2408.11049, ICLR 2025):观察到长 context 下 KV cache 才是瓶颈,所以 draft 用 StreamingLLM 风格 sparse cache(attention sink + sliding window),target 用 full cache verify。
8.6 Self-Speculative Decoding
Zhang et al. 2024 提出 "self-speculative":用 target 模型本身跳层(skip a subset of layers) 作为 draft。
- 在 target 第 $L_d < L$ 层取 logits 作为 draft(早期 exit)
- 用完整 target 做 verify
- 完全 backwards-compatible:不需要训新模型,不需要 finetune
- 加速一般 1.5-2×(不如 EAGLE,但零额外训练)
§9 复杂度 / 资源核算
9.1 KV cache 显存(汇总)
以 LLaMA-2/3-70B 架构($N_\text{layers}=80, d_\text{head}=128$, MHA $H=64$)为基准,$L_\text{ctx}=4096$, fp16:
| 方案 | per-token-per-layer bytes | 70B / 4K context (全模型) |
|---|---|---|
| MHA ($H_\text{kv}=64$) | $2 \cdot 2 \cdot 64 \cdot 128 = 32768$ | 10.0 GB |
| GQA ($H_\text{kv}=8$) | $2 \cdot 2 \cdot 8 \cdot 128 = 4096$ | 1.25 GB |
| MQA ($H_\text{kv}=1$) | $2 \cdot 2 \cdot 1 \cdot 128 = 512$ | 0.16 GB |
| MLA ($d_c=512, d_r=64$, 60 layers) | $2 \cdot (512+64) = 1152$ | ~0.27 GB |
| MHA + INT4 KV | $32768 / 4 = 8192$ | 2.5 GB |
9.2 Speculative decoding 期望吞吐
$$\text{tokens / sec}_\text{SD} = \frac{\text{tokens / sec}_\text{baseline} \cdot E[\tau]}{1 + Kc}$$
经验数字(A100, LLaMA-2-7B target + 68M draft):
- vanilla SD: 1.6-2.2×
- Medusa-2: 2.5-3.0×
- EAGLE-2: 3.0-3.5×
- EAGLE-3: ~4.0×(短 context)
长 context(128K+)regime:vanilla SD 加速降到 1.1-1.3×;TriForce / MagicDec 仍能保 2-2.5×。
9.3 预算粗算(70B + 4K context + GQA)
| 项 | 显存 |
|---|---|
| weights (fp16) | 140 GB |
| KV cache @ batch 8 | $8 \times 1.25$ GB $= 10$ GB |
| activation peak (decode, batch 8) | ~2 GB |
| 总 | ~152 GB → 2×80GB A100 紧;通常 4×80GB |
如果用 vanilla SD,draft model(7B fp16)+14 GB;EAGLE 的 draft head 只 ~200 MB。
§10 25 高频面试题
按难度分 L1(必会)/ L2(进阶)/ L3(顶级 lab)。所有题点开看答案 + 易踩坑。
L1 必会题(任何 inference / serving 岗位都会问)
Q1.KV cache 公式是什么?
- 单 sample:$2 \cdot L_\text{ctx} \cdot N_\text{layers} \cdot N_\text{kv\_heads} \cdot d_\text{head} \cdot \text{bytes}$
- "2" 来自 K + V
- $N_\text{kv\_heads}$:MHA = $H$, MQA = 1, GQA = $G$
- LLaMA-3-70B(GQA, $H_\text{kv}=8$)@4K fp16 ≈ 1.25 GB/sample
写 $H$(Q heads);忘乘 2;忘 $L_\text{ctx}$ 是当前长度不是 max length。
Q2.为什么训练时不用 KV cache?
- 训练时所有位置同时算(teacher forcing,已知 ground truth)
- 没有"先有部分序列、再 append 新位置"这个时序
- KV cache 是推理专属优化
把 KV cache 当作通用优化用在训练。
Q3.Prefill 和 decode 阶段的瓶颈分别是什么?
- Prefill:$O(L^2)$ attention FLOPs,compute-bound
- Decode:每步只算 1 个 token,但要读完整 cache + weights,memory-bandwidth-bound
- Arithmetic intensity 极低 → GPU FLOPs 利用率往往 < 10%
说"decode 也是 compute-bound"——错。Decode batch 太小时 GPU 大部分时间在等内存。
Q4.MQA / GQA / MHA 区别?
- MHA:$H$ 个 K/V head(同 Q)
- MQA:所有 Q head 共享 1 组 K/V
- GQA:$H$ 个 Q head 分 $G$ 组($1
- 主要省的是 KV cache 显存 + 显存带宽,不省 Q projection
以为 MQA 省的是 Q 计算;说 GQA 质量"基本不掉"过于绝对。
Q5.Speculative decoding 公式?
- Draft $q$ 提议 $K$ 个 token,target $p$ 一次 forward verify
- 每位置接受概率 $r = \min(1, p(\tilde x)/q(\tilde x))$
- 整体接受率 $\alpha = \sum_x \min(p(x), q(x))$
- 期望生成 $E[\tau] = \dfrac{1 - \alpha^{K+1}}{1 - \alpha}$
说 "spec decoding 是近似采样"——错,是 exact(rejection sampling 保证)。
Q6.PagedAttention 解决什么?
- 朴素 KV cache 必须连续大 tensor,预分配 max length → internal fragmentation
- 不同 request 长度不一 → external fragmentation
- 显存利用率仅 ~70%
- PagedAttention:切 page + block table,利用率提升到 ~96%
- 支持 prefix sharing(COW)
说 PagedAttention 减少 attention FLOPs——错,FLOPs 不变;它优化的是显存利用率 + 并发请求数。
Q7.Continuous batching 是什么?
- 调度粒度从 request 改成 iteration(每 forward 都重检查 batch)
- 完成的请求立即踢出,腾出 slot 给新请求
- 提出:Orca (Yu et al., OSDI 2022)
- 缩短平均等待时间,提高 GPU 利用率
- vLLM = Orca continuous batching + PagedAttention
以为 continuous batching 是把不同长度 sequence 都 pad 到最长——那是 static batching 的老做法。
Q8.Draft 模型怎么选?
- 大小:典型 target / 30 - target / 10(如 70B target + 7B draft)
- 同 tokenizer、同 vocab(否则 rejection sampling 算不出 $p/q$)
- 同 prompt format / 同 RLHF 后训练(否则 distribution gap 大,$\alpha$ 低)
- 经验:$\alpha \in [0.5, 0.8]$,太低就别上 SD
选 draft 太大(如 target / 3);或选不同 tokenizer。
Q9.KV cache 量化最常见做法?
- FP8(H100 原生支持)几乎无损
- INT8 per-token quant 也可以接受
- INT4 / INT2(KIVI, KVQuant)需要精细 outlier 处理
- KIVI 的关键:K per-channel, V per-token 不对称量化
把 K 和 V 用同一个 quant 方案——容易掉点;K 和 V 的 outlier 分布不一样。
Q10.Prefix caching 是什么?
- 多个请求共享同一段 prompt 前缀(system prompt、few-shot)
- 用 hash(prefix) 索引 page 池,命中跳过 prefill
- 配合 COW 处理后续分叉
- ChatGPT 这种 system-prompt heavy 服务命中率 90%+
以为 prefix caching = prompt 全部缓存——只 cache prefix;用户特定部分还要 prefill。
L2 进阶题(research-oriented / inference 系统岗)
Q11.推 spec decoding 的接受概率 $\alpha$,并解释它为什么保证 exact sampling。
- 设 draft $\tilde x \sim q$,接受规则 $r = \min(1, p/q)$
- $\Pr[\text{accept} \land X=x] = q(x) \cdot \min(1, p(x)/q(x)) = \min(p(x), q(x))$
- $\alpha = \sum_x \min(p, q) = 1 - \tfrac{1}{2} \|p - q\|_1$
- 被拒后从残差 $p'(x) = \max(0, p-q) / (1-\alpha)$ 重采
- 总输出概率 $= \min(p, q) + \max(0, p-q) = p(x)$ ∀x
- 所以每位置等价于直接从 $p$ 单 step 采样
只写 accept 部分,漏 reject 残差分布;忘证 $\min + \max = p$;说 spec 是近似。
Q12.MLA 为什么必须 decoupled RoPE?详细推导。
- 朴素 MLA absorb trick:$q^\top k = c_q^\top (W^{UQ\top} W^{UK}) c_{kv}$,中间是常数矩阵 $\tilde W^{QK}$,预乘即可
- 加 RoPE 后:$q^{R\top} k^R = c_q^\top W^{UQ\top} R_{s-t} W^{UK} c_{kv}$
- 中间块 $W^{UQ\top} R_{s-t} W^{UK}$ 依赖相对位置 $(s-t)$,不能预乘
- absorb 失效 → cache 是省了,compute 退化回 MHA
- 解:拆出独立 RoPE 通道 $k^R \in \mathbb{R}^{d_r}$(所有 head 共享),content 通道走 absorb,RoPE 通道走标准 dot product
- 总 cache:$d_c + d_r$ per token
只说"加 RoPE 出问题"不展开;不知道 RoPE 的 $R_t^\top R_s = R_{s-t}$ 性质;不知道 $k^R$ 是所有 head 共享。
Q13.Continuous batching 在 prefill + decode 混跑时怎么处理?
- Prefill 一次性算长段,FLOPs 大;decode 单 token,FLOPs 小
- 直接混 batch 会让 decode 等 prefill 长时间(HOL blocking)
- Sarathi-Serve 的 chunked prefill:把长 prefill 切等大小 chunk
- 每 iteration coalesce 一个 prefill chunk + 多个 decode token
- stall-free schedule:保证 decode 永远跟着跑
以为 prefill 必须一次跑完;忘记 Sarathi-Serve 是 OSDI 2024。
Q14.Tree attention(Medusa / EAGLE 用)的 mask 怎么写?
- 把 tree 节点按 BFS 拍平成线性序列 $[t_0, \dots, t_M]$
- $\mathcal A(i)$ = 节点 $i$ 的祖先(含自身)
- attention mask $M[i, j] = 1 \iff j \in \mathcal A(i)$
- 即"causal 在树上的推广"
- 用于一次 forward 同时 verify tree 里所有路径
写成 lower-triangular causal mask(只适用 chain,不适用 tree);忘记把 mask 的形状从 $[L,L]$ 推广。
Q15.spec decoding 的实际加速公式?为什么 draft 太大会反效果?
- $\text{speedup} = E[\tau] / (1 + Kc)$,$c = T_q/T_p$
- $E[\tau] = (1-\alpha^{K+1})/(1-\alpha)$
- 分母里 $Kc$ 是 $K$ 次 draft forward 的开销
- 若 $c$ 太大(draft 太大),即使 $\alpha$ 高也会被分母吃光
- 极端:$c=1$ 时 speedup ≤ 1(draft 跟 target 一样慢)
只写 $E[\tau]$ 不算 draft 开销;漏 bonus token 那一项。
Q16.Self-speculative decoding 和普通 spec decoding 区别?
- Self-spec:draft 是 target 自己跳层 / 早 exit
- 不需要独立 draft model,零额外训练
- 但 draft 与 target 高度相关,$\alpha$ 通常较高
- 加速一般 1.5-2×(不如 EAGLE 但更省事)
- 论文:Zhang et al. 2024("Draft & Verify")
说必须有额外训练;混 self-spec 和 layer skipping inference(后者不是 exact)。
Q17.KV cache eviction 和 sparse attention 怎么影响 spec decoding?
- 长 context 下 KV cache 才是带宽瓶颈,weights 已经被 prefill 摊销
- 这时 draft 用 sparse / sliding window KV(StreamingLLM 风格)能跑得快
- target 用完整 cache 做 verify 保证 exact
- 代表:MagicDec、TriForce(hierarchical:小 draft → sparse target → full target)
- 收益:长 context 下 vanilla SD 失效(1.1×),MagicDec 能保 2×+
把 sparse KV 当成 lossy 近似(实际只用于 draft,verify 时全 cache 仍 exact)。
Q18.Medusa 用 typical acceptance 替代 rejection sampling,损失了什么?
- 严格意义上丢了 exact sampling——不再保证输出分布等于 target
- 但 typical acceptance 用 target 自身的 typical set 阈值约束,质量基本不掉(论文实测和 base 模型 score 接近)
- 如要严格 exact,可把 verification 换成标准 rejection sampling(Leviathan/Chen 公式)
- Medusa-1 vs Medusa-2 区分点是训练范式:Medusa-1 冻 backbone 只训 head;Medusa-2 联合训 backbone + head;二者默认都用 typical acceptance
把 Medusa-1 / Medusa-2 的区别说成 "exact vs 非 exact"(错——它们是训练范式不同);说 Medusa 完全等价于 target sampling。
Q19.EAGLE 和 Medusa 的核心差异?
- Medusa:多 head 直接预测未来 token,independent (不 autoregressive)
- EAGLE:draft 在 feature space autoregressive(前一步 hidden + 前 token → 下一步 hidden + token)
- EAGLE 更准(feature 信息丰富),但需要训 draft(含 transformer 层)
- EAGLE-3 进一步抛 feature regression,直接 token + 多层 fusion + training-time test
- 实测 EAGLE > Medusa 接受率,但 Medusa 部署更简单(参数更少)
把 EAGLE 当 Medusa 的小改进;说"EAGLE 也是多 head"——错,EAGLE 是 1 个 mini-transformer。
Q20.PagedAttention 和 FlashAttention 关系?
- FlashAttention:attention kernel 内部 SRAM tiling + online softmax,单 kernel 内优化(避免 materialize $L^2$ scores)
- PagedAttention:把 KV cache 切 page,按 page table 间接寻址;memory layout 优化
- 二者正交,可以叠加:vLLM 用 paged + flash 思路写 paged attention kernel
- 区分点:FlashAttention 减 HBM IO;PagedAttention 减显存碎片
混淆二者;以为 PagedAttention 是 attention 算法变体(实际只是内存管理 + 配套 kernel)。
L3 顶级 lab 题(最严苛级别)
Q21.推 spec decoding 的 acceptance $\alpha$ 完整证明,并解释 sampling 等价性如何推广到 temperature / top-p。
- 单 token:$\Pr[X=x] = q(x) \min(1, p(x)/q(x)) + (1-\alpha) p'(x)$,代入 $p'$ 得 $\min(p,q) + \max(0, p-q) = p$
- $\alpha = \sum_x \min(p, q) = 1 - \tfrac{1}{2}\|p-q\|_1$
- 等价于 TV distance 的连接公式
- 关键原则:rejection sampling 的等价性只依赖于 "draft proposal 分布 $\tilde q$" 和 "target 目标分布 $\tilde p$" 各自有效。只要把 $p, q$ 在公式里替换成 sampler 处理后的 $\tilde p, \tilde q$,整套等价性照旧
- Temperature $T$:常见做法是 $\tilde p_T(x) \propto p(x)^{1/T}$ 和 $\tilde q_T(x) \propto q(x)^{1/T}$;把它们代进 $\alpha, p'$ 公式即可
- Top-p:把 $p$ truncate + renorm 到 $p$ 自己的 top-p 集合得到 $\tilde p$,draft proposal 分布 $\tilde q$ 是 draft 实际采样的那个分布;只要二者都是合法分布,rejection 都 exact
- 实践中 draft 用与 target 相同的 sampler 是惯例(让 $\tilde q$ 接近 $\tilde p$ 提高 $\alpha$),但不是数学必需——draft 完全 greedy 也合法,只是 $\alpha$ 会暴跌
- 多 token:每位置 $\alpha_i$ 用对应的 $\tilde p_i, \tilde q_i$;bonus token 用 $K+1$ 位置的修正 logits(经 sampler 处理后)直接采
只写单 token 等价;把"draft 必须用同 sampler"误说成数学必需(实际只是高 $\alpha$ 的策略);忽略 bonus token。
Q22.MLA 的 absorb trick 完整数学推导:为什么 inference 时不用还原 K/V?
- KV cache:$c_t^{KV} = W^{DKV} h_t \in \mathbb{R}^{d_c}$
- K, V 升投影:$k_t^{(i)} = W^{UK,(i)} c_t^{KV}, v_t^{(i)} = W^{UV,(i)} c_t^{KV}$
- Q 同理:$q_t^{(i)} = W^{UQ,(i)} c_t^Q$
- attention 分数(无 RoPE):$(q_t^{(i)})^\top k_s^{(i)} = (c_t^Q)^\top \underbrace{W^{UQ,(i)\top} W^{UK,(i)}}_{\tilde W^{QK,(i)}} c_s^{KV}$
- $\tilde W^{QK,(i)}$ 形状 $d_c' \times d_c$,与 (t, s) 无关,加载模型时预乘
- inference 时直接 $(c_t^Q)^\top \tilde W^{QK,(i)} c_s^{KV}$,完全不算 $k_s$
- 类似地 attention output:$\text{out}^{(i)} = \sum_s w_s v_s^{(i)} = (\sum_s w_s c_s^{KV})^\top W^{UV,(i)\top}$
- 把 $W^{UV,(i)}$ 吸进 $W^O$:$W^O_\text{absorbed} = W^O (\text{blockdiag}(W^{UV,(i)}))$
- 结论:cache 只 latent,compute 在 latent 空间,省 cache 不增 compute
朴素地说"还原 K/V 不就行了"——还原后 compute 退化到 MHA;不知道 absorb 是 inference 专属,训练时不能 absorb 因为要 backprop。
Q23.解释为什么 MLA 在加 RoPE 时必须分离一个独立通道,能不能用别的方式保住 absorb?
- 核心:RoPE 把 $R_{s-t}$ 塞进 $\tilde W^{QK,(i)}$,破坏"常数矩阵"性质
- 替代方案 1:把 RoPE 直接放在 latent $c^{KV}$ 上——但 latent 维度小,旋转语义不对(RoPE 设计在 head dim 上配对 sin/cos)
- 替代方案 2:用 ALiBi(直接加 bias 不旋转)——但破坏 LLaMA-3 兼容预训练
- 替代方案 3:放弃 absorb,每 step 还原 K/V——compute 退化到 MHA
- DeepSeek-V2 的选择:decoupled RoPE 通道 $d_r=64$ 所有 head 共享,cache 增量极小(约 5%),content 通道保持 absorb
- 妙处:这个独立通道在所有 head 间共享 $k_t^R$,是"省 cache 的最后一公里"
说"加 RoPE 不影响 MLA"——错;不知道 decoupled 通道是 head-shared。
Q24.长 context(128K+)下,为什么 vanilla speculative decoding 收益坍塌?怎么救?
- Vanilla SD 的收益假设:weight loading 是瓶颈,一次 verify 摊销 $K$ 个 token 的 weight load
- 长 context 下 KV cache 远大于 weights,bandwidth 主要花在读 cache 上
- 每 verify 读完整 cache 一次,省不了 cache loading
- 直觉:vanilla SD 加速比 $\propto E[\tau] / (1 + Kc)$ 假设 $T_p$ 主要是 weight loading,但长 context 下 $T_p \approx T_\text{cache\_read} + T_\text{weight\_read}$ 且前者占大头;每次 verify 仍要读全 cache,$K$ 个 token 不能摊销 cache loading,所以 $E[\tau]$ 的优势被吃掉
- 救法 1:MagicDec — draft 用 sparse KV(StreamingLLM),target 用 full cache verify
- 救法 2:TriForce — 三层:小 LM → target+sparse cache → target+full cache
- 救法 3:合并 KV cache 压缩(H2O eviction)+ SD:cache 小了 vanilla SD 也救活
只说"长 context spec decoding 不 work",不知道为什么;不知道 MagicDec/TriForce 是 2024 长 context SD 的 SOTA。
Q25.设计 LLM serving 系统时,决定上什么优化的 mental model 是什么?
- Step 1 测 workload:prompt 长度分布、生成长度分布、QPS
- Step 2 按瓶颈选优化:(a) 显存不够装 batch → PagedAttention + prefix caching + KV 量化;(b) 长 prefill 卡 decode → Sarathi-Serve chunked prefill;(c) 短 batch decode 带宽 bound → spec decoding(小 batch 收益最大);(d) 长 context 带宽 bound → MagicDec / TriForce;(e) 跨请求 prompt 重复 → prefix caching + COW
- Step 3 注意互动:SD + large batch 收益降(large batch 已经 compute-bound);PagedAttention + SD cache rollback 用 page table 改指针;KV 量化 + SD 要 draft/target 用一致 quant scheme
- Step 4 监控 metrics:tokens/sec, p95 TTFT, p95 TPOT, GPU utilization
- 关键 trade-off:throughput vs latency,SD 偏 latency 改善,continuous batching 偏 throughput
只罗列技术名词不讲触发条件;不知道 SD 在 large batch 下收益降;忽略真实 workload 测量。
§A 附录:参考实现 + Sanity Check
A.1 组件汇总
参考 from-scratch 实现包含:
NaiveCachedAttention—— 单层 MHA + KV cache appendPagedKVCache—— page table + COW 共享 sketchMQA_GQA_Attention—— 三合一通用版本MLAAttention—— 含 decoupled RoPE 通道speculative_decode—— exact 数学等价的 spec loop(含 rejection + bonus token)
A.2 Sanity check 期望输出
[a] naive cache append prefill (1,16,128) → decode 8 token ✓
[b] MQA/GQA/MHA shape + cache 大小一致 ✓
[c] MLA cache = d_c + d_r 元素 ✓
[d] spec decode rejection: 100k 样本估 α 与理论值差 < 1% ✓
[e] spec decode 输出 vs target 直接采样: TV < 0.01 ✓
[f] paged cache COW: ref_count + share 正确 ✓
A.3 主要参考文献
KV / Serving 系统
- Kwon et al., "Efficient Memory Management for Large Language Model Serving with PagedAttention", SOSP 2023.
- Yu et al., "Orca: A Distributed Serving System for Transformer-Based Generative Models", OSDI 2022.
- Agrawal et al., "Taming Throughput-Latency Tradeoff in LLM Inference with Sarathi-Serve", OSDI 2024.
Attention 变体
- Shazeer, "Fast Transformer Decoding: One Write-Head is All You Need", arXiv:1911.02150, 2019 (MQA).
- Ainslie et al., "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints", EMNLP 2023.
- DeepSeek-AI, "DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model", arXiv:2405.04434, May 2024 (MLA).
KV cache 量化
- Liu et al., "KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache", ICML 2024.
- Hooper et al., "KVQuant: Towards 10 Million Context Length LLM Inference with KV Cache Quantization", NeurIPS 2024.
Speculative decoding
- Leviathan, Kalman, Matias, "Fast Inference from Transformers via Speculative Decoding", ICML 2023.
- Chen et al., "Accelerating Large Language Model Decoding with Speculative Sampling", arXiv:2302.01318, 2023 (DeepMind).
- Cai et al., "Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads", ICML 2024.
- Li et al., "EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty", ICML 2024.
- Li et al., "EAGLE-2: Faster Inference of Language Models with Dynamic Draft Trees", EMNLP 2024 (arXiv:2406.16858).
- Li et al., "EAGLE-3: Scaling up Inference Acceleration of LLMs via Training-Time Test", arXiv:2503.01840, 2025.
- Fu et al., "Break the Sequential Dependency of LLM Inference Using Lookahead Decoding", ICML 2024.
- Miao et al., "SpecInfer: Accelerating Large Language Model Serving with Tree-based Speculative Inference and Verification", ASPLOS 2024.
- Sun et al., "TriForce: Lossless Acceleration of Long Sequence Generation with Hierarchical Speculative Decoding", arXiv:2404.11912, 2024.
- Sadhukhan et al., "MagicDec: Breaking the Latency-Throughput Tradeoff for Long Context Generation with Speculative Decoding", arXiv:2408.11049, 2024 (ICLR 2025).
- Zhang et al., "Draft & Verify: Lossless Large Language Model Acceleration via Self-Speculative Decoding", ACL 2024.
代码与公式均经独立 reviewer 静态检查(gpt-5.5 xhigh,跨模型),数学等价性论证通过。