全词表直接扩散 → 借用冻结 encoder(ELF)→ 专用 VAE latent(Cola)→ 统一框架(DSL)。
沿途文章:CFM · FLM/FMLM · DFM · LangFlow · ELF · Cola-DLM · DSL / DSL-LLaDA 等 2026 上半年同期工作。
🎯 一张图先建立直觉(给熟悉图像扩散的读者):2026 上半年的 continuous DLM,可以粗略挂到图像扩散的两条熟悉坐标轴上 ——
怎么读:前三格在空间轴(全空间 → 借来的表征 → 专用 VAE latent,从左到右压缩递增);DSL 在正交的框架轴,是贯穿三格的"统一框架"而非"更压缩的第四格"。这是帮你快速入门的直觉类比,不是算法同构:DDPM↔FM 只对应"原始全维、不压缩";ELF 的 latent 是 token-aligned、不降序列的 frozen 通用 encoder 特征,≠ 图像 VAE 压缩 latent;Cola 的 VAE 与 DiT 联合训练、不压序列(≠ SD 那种先训好再冻结的固定 VAE);DSL 强调的是 per-token SNR path 如何把多种范式收成特例。
📎 姊妹篇:图像 / 视频扩散侧「借表征 / 用流形」的完整脉络(SSL · Consistency Models · REPA · RAE · JiT · V-JEPA2 …),见同一作者的 《扩散模型 × 表征学习 × 流形学习》——本文「借 frozen T5 表征 / x-prediction」正是那条脉络在语言上的落点。
这篇怎么读:以 ELF 为锚(§2、§4–§8)→ §3 把它放进同期 5 家 Flow-Matching 工作里横向定位 → §9–§10 沿「连续表征坐标系从哪来」这条轴,拉出三条并列路线:ELF(借 frozen T5 embedding)/ Cola-DLM(字节,学一个 VAE latent,§9)/ DSL(UC Riverside,token 单位球面上 stochastic localization、几何随 SNR 动态涌现,§10;含 8B 实证 DSL-LLaDA §10.6)。
ELF(Embedded Language Flows, MIT; Hu/Qiu et al., senior 作者 Kim/Andreas/He, arXiv 2605.10938, 2026-05-11) = 在冻结的 T5-small contextual embedding 空间里做连续时间 Flow Matching,只在最后一步离散化; denoiser 和 final-step decoder 共享同一个 Transformer 权重。 一句话定位:"continuous DLM 的瓶颈可能不在'连续'本身——而在过去工作把 encoder 放进训练联合学。ELF 把它冻结,diffusion 只学如何在已有几何中 transport,在其 ablation 中优于 learnable encoder 的选项。"
关键数字(完整对比见 §3 / §7):ELF-B 仅 105M 参数(+35M frozen T5)、OWT Gen-PPL 24.08(32 步、无蒸馏),训练只用 45.2B tokens——约同期 MDLM/Duo/FLM 的 1/12。
Gen-PPL = 模型自采样经 GPT-2 Large 打分的 perplexity,越低越像自然语言。
三个值得记住的点(关于这条研究线,不只关于 ELF):
历史定位:2026 上半年 continuous-DLM 出现 6+ 篇集中工作。FM 家族按发布时间排: CFM (Feb)、FLM/FMLM (Feb)、DFM (Apr)、LangFlow (Apr)、ELF (May); 外加 ByteDance Cola-DLM (May 末) 的 latent-VAE 路线,以及 UC Riverside 的 DSL / DSL-LLaDA 的 unit-sphere stochastic-localization 路线。 三条路线各占一个位置:ELF 在 FM 家族里 size 最小(105M)、结构最简(无 distillation / 无 latent VAE),32 步无蒸馏即接近 dataset reference PPL;Cola 走最大尺度、reasoning-focused 的 latent-VAE 路线;DSL 把离散 masked diffusion 收成连续 SNR 谱的特例,统一 continuous / discrete / AR。 详见 §3 (FM 家族 5-way)、§9 (Cola 对比) 和 §10 (DSL 第三条路线 + 8B 实证)。
| 年份 | 工作 | 主要贡献 | 评测尺度 |
|---|---|---|---|
| 2021 | D3PM (Austin et al. NeurIPS) | 定义离散扩散框架:mask / uniform / embedding transition | text8, LM1B 小尺度 |
| 2024 | SEDD (Lou et al. ICML Best Paper) | Score entropy loss 给离散空间一个 clean 损失 | OWT scale |
| 2024 | MDLM (Sahoo et al. NeurIPS) | Masked diffusion ELBO ≡ weighted CE — 极大简化训练 | OWT scale |
| 2025 | LLaDA (Nie et al.) | 第一个 8B-scale 离散 DLM,证明 scaling 可行 | 8B params |
| 2025 | Dream 7B (Ye et al.) | 大规模 diffusion LM(具体机制细节见原文) | 7B params |
| 2026 | 外部 landscape | discrete:BD3-LM (semi-AR block) / ReMDM / PRISM(test-time remasking);continuous-latent:VADD / LADD / HDLM / CADD;inference-time coupling:CoDD-style PC layer | 各种 scale |
这套路线都把 D-LLM 定义为:mask/uniform 离散腐蚀 → 逐位置独立 softmax reverse。 随着工作越做越多,发现它撞上四堵墙:
Token 在嵌入前是孤立范畴点,几何邻近性需要从零学。
对比连续扩散在 image / video 上,模型可以利用像素几何("红"和"暗红"邻近),
但离散 token 空间"猫"和"狗"是几何上无关的两个 one-hot vector。
模型必须把所有词之间的语义距离从训练数据慢慢"学"出来。
结果:参数效率低 — 一个 7B 离散 DLM 学完 token 几何后剩下的容量才用来学语言。
标准 D-LLM 的 reverse 是每个 token 独立的 softmax(可类比为图模型的 fully factorized reverse head,treewidth-0)。 真实联合后验 P(X0|Xt) 的所有 token 间依赖被一刀切掉。典型症状(示意):
Categorical score / transition 在数学上没有干净的 x / v / ε 对应物。 Image diffusion 圈两年发展出来的工具——self-conditioning, classifier-free guidance, flow matching, rectified-flow, EDM-style noise schedule——没有一个能直接搬到离散空间。 每个都要重新设计,速度被拖慢。
每步都要 unmask / 取 argmax / round;离散误差不可微地累积。 经验上需要很多步才能 recover:按 ELF 论文的 step-efficiency 对比,MDLM / Duo 等要用上千步(如 1024)才能接近 ELF-B 仅 32 步的 Gen-PPL 水平。 即使做 distillation(MDLM+SDTT, Duo+DCD),在极少步下仍难免 PPL 退化;ELF 的卖点是无需蒸馏就在 32 步取得低 Gen-PPL。 作为对比,image diffusion 已经发展出 consistency model / mean-flow / flow-distillation 等 1-2 步直采路线。
| 方法 | 在哪个接口"修" | 问题 |
|---|---|---|
| CoDD | 用 tractable probabilistic / circuit layer 替换或增广 factorized output | 训练时还可能是 factorized;inference-time coupling 不改根本 |
| CoDAR | Continuous latent diffusion + fixed encoder + separate / contextual AR decoder (ELF Tab 2 把它列在 latent diffusion 类) | 引入 AR decoder,部分丢了 DLM 全并行性 |
| E2D2 / 类 block diffusion | Semi-autoregressive:block 内 joint denoising,block 间 AR | block-AR 牺牲并行性 |
| VADD / LADD / HDLM / CADD (外部) | 加 latent / hierarchical 结构 | 架构越来越复杂,没有 clean theoretical story |
| ReMDM / PRISM-style | test-time remasking / search | 训练阶段不变,只动推理 |
这些都是 修——保持"离散 + factorized"这个根基不变,在边缘 patching。 ELF 的回答更激进:跳出离散空间。如果 token 嵌入已经被 T5 学好了, 为什么 diffusion 还要在离散符号上跑?为什么不直接在 T5 那个连续 + 有几何结构的空间里跑 flow?
ELF 在两个意义上是 continuous(paper §2 明确说 "continuous in two senses"):
但 ELF 还有一个关键的"冻结":encoder 不学。这是它和过去同类工作的最大区别。
很容易混淆"每个 token 位置对应一个 vector"是不是指"每个词有一个 vector"。不是。先理清三个数字:
| 数字 | 含义 |
|---|---|
| 1024 | 序列长度——位置数量("句子里有 1024 个 token 位") |
| 512 | T5-small contextual embedding 维度——每个位置的 vector 有多少维 |
| 32128 | 词表大小——T5 SentencePiece 总共 32128 种不同的 token |
ELF 的 tensor 流(OWT, L=1024, D=512):
input_ids: [B, 1024] ← 1024 个 token ID(每个是 0~32127 之间的整数)
↓ T5-small frozen encoder
T5 contextual emb: [B, 1024, 512] ← 每个"位置"得到一个 512-d vector
↓ + noise
z_t (noisy): [B, 1024, 512] ← ELF denoise 就在这个 tensor 上
↓ 32 步 Flow Matching denoising
clean embedding: [B, 1024, 512] ← 终点 ≈ 干净的 contextual embedding
↓ 再 forward 一次 → factored decoder head → 32128 logits
vocab logits: [B, 1024, 32128] ← 只在 t=1 这一步才出现
↓ per-position argmax
output token IDs: [B, 1024]
关键澄清:
[B, 1024, 512] 这个连续 tensor。对比离散 DLM(MDLM / LLaDA):
| 离散 DLM (MDLM) | ELF | |
|---|---|---|
| Denoise 状态 | token IDs [B, 1024](离散) | contextual embedding [B, 1024, 512](连续) |
| 每步输出空间 | 32128 vocab 上的 softmax 分布 | 512-d 连续 velocity |
| 词表 32128 出现频率 | 1024 步每步都做 vocab classification | 只在 t=1 一次 |
这就是 §2.4 / §9 里说的"ELF 跳到 continuous embedding,但仍然每个位置一个 vector,最后做 per-position CE"—— ELF 的连续性已经改了(中间不再 round 到 vocab),但 factorization 在 final CE 那一步仍然 per-position 独立(每个位置独立 softmax,没有 joint posterior 跨位置耦合)。
过去 5 年的 continuous DLM 路线(paper Tab 2 p.15 给了完整 landscape)分两条线:
共同问题:
ELF 在 Tab 2 里是唯一一行同时满足「frozen encoder + 训练/推理都不加 per-step token CE + 不另起 separate decoder」的——架构上最干净的设计点。
ELF 的 framing:把"语义几何"和"transport 动力学"两件事在架构上分离:
过去的方法:拍电影 + 现场标定摄像机 + 同时调灯光 — 三件事一起做,一砸就一片乱。
ELF:用已标定好的摄像机 (T5),专心拍电影 (diffusion transport)。
"语言什么样"的问题已经被 Google 用 T5 预训练(~1T tokens)解决了,
ELF 不再重复学这个,把所有训练算力都集中在 transport 上。
45B vs 524B 训练 token 的差距就是这个 framing 的直接结果。
ELF 上面这套"冻结 encoder 才稳"的论证,在 2 周后字节的 Cola-DLM(2026-05 末)那里被正面反驳。 Cola 选择joint training——VAE encoder 和 DiT prior 一起从 Stage 2 开始联合优化,不冻结 encoder。 他们的 RQ2/RQ3 ablation 关键发现:
所以 frozen-vs-joint 这个设计选择没有一个干净的答案。两边都对,只是 regime 不同:
| 规模 | 谁的 ablation 占优 | 解读 |
|---|---|---|
| ~100M-650M 参数 + ~45B tokens | ELF frozen | encoder 容量已够,再 unfreeze 反而拖慢 transport 学习 |
| ~2B 参数 + ~2000 EFLOPs | Cola joint | 大算力下 encoder 也有 headroom,joint co-adapt 释放更多潜力 |
注意两边都没测对方的 regime——ELF 没在 2B+ 规模 sweep joint,Cola 也没在 100M 规模测 frozen。所以这个 tension 严格说是open question。详细分析见 §9 Cola-DLM 对比。
| 方面 | T5-small (ELF 用) | 其它候选 |
|---|---|---|
| 大小 | 35M (encoder-only) | BERT-base 110M、RoBERTa 125M、Sentence-BERT 110M |
| 训练任务 | Span-corruption denoising 预训练;后续 text-to-text multi-task transfer | MLM (BERT) / 对比 (Sentence-BERT) |
| 表示性质 | contextual(同一个 token 在不同句子里不同向量) | BERT 也 contextual;static embedding (word2vec/GloVe) 则不 |
| 几何性质 | span corruption 训练让模型学到"什么 span 合理",几何比较平滑 | BERT MLM 也类似 |
| Vocab | SentencePiece 32128 | BERT WordPiece 30K |
| 是否 generative | encoder-only 用法 | — |
论文没有 sweep encoder 选择(只 ablate T5-small / base / large 三个 size,App C.1)。 这是 ELF 设计中最值得追问的一点:能不能换成 BERT? Sentence-BERT? CLIP text? 一个 LLaMA-7B hidden state? 猜测:T5 的 span corruption 让 hidden state 几何特别平滑,恰好适合 flow matching。 但这是猜测,不是 paper claim。
一句话总结:ELF 没有自己学语言的几何,它直接搭便车坐在 T5 已经画好的"文本流形"上做 transport。
最早的词向量做法(word2vec / GloVe / Diffusion-LM 用的 learned embedding matrix):
E,形状 [32128, 512](vocab × dim)E[4321]所以无论是 "I deposited money in the bank"(金融机构)还是 "We walked along the river bank"(河岸),`bank` 都被映射到同一个固定向量。模型自己没法从 embedding 那一层看出"这个 bank 是哪个意思"——只能靠后面的 transformer 自己消歧。
预训练 encoder(BERT / T5 encoder 等)的做法:跑一遍 transformer encoder over the whole sequence,每个 token 位置的输出向量都受整句话影响。
input_ids: [B, 1024] 一个句子,1024 个位置
↓ T5 encoder (6 层 self-attention + MLP)
output: [B, 1024, 512] 每个位置一个 512-d 向量,
且这个向量是 attention over 所有位置算出来的
所以同样是 "bank":
| Static embedding | Contextual embedding | |
|---|---|---|
| 存储 | lookup table [V, D] | encoder forward [L, D] per sequence |
| "bank" 在不同句子里 | 同一个向量 | 不同向量 |
| 信息量 | V 个固定点(vocab=32128) | 几乎无穷多个点(每句话每位置都不同) |
| 几何结构 | 学到的"词类别"聚类 | 语义+句法消歧后的"上下文状态空间" |
| 来源 | 训练时学一张表 | 跑一遍 frozen encoder |
T5 预训练(~1T tokens)已经把文本数据组织成了一个有结构的 512-d 流形:语义相近的句子 / 位置在这个空间里几何上也相近,上下文消歧、词性、句法关系都被编码进了向量的位置和方向。
ELF 把 T5 encoder 冻住当"坐标系"——diffusion 的工作变成了"在这张已经画好的语义地图上做 transport",而不是"先画地图再 transport"。
对比之下:
因为 ELF 实际能用的"语言知识"远不止 45B,而是 45B + T5 那 1T tokens 的预训练迁移。这一点也常被点出:
"45B 'tokens' exclude pretrained T5-small prior; frozen encoder doing real work. What at 2B from scratch?"(一句常见的质疑)
代价:上限被 T5-small 锁死。换 7B encoder 还有效吗?scale 到 LLaMA-70B 行不行?这是 ELF 的命门——也有评论认为 ELF 更像"T5 的生成式插件而非通用架构"(详见 §11 与 ELF 的命门讨论)。
整个 flow 从 ε(标准高斯 × noise_scale=2.0)→ clean embedding 都在连续空间。 但 t=1 时模型必须输出离散 token。ELF 的做法:
decoder_step_active=True,这次 transformer 输出经过 factored decoder head(backbone 768-d hidden → 512 GELU → 32128)映射到 vocab logits关键:denoise 主干 + 离散化 decoder 是同一份 transformer 权重,靠 4 个 mode token 切换。 论文 App C.4 显示这种 in-context conditioning 比 adaLN-Zero 略好且省 43M 参数(详见 §5)。
2026 年 2-5 月,一共出现了 5 篇基于 Flow Matching 的语言模型工作, 其中 ELF 是 paper 自己 Tab 2 显示的最后一行。这一节把这 5 篇放在一起对比,明确 ELF 在 FM 家族里的独特定位。
CFM — Categorical Flow Maps (Roos et al., UvA/Oxford, arXiv 2602.12233, Feb 2026)。 核心决策:把 categorical generation 写成"endpoint-prediction flow map", 模型预测 simplex 上的 endpoint 分布 πs,t,再用 self-distillation 做少步生成。 评测 Text8 NLL + LM1B Gen-PPL,单步 LM1B Gen-PPL 274.87。
FLM / FMLM — Flow Map Language Models (Lee et al., KAIST + CMU, arXiv 2602.16813, Feb 2026)。 核心决策:在 token-wise one-hot 连续空间建模,simplex-valued denoiser,CE 后验预测; FMLM 是 FLM 蒸馏后的少步版本。LM1B + OWT 评测, FLM 1024-step LM1B Gen-PPL 96.91、OWT 62.23;FMLM 单步 LM1B 119.34、OWT 168.30。
DFM — Discrete Flow Maps (Potaptchik et al., Harvard/Oxford/MIT/NYU, arXiv 2604.09784, Apr 2026)。 核心决策:把 flow map 从 average velocity 重参数化为 mean denoiser ψs,t, 让 off-diagonal flow map 自然落在 probability simplex 上。直接批判欧氏 L2 与概率几何的不匹配。 LM1B 评测,DFM-ESD 单步 Gen-PPL 68.11(entropy 3.79);DFM-PSD 单步 94.08(entropy 4.06)。
LangFlow — Chen et al., UIUC, arXiv 2604.11748, Apr 2026。 核心决策:回到 learned embedding-space,用 Bregman-divergence 把 token CE 解释为 Flow-Matching posterior matching; 推出 ODE-based NLL upper bound,为 embedding-space DLM 提供了一个可信的 likelihood 估计。 LM1B PPL 30.0 / OWT PPL 24.6(注意:是 held-out NLL upper bound,不是 Gen-PPL); 对应 Gen-PPL 是 LM1B 92.2 / OWT 36.5。
ELF — Embedded Language Flows (Hu/Qiu et al., MIT, arXiv 2605.10938, May 2026, 主角)。 核心决策:在冻结的 T5-small contextual embedding里做 FM,只在 t=1 用共享 transformer 切到 decode mode。 80% MSE + 20% CE,无蒸馏。OWT 32-step SDE Gen-PPL 24.08 ± 0.16。
| 维度 | CFM | FLM/FMLM | DFM | LangFlow | ELF |
|---|---|---|---|---|---|
| State space | Probability simplex ΔK−1(endpoint predictor πs,t) | one-hot 连续,simplex-valued denoiser | Mean denoiser ψs,t: ℝK → ΔK-1 | Learned token embedding (V×D matrix) | Frozen T5-small contextual embedding (512-d) |
| Interpolant 几何 | Straight-line stochastic interpolant | Gaussian/one-hot interpolant + simplex repar | Linear interpolant (β-reparameterized) | VP γ-path, deterministic ODE | Rectified linear: zt = t·x + (1−t)·ε |
| Training loss | CE diagonal + endpoint consistency (CSD/ECLD) | FLM: CE posterior; FMLM: KL/CE flow-map distillation | CE diagonal + PSD/LSD/ESD KL consistency | CE-as-Bregman (FM posterior matching) | 80% MSE on velocity + 20% CE on decode |
| 默认步数 | 1-step 主推(NFE 1-64 sweep) | FLM: 1024-step / FMLM: 1-step | 1-2-4-8 step | 128-step (LM1B) / 1024-step (OWT) | 32-step SDE (headline) / 64-step (scaling) |
| 是否蒸馏 | ✅ Self-distillation (CSD/ECLD) | ✅ FMLM 从 FLM teacher 蒸馏 | ✅ Diagonal 1M + off-diagonal 200k/100k | ❌ 无 distillation,明说留给未来 | ❌ 无 distillation |
| Time grid | Logit-normal w/ 0.75 diagonal fraction | Decoding-error-rate reparameterization | argmax-linearized + convex mix β̃(t) | Learnable Gumbel scheduler | Logit-normal (P_mean=−1.5) |
| Endpoint 处理 | Simplex endpoint π; argmax 或采样 | Simplex posterior; argmax | Simplex mean denoiser; softmax | Argmax over token probability | 共享 transformer t=1 decode mode + argmax |
| Headline eval | LM1B 1-step Gen-PPL 274.87 | FLM OWT 1024-step 62.23; FMLM 1-step 168.30 | LM1B 1-step Gen-PPL 68.11 (ESD) | OWT Gen-PPL 36.5 @ 1024-step (PPL upper-bound 24.6) | OWT 32-step Gen-PPL 24.08 ± 0.16 |
§3.2 表第一行的 State space 是 5 篇 paper 的核心分歧点。CFM / FLM / DFM 一类说自己"在 simplex 上走 flow", LangFlow 说自己"在 learned embedding 上走",ELF 说自己"在 contextual embedding 上走"——这些到底什么意思?把它讲透。
一个 token 表示成 vocab-size 长的向量,只有一位是 1,其他全是 0:
vocab V = 32128
token "bank" (id=4321) → one_hot 长度 V 的向量
= [0, 0, ..., 0, 1, 0, ..., 0]
↑ 第 4321 位
整个句子(L=64 个 token)→ 形状 [64, 32128],每行是一个 one-hot。
V 维空间里所有"概率分布"组成的集合:长度 V 的非负向量,且各分量之和 = 1。直观理解:
V=3 时的 simplex(直观图):
(1,0,0) ← token A 的 one-hot
/\
/ \
/ . \ ← 内部任意点 = 概率混合
/ .. . \
/________\
(0,1,0) (0,0,1)
token B token C
Flow matching 这一族的训练对象是一条从噪声 ε 到数据 x 的路径;最常见、也是 ELF 用的形式是 rectified-linear 插值 zt = t·x + (1−t)·ε(§3.2 表里其余几篇用的是别的插值)。 关键问题是:x 长什么样、ε 长什么样、在哪个空间里走。
| 路线 | x 长什么样 | 走 flow 的空间 | 哪几篇 |
|---|---|---|---|
| One-hot / Simplex | V 维 one-hot(或 simplex 内的概率向量) | L × V(V ≈ 32k) | CFM, FLM/FMLM, DFM |
| Learned embedding | D 维 static 向量(embedding lookup) | L × D(D ≈ 128) | LangFlow, Diffusion-LM |
| Contextual embedding | D 维 context-aware 向量(encoder 输出) | L × D(D = 512) | ELF |
想象 V=32128 维空间里数据点 x 长什么样:
| 路线 | 数据 manifold 形状 |
|---|---|
| ELF(contextual) | T5 已组织好的"语义流形",相近词在附近 |
| LangFlow / Diffusion-LM(static) | 学一张 lookup 表,V 个固定点的离散点云 |
| CFM / FLM / DFM(one-hot/simplex) | V 个相互垂直的"角"(汉明距离 = 2,谁离谁都一样远) |
最关键的一点:one-hot 空间里,"bank" 和 "dog" 的欧氏距离 = "bank" 和 "shore" 的欧氏距离 = √2。没有任何语义几何——token 之间的相似性必须由 Transformer 自己从头学出来。
假设词表只有 4 个 token:river / bank / money / shore
river: [1, 0, 0, 0]
bank : [0, 1, 0, 0]
money: [0, 0, 1, 0]
shore: [0, 0, 0, 1]
所有 token 互相垂直,"bank-river" 跟 "bank-money" 完全等距。Transformer 必须从一堆垂直 one-hot 输入中,靠 attention 自己摸索出 "bank" 在 "river" 旁边时该往哪边走回到 §2.4 那个比喻:
这就解释了为什么 CFM / FLM / DFM 都需要更多 trick:
对照之下 ELF 用 MSE 是合法的——因为 contextual embedding 不是概率分布,是带几何结构的语义向量,欧氏距离反映语义距离。
ELF 是唯一不学 embedding 也不约束在 simplex 上的方案。这是它最独特的设计点。
ELF 选 "32-64 步直接采样" 这条路而不蒸馏,是反潮流的—— 当时所有 LM-FM 工作都在卷"少步"。
ELF 的 MSE 合法性建立在"embedding 不是分布"这个根本前提。这也是它和其它 4 篇的分水岭。
注意:LangFlow LM1B 30.0 / OWT 24.6 是 ODE NLL upper bound,不是 Gen-PPL。 和 ELF 的 24.08(Gen-PPL)不可直接横比。可比的是 LangFlow OWT Gen-PPL 36.5(1024 步)vs ELF 24.08(32 步)。
不对称的原因:CFM/FMLM/DFM 的状态在 simplex 或 one-hot 上,flow map 有 clean 的几何对象(mean denoiser)可学;ELF 的状态是 T5 embedding,要做 flow map distillation 需要先证明 T5 embedding 上的 flow map 也有 clean 形式——paper 没做。
ELF vs CFM:两者都把离散文本放到连续 FM 里,但CFM 的连续对象是 simplex-valued endpoint,ELF 的是 T5 contextual embedding。 CFM 押注"少步 self-distillation"(单步 LM1B 274.87);ELF 押注"无蒸馏 32 步"(OWT 24.08)——注意两数跨不同数据集(LM1B vs OWT),不可直接横比,这里只对比路线取向。 两条完全互补的路线——蒸馏路 vs 采样路。(点击此标题 → 滚到 §3.2 表 + 高亮 CFM 列)
ELF vs FLM/FMLM:两者都做 continuous flow LM,也都在 OWT 上用 Gen-PPL; 但 FLM 用 V 维 one-hot(vocab 越大状态越大),ELF 用 512 维 T5 embedding + 128 bottleneck。 FLM 的卖点是天然可蒸馏到 FMLM;ELF 的卖点是无蒸馏即少步。 直接对比:FLM 1024 步 OWT Gen-PPL 62.23 vs ELF 32 步 24.08 —— ELF 在这一对照下明显更优,但代价是放弃了 vocab-level 状态空间的"天然性"。
ELF vs DFM:两者都认真做 ablation,但变量完全不同—— DFM 在变 PSD/ESD consistency 把 flow map 压到 few-step;ELF 在变 ODE/SDE/γ/SC-CFG 把 base flow 调到 32 步。 DFM 的强点是 1-4 NFE;ELF 的强点是无蒸馏 32 NFE。再次互补。
ELF vs LangFlow:两者都在 embedding-space 做 FM,也都报 OWT 数字;但 LangFlow 的 24.6 是 held-out NLL upper bound,ELF 的 24.08 是 Gen-PPL,不能直接比。 可比项:LangFlow OWT Gen-PPL 36.5 (1024 步) vs ELF 24.08 (32 步) —— ELF 在生成质量上明显优。 反过来,LangFlow 有 likelihood story(可信 NLL bound + 4/7 zero-shot benchmark 超过 AR),ELF 目前没有对应 likelihood claim。 两边互补。
这 5 篇本质在回答同一个问题:语言 FM 应该把"连续性"放在哪里?
Simplex(CFM/DFM)、one-hot(FLM)、learned embedding(LangFlow),还是 frozen contextual embedding(ELF)?
ELF 押注最后一种,并用无蒸馏 32 步 OWT Gen-PPL 24.08 证明这条路在 sample quality 上最有说服力。
从横轴到纵轴。§3 这条横轴把 ELF 放进同期 FM 家族里、确认了它"借 frozen contextual embedding"这一押注的独特性。但 ELF 借坐标系只是"连续表征从哪来"的一种答案——接下来 §4–§8 先把 ELF 这个锚讲透(训练、架构、采样、结果),然后 §9–§10 沿同一条轴拉出另外两条并列路线:字节 Cola-DLM 的"学一个 VAE latent"(§9),和 UC Riverside DSL 的"在 token 球面上做 stochastic localization、让几何随 SNR 动态涌现"(§10)。三条路线放在一起,才是这篇 blog 想给的完整图景。
本节先把 ELF 网络当成一个抽象函数:
x̂ = netθ(z, t, c, ω, mode)
U(·) 拿到 vocab logits(decode mode 时用)这一节不依赖具体架构——只要 netθ 是个能接受 (z, t, c, ω, mode) 的可学函数即可。具体 Transformer 实现(T5 encoder、DiT block、RoPE、bottleneck、factored decoder head 等)下一节 §5 展开。
ELF 用的是 x-prediction parameterization(App C.1 ablation 选出来的—— x-pred 在高维比 v-pred / ε-pred 稳定,因为"clean text 在低维流形上"):
# sampling_utils.py:115-127
def net_out_to_v_x(net_out, z, t, t_eps=5e-2):
x = net_out # ← 网络输出直接就是 x̂(clean embedding pred)
denom = torch.clamp(1.0 - t, min=t_eps)
v = (x - z) / denom # ← velocity 是后处理算出来的
return v, x
但这不代表跳过逐步去噪。推理仍然 32 步 Euler 迭代:
for step i in range(32):
x̂ = net_θ(z_t, t, c, ω, "denoise") # 每步预测 clean embedding
v = (x̂ - z_t) / (1 - t) # 由 x̂ 换算瞬时 velocity
z_t = z_t + dt · v # Euler 走一小步
# 最后一步:x̂ = net_θ(z_t≈x, t=1, c, ω, "decode") → unembed → token
三件事要分清:
| 真值 | |
|---|---|
| 网络直接输出什么 | clean embedding x̂(一直在预测最终目标 x0) |
| Flow Matching 框架下的 transport 量 | velocity v |
| 推理过程 | 仍然 32 步 Euler 迭代;每步用网络的 x̂ 算 v 再走一小步 |
为什么这么设计:
‖v_pred − v_target‖²,但 v_pred / v_target 都从 x_pred / x0 换算,梯度其实在监督 x_pred → x0和 consistency model / mean flow 的区别:那些工作目标是真的 1 步从 noise 直接跳到 x(学一个"时间无关的 x-prediction")。ELF 没走那条路——仍是 32 步,但每步的局部预测对象是 x 而非 v。FMLM 就是 FLM 的 consistency-distilled 1-步版本;ELF 的 future work 也提了这个方向。
论文 App B.1 把 Alg 3 / 4 写得像两条独立 pipeline。PyTorch port 实际是这样执行(按 train_step.py):
input_ids → x₀ ∈ [B, L, 512],(x₀−μ)/0.2t,加噪 z = t·x + (1−t)·ε·2.0p,z̃ = p·x + (1−p)·ε·5.0z_mixed[row] = decoder_z if decoder_step_active[row]==1 else denoiser_z.detach()关键认识:CFG target 的第二个 no-grad forward 在主 gradient forward 之后。 数学目标不变(用 v_target 监督 v_pred),但实现顺序不是"两个 no-grad 然后一个 grad"。
ELF 用两条不同的 loss 学两件不同的事。这是 ELF 整篇 paper 最容易被误解的点:
| Loss | 学什么 | 作用 | 训练频率 |
|---|---|---|---|
| LMSE(denoiser) | 在 连续 embedding 空间里 transport 噪声到 clean embedding | 学"动力学"——给定 noisy zt 和时间 t,怎么把 zt 沿 flow 推到 x。这是 Flow Matching 的核心。 | 80%(per-example Bernoulli(0.2) 抽 decoder 模式,剩下都是 denoiser) |
| LCE(decoder) | 在 离散 token 空间里把 clean embedding 投影回 vocab | 学"离散化"——把 flow 终点 (clean embedding ≈ x) 映射到 32128 个 token id。这是 ELF 唯一触及离散 token 的训练步骤。 | 20% |
为什么必须有 CE? 你既然懂 MSE,问题就清楚了: MSE 只能让模型预测 clean embedding,但下游任务要的是生成 token。 如果只有 MSE,推理时拿到 clean embedding 后没办法回到 token(找最近的 T5 embedding 找最近邻?精度差且不可微)。 所以必须额外训练一个 decoder head,把 embedding 映射回 vocab logits —— 这个 head 必须用 CE 训练(因为 vocab 是离散的)。 ELF 的优雅之处是:decoder head 和 denoiser 共享同一份 transformer 权重,只用 4 个 mode token 切换 mode。
所以 ELF 的训练目标本质是:
Ltotal = 𝔼(s, c) [ (1−pdecode) · LMSE(transport) + pdecode · LCE(round-to-token) ]
其中 pdecode = 0.2 是 decoder 分支抽样概率(论文 Tab 4 默认)。
Self-conditioning(来自 Chen, Zhang, Hinton, "Analog Bits", ICLR 2023,ELF 引用 [9])是 diffusion / flow matching 的一个推理时迭代精修技巧:
普通 diffusion 每一步 forward 只看 zt:x̂t = net(zt, t)。
Self-cond 让网络额外接收上一步的预测作为输入:x̂t = net(zt, t, x̂t−1)。
推理时这相当于"我已经有一个 partial 估计,refine 它",比从零预测更稳。
但训练时模型并没有"上一步预测" — 因此训练用以下 trick 模拟:
stopgrad(x̂no_sc),对这次反传梯度。
这样模型学到"给 x̂prev 作 input 时怎么 refine"。为什么要 stopgrad?因为不希望梯度通过第一次 forward 回流——
那会让训练目标变成"让 x̂no_sc 也参与优化",破坏 mathematical setup。
ELF 的 self-cond 完全沿用 Chen et al. 这个标准做法。
实现上 self-cond 把网络输入维度从 D 扩到 2D(concatenate [z, x̂prev]),
然后用一个 self_cond_proj: 2D → D 线性层压回 D。
所以你在 Alg 3 看到的 self_cond_proj(concat([z, ...], dim=-1)) 就是这个压回操作;不是 ELF 独创。
有 5 个理由,按重要性排序:
(1−1/ω)·(v_cond − v_uncond) 里:
stopgrad(x̂_no_sc)没 self-cond 就没 SC-CFG,没 SC-CFG 就没有"推理只跑 1 次 forward"的训练时 CFG 优化(见 §4.4)。 所以整套 Eq 3 训练时 CFG trick 都建立在 self-cond 之上。
Alg 3 里 self-cond 看起来占大半篇幅是因为记账复杂,不是概念复杂。剥掉 plumbing 后本质就是 "plain FM + Chen 2023 的 self-cond + Eq 3 的训练时 CFG"。下面这 5 个 step 都是 implementation 细节,不是 5 个独立 ELF 创新:
| Alg 3 里的步骤 | 它在干嘛 |
|---|---|
x_no_sc = net(concat([z, 0])) | no-cond reference,self-cond 输入填 0 |
stopgrad(x_no_sc) | 不让梯度回流第一次 forward(否则训练目标污染) |
x_sc = net(concat([z, stopgrad(x_no_sc)])) | self-cond reference forward |
(1−1/ω)·(v_sc − v_no_sc) | SC-CFG guidance 烤进 v_target(Eq 3) |
where(self_cond_mask, ..., ...) | 50% 概率二选一(让模型也学"没 prior" 的情况,覆盖推理第 1 步) |
一句话:self-cond 是 ELF 站在 Chen 2023 + DiT-style 训练时 CFG 这两个肩膀上的"工程接缝"。 拿掉 self-cond 也能跑 plain FM,但你会失去 (a) 32 步达到同质量的能力 (b) Eq 3 训练时 CFG 这两个 ELF 系统级卖点。
论文 App B 写成两个独立算法。PyTorch port 改成 per-example mix,数学上等价(见 §4.6)。两个算法的关键差异:
| Denoiser (Alg 3) | Decoder (Alg 4) | |
|---|---|---|
| Mode token gate | 0(无 mode 信号) | 1(mode token 激活) |
| 时间 t | per-sample t = σ(N(Pm=−1.5, Ps²=0.8²)) | t = 1(始终终点) |
| Corruption ratio | per-sample t | per-token p = σ(N(0.8, 0.8²))(独立!) |
| Noise scale | 2.0 | 5.0 (OWT) / 1.0 (XSum, WMT) |
| Self-cond input | 50% stopgrad(x'); 50% zeros | 始终 zeros(不学 self-cond) |
| Loss | MSE on velocity(base + CFG-augmented target) | CE per token |
| Output head | FinalLayer (768→512 flow output) | Factored decoder (768→512→32128) |
Algorithm 3 伪代码(denoiser,paper App B.1,去 LaTeX):
# Algorithm 3: ELF denoiser training with conditioning and guidance
# net(z, t, c, w, mode): ELF network with in-context conditioning
# self_cond_proj(z): concat-to-original-dim projection
# self_cond_prob: 0.5
# s: discrete token sequence
# c: condition (optional, only for XSum/WMT)
x = encode(s) # T5 frozen forward
t = sample_t() # logit-normal scalar per sample
w = sample_sc_cfg_scale() # ω ∈ [0.5, 5], paper: power-biased
# PyTorch port: shifted log-uniform
e = randn_like(x) # standard Gaussian(与 paper 一致)
z = t * x + (1 - t) * e # rectified-flow interpolant
v = x - e # base velocity target(e 为 standard Gaussian = 论文 noise_scale=1 抽象写法;PyTorch port 实际 noise_scale=2.0,见 §4.3)
# (1) 不带 self-cond 的 forward (no_grad)
z_no_sc = self_cond_proj(concat([z, zeros_like(z)], dim=-1))
x_no_sc = net(z_no_sc, t, c, w, mode="denoise")
v_no_sc = (x_no_sc - z) / (1 - t)
# (2) 带 self-cond 的 forward (no_grad, stopgrad on x_no_sc)
z_sc = self_cond_proj(concat([z, stopgrad(x_no_sc)], dim=-1))
x_sc = net(z_sc, t, c, w, mode="denoise")
v_sc = (x_sc - z) / (1 - t)
# (3) CFG target: post-combination quantity
v_target = v + (1 - 1/w) * (v_sc - v_no_sc)
# Per-example self-cond mask
self_cond_mask = uniform(B) < self_cond_prob
v_pred = where(self_cond_mask, v_sc, v_no_sc)
v_target = where(self_cond_mask, v_target, v)
v_target = stopgrad(v_target)
loss_denoise = mse_loss(v_pred, v_target)
看 Alg 3 最容易迷的就是 (3) 之后 where 和 stopgrad 那 5 行。它们做了 4 件事:
# (3) CFG target: post-combination quantity
v_target = v + (1 - 1/w) * (v_sc - v_no_sc) # 计算"带 CFG"的 target
# Per-example self-cond mask
self_cond_mask = uniform(B) < self_cond_prob # 每个 example 独立抽 Bernoulli(0.5)
v_pred = where(self_cond_mask, v_sc, v_no_sc) # gradient forward 走哪条
v_target = where(self_cond_mask, v_target, v) # target 选哪个
v_target = stopgrad(v_target) # 截梯度
关键 insight:这是同一个 batch 同时训练两种模式。 mask 把 batch 切两半:50% 走 self-cond + CFG 路径,50% 走 plain FM 路径。 这两条路径必须用同一个 mask挑 v_pred 和 v_target,否则配对错乱。
| mask 值 | 输入有 self-cond? | v_pred (gradient) | v_target | 网络学到什么 |
|---|---|---|---|---|
| True (50%) | 有,self-cond 输入 = stopgrad(x̂_no_sc) | v_sc(带 self-cond 的 forward) | (x − ε) + (1 − 1/ω)·(v_sc − v_no_sc) | "给 prior 估计 + ω,输出 CFG-amplified velocity" |
| False (50%) | 没有,self-cond 输入 = 0 | v_no_sc(不带 self-cond 的 forward) | v = (x − ε)(base velocity) | "没 prior,给 plain FM velocity"(推理第 1 步用得到) |
no_grad 上下文里算出来的(已经无梯度),v = (x − ε) 来自 leaf tensor x 和 ε 也无梯度。
所以理论上 v_target 本身就没有 gradient flow 进网络。
stopgrad 是防御性的——(a) 文档上明确"v_target 是 supervision,不可微";(b) 防止有人 fork 代码后去掉 no_grad 时还能保住语义。
这是好的工程习惯。就是 mse_loss(v_pred, v_target) 算 L2 距离:
mse_loss 展开是什么?就是 velocity 上的 L2 距离 + per-token mean,标准 Flow Matching 损失。具体计算:
# v_pred ∈ [B, L, 512],来自 ELF 主干 gradient-tracked forward 后转 velocity:
# x_pred = ELF_transformer(z_self_cond, t, c, ω, decode_mode=False)
# v_pred = (x_pred - z) / clamp(1 - t, t_eps) # paper Eq 4
# v_target ∈ [B, L, 512],由 Eq 3 训练时 CFG 构造(前面 4.4 节详):
# v_base = (x_0 - z) / clamp(1 - t, t_eps) = (x - ε·noise_scale)
# v_target = v_base + (1 - 1/ω) · (v_cond - v_uncond)
# v_target = v_target.detach() # 梯度只走 v_pred
l2_per_token = ((v_pred - v_target) ** 2).mean(dim=-1) # [B, L] channel-wise mean
l2_per_token *= loss_mask # 排除 padding + cond positions
数学上 ELF 的 MSE 等价于:
LMSE = 𝔼(x, c) 𝔼t, ε [ Σi ∈ valid ‖ vθ(zt, x'prev, t, c, ω)i − vtarget,i ‖22 / D ]
其中:
| 符号 | 含义 |
|---|---|
| x | frozen T5-small encode 后的 contextual embedding,归一化后(× 1/0.2 = ×5) |
| zt | noisy embedding:zt = t·x + (1−t)·ε·noise_scale,其中 noise_scale = 2.0 |
| t | per-sequence corruption 时间,t ~ σ(N(−1.5, 0.64)),logit-normal |
| ε | 标准高斯 ∈ ℝD,D = 512(T5-small encoder dim) |
| vθ | ELF 主干预测的速度场(实际预测 x_pred,再转 v = (x_pred − z) / (1−t)) |
| vtarget | 训练时 CFG 烤进的 target velocity(Eq 3,.detach() 截梯度) |
| x'prev | self-conditioning 输入(50% 概率 = stopgrad(uncond x_pred);50% = zeros) |
| ‖·‖22 / D | channel 维(512)取均值,等价于 per-channel MSE |
| valid 位置 | 排除 padding + 条件生成的 cond positions |
Algorithm 4 伪代码(decoder,paper App B.1):
# Algorithm 4: ELF decoder training with conditioning and guidance
x = encode(s)
p = sample_per_token_p() # logit-normal PER TOKEN (different from denoiser)
w = sample_sc_cfg_scale()
e = randn_like(x) # standard Gaussian(与 paper 一致)
z = p * x + (1 - p) * e # per-token corruption ratio
# decoder always uses zero self-cond input
z = self_cond_proj(concat([z, zeros_like(z)], dim=-1))
h = net(z, t=1, c, w, mode="decode") # mode token gate = 1
s_pred = unembed(h) # factored decoder head
loss_decode = ce_loss(s_pred, s)
ce_loss 展开是什么?就是标准的 per-token cross-entropy,没有任何 ELF-specific 改造。具体计算(来自 src/train_step.py):
# decoder_logits ∈ [B, L, 32128],由 ELF 主干 forward + factored decoder head 给出:
# hidden_768 = ELF_transformer(z̃, t=1, c, ω, decode_mode=True)[B, L, 768]
# hidden_512 = GELU_tanh(hidden_768 @ proj_kernel + proj_bias) # 768 → 512
# decoder_logits = hidden_512 @ unembed_kernel + unembed_bias # 512 → 32128
log_probs = F.log_softmax(decoder_logits.float(), dim=-1) # [B, L, 32128]
ce_per_token = -log_probs.gather(-1, target_ids.unsqueeze(-1)).squeeze(-1)
# ⇔ -log p_θ(s_i | z̃_i, t=1, c, ω, mode=decode) # [B, L]
# 然后 mask 掉 padding 位置 + 条件生成的 cond 位置,再聚合
ce_per_token *= loss_mask # loss_mask = (1 - cond_seq_mask) * attention_mask
数学上 ELF 的 CE 等价于:
LCE = − 𝔼(s, c) 𝔼pi, εi [ Σi ∈ valid log pθ(si | z̃i, t=1, c, ω, mode=decode) ]
关键变量含义:
| 符号 | 含义 |
|---|---|
| si | 第 i 个位置的真实 token id(ground truth) |
| z̃i | 该位置的 per-token corrupted clean embedding: z̃i = pi·xi + (1−pi)·εi·noise_scale, 其中 pi ~ σ(N(0.8, 0.64)) 每个 token 独立抽,noise_scale = 5.0 (OWT) / 1.0 (cond) |
| t=1 | decoder 分支永远在终点;时间 token 编码 t=1 |
| c | 条件输入(XSum / WMT 才有;OWT 无条件 c=∅) |
| ω | SC-CFG scale ∈ [0.5, 5](虽然 decoder 分支不学 SC-CFG guidance,但 ω token 仍 prepend 作为输入) |
| mode=decode | 4 个 mode_token gate=1(denoiser 分支时 gate=0,token 被乘 0) |
| pθ(·) | factored decoder head 输出的 softmax 分布(768 → 512 GELU → 32128 → softmax) |
| valid 位置 | 排除 padding token + 条件生成的 cond positions(cond 不要求模型预测) |
两条分支用不同的时间/腐蚀分布。Paper App B.1 + 代码 src/configs/training_configs/train_owt_ELF-B.yml 默认:
| 分支 | 分布 | P_mean | P_std | Noise scale | 说明 |
|---|---|---|---|---|---|
| denoiser | per-sequence logit-normal | −1.5 | 0.8 | 2.0 | t = σ(N(−1.5, 0.64)) → 偏向小 t(噪声多) |
| decoder (OWT) | per-token logit-normal | 0.8 | 0.8 | 5.0 | p = σ(N(0.8, 0.64)) → 偏向大 p(接近 clean);每个 token 独立 |
| decoder (XSum/WMT) | per-token logit-normal | 0.8 | 0.8 | 1.0 | 条件生成用更小 noise |
核心思路:denoiser 训练时多接触噪声大的样本(学习如何 transport), decoder 训练时多接触接近 clean 的样本(学习如何 round 回 token)。 而且 decoder 是每个 token 独立抽 corruption ratio,模拟推理时不同位置的 reconstruction 质量。
ELF 用 Lipman et al. 2023 rectified-flow 标准插值公式:
z = p · x + (1 − p) · ε · noise_scale
| p 值 | z 长什么样 | 翻译 |
|---|---|---|
| p = 0 | z = ε · 5.0 = 全是噪声 | 完全 noisy |
| p = 0.5 | z = 0.5·x + 0.5·ε·5.0 | 一半信号一半噪声 |
| p = 1 | z = x = 完全 clean embedding | 完全 clean |
所以 p 是"信号比例",1−p 是"噪声比例"。p 越大,z 里 clean x 成分越多。
denoiser 的 t 同理:z_t = t·x + (1−t)·ε·2.0,t=1 是 clean 端点,t=0 是噪声端点。
关键原则:训练分布必须匹配推理分布。两条分支推理时见到的样子完全不同:
| denoiser | decoder | |
|---|---|---|
| 推理时跑在哪 | 整条 32 步 ODE 轨迹,所有 token 位置同步沿 t 推进 | 32 步完后跑一次 forward,把 zt=1 映射到 token |
| 推理时各 token 位置的 corruption 状态 | 同一时刻 t,所有位置共享同一个 corruption level | 不同位置 ODE 落地质量不均——有的 95% 干净,有的 70% 干净(attention 把噪声去除得不均衡) |
| 训练 sample 单位 | per-sequence t ~ logit-normal(−1.5) | per-token pi ~ logit-normal(+0.8) |
| shape | t: (B,) → broadcast 到 (B, 1, 1) |
p: (B, L, 1) 每个位置独立 |
| 动机 | 所有位置同 t 才能让 attention 学到正确的 transport 动力学 | per-token p 训练 decoder 在per-position 残余噪声不均时仍能 round 回 token |
Denoiser 不能 per-token:如果训练时同一序列内不同位置取不同 t,attention 看到"有的位置很 noisy 有的很 clean"——但推理时根本不会出现这种状态。模型会学到错误的 transport 模式。
Decoder 必须 per-token:推理时 32 步 ODE 落地的 zt=1 不是完美 clean,不同位置残余噪声幅度不同(attention 在 32 步里对各位置的去噪进度不同步)。 Per-token pi 训练 decoder看上下文判断该位置真正的 token——paper App B.1 原话: "encourages the shared-weight decoder mode to recover corrupted embeddings from their surrounding context, making final-step discretization more robust to imperfect embeddings produced by the denoiser at inference time."
per-token pi 只通过 z̃i 的"信号/噪声混合比例"隐式传递给网络——网络不直接知道某个位置的 pi 是多少。网络的时间输入是固定的 t=1 scalar,decoder 通过 attention 看周围上下文自动判断该位置可信度。
| denoiser | decoder | |
|---|---|---|
| P_mean | −1.5(偏小 t) | +0.8(偏大 p) |
| σ(P_mean) 中位 | ~0.18(多 noisy) | ~0.69(多 clean) |
| 大部分样本落在 | [0.05, 0.4] noisy 区 | [0.5, 0.95] clean 区 |
| Granularity | per-sequence(match transport 同步性) | per-token(match 终点 per-position 不均) |
| Noise scale | 2.0 | OWT 5.0 / 条件 1.0 |
# sampling_utils.py::add_noise (denoiser)
def add_noise(x0, noise, t, config, cond_seq_mask=None):
t_expanded = t.reshape(-1, 1, 1) # [B, 1, 1]
z = t_expanded * x0 + (1 - t_expanded) * noise * config.denoiser_noise_scale
if cond_seq_mask is not None: # 条件生成时保留 cond token 不腐蚀
z = cond_seq_mask * x0 + (1 - cond_seq_mask) * z
return z
# train_step.py — decoder branch (per-token p sampling)
decoder_z_vals = (
torch.randn((B * L,), dtype=dtype, device=device)
* config.decoder_p_std + config.decoder_p_mean # = N(0.8, 0.8²)
)
decoder_lambda_t = torch.sigmoid(decoder_z_vals).reshape(B, L, 1) # [B, L, 1] per-token
decoder_noise = torch.randn(x0.shape, dtype=dtype, device=device) * config.decoder_noise_scale
decoder_z = decoder_lambda_t * x0 + (1 - decoder_lambda_t) * decoder_noise
Classifier-Free Guidance (CFG)(Ho & Salimans, "Classifier-Free Diffusion Guidance", 2022) 是 image diffusion 的一个条件信号放大器:
训练时让一个网络同时学有条件 vθ(z, t, c) 和无条件 vθ(z, t, ∅)(10% 概率把 c 置空)。 推理时按系数 ω > 1 把两者外推:
vfinal = vuncond + ω · (vcond − vuncond)
直觉:朝"和无条件的方向"反方向多走一步,等价于更强地遵从条件 c。ω=1 是原始条件 forward;ω=3 时条件信号被显著放大(细节更锐利、和 prompt 更贴)。 代价:推理每步要跑 2 次 forward(cond 一次 + uncond 一次),算力 ×2。 这是 DALL·E / Stable Diffusion / Imagen 等图像 diffusion 的标配。
ELF OWT 默认 32 步 × 1 forward = 32 forward。如果套标准 CFG → 64 forward,推理算力翻倍。 更要命的是:ELF 把 sampling step 从 baseline 的 1024 步压缩到 32 步是它的核心卖点,再 ×2 就失去了步数效率优势。
这个 trick 是 Chen 等人 2025 年 image diffusion 的工作(ELF App C 引用)。核心想法:
不在推理时做 CFG 组合,而是在训练时就把 CFG 组合烤进 v_target。
让模型直接学 post-combination quantity v_θcfg,推理只需一次 forward 就拿到等价于 CFG 后的 velocity。
| 标准推理时 CFG | 训练时 CFG (Chen 2025 / ELF) | |
|---|---|---|
| 训练 forward 数 | 1(只学 v_cond / v_uncond 之一) | 3(uncond + cond + grad-tracked) |
| 推理 forward 数 / 步 | 2(cond + uncond) | 1 |
| 32 步总推理 forward | 64 | 32 |
| 训练 1.5× ↔ 推理 2× 谁划算 | 训练 5 epoch 1 次,推理跑无数次 — 训练时 CFG 明显更划算 | |
注意 ELF 的 c 在 Eq 3 里不是文本 condition,而是self-conditioning 输入(见前面 self-cond 那个 callout):
所以 ELF 实际上把两个独立的 CFG 机制分开处理:
| 机制 | 训练时还是推理时 | 放大什么 | OWT 用 | XSum/WMT 用 |
|---|---|---|---|---|
| SC-CFG(Eq 3) | 训练时(baked-in) | self-conditioning 信号 | ω=3 | ω=1(已 baked,不额外推理) |
| Input-cond CFG(标准) | 推理时 | 文本 condition 信号 | —(无条件) | ω=2(推理时 ×2 forward) |
为什么不把 input-cond CFG 也烤进训练? 因为 input-cond CFG 的 ω 通常需要在推理时 sweep 不同值(找最佳质量),烤进训练就锁死了。 SC-CFG 的 ω 在 ELF 里是固定行为(不用 sweep),所以烤进训练划算。 这是一个 nuanced 的工程权衡。
Eq 3 = Chen 2025 训练时 CFG 这个 image diffusion trick,移植到 ELF 的 self-conditioning 维度, 让 32 步无条件采样的算力优势在加 CFG 后仍然保住。剥掉这个 trick, ELF 要么没 SC-CFG(质量低),要么推理 32 步 ×2 forward(失去步数效率卖点)。
Image diffusion 的 classifier-free guidance 通常是推理时跑两遍 forward 然后线性组合:
推理时 CFG: vfinal = vuncond + ω · (vcond − vuncond)
ELF 把它烤进训练——模型直接学 post-combination quantity,推理只需一次 forward。Paper Eq 3:
vtarget = (x − ε) + (1 − 1/ω) · (vθcfg(zt | t, c, ω) − vθcfg(zt | t, ∅, ω))
这个系数不是 ad hoc 写的,是从标准 inference-time CFG 5 步代数推导出来的。Chen et al. ICML 2025 "Visual Generation without Guidance" 的核心贡献就是这个 reparameterization。
Step 1 · 标准 CFG 公式(Ho & Salimans 2022 推理时 CFG):
vfinal(ω) = vu + ω·(vc − vu) = vc + (ω − 1)(vc − vu)
vc, vu 是"不加 CFG"时网络对 cond / uncond 的 logical base 输出。
Step 2 · 训练时 CFG 的网络已经学 post-combination quantity
ELF 网络输出 vθcfg(c, ω) 直接就是 vfinal(推理不再组合)。所以:
两者之差已经被 ω 预放大:
vcfgc − vcfgu = [vu + ω(vc−vu)] − vu = ω·(vc − vu)
Step 3 · 反推 logical base 差
vc − vu = (1/ω)·(vcfgc − vcfgu)
Step 4 · 构造 training target
我们要让网络的 vcfgc 收敛到 vfinal。代入 Step 3:
vtarget = vc + (ω − 1)·(vc − vu)
= vc + (ω − 1)·(1/ω)·(vcfgc − vcfgu)
= vc + (1 − 1/ω)·(vcfgc − vcfgu)
所以 (1 − 1/ω) 等于 (ω − 1) × (1/ω) 化简的结果——前一项是标准 CFG"超越 cond 的放大量",后一项是把"已被 ω 预放大的差值"还原回 logical base。
Step 5 · FM target 替换 vc
base FM target 就是 vc = x − ε(rectified-flow 标准 target),代入得 Eq 3:
vtarget = (x − ε) + (1 − 1/ω)·(vcfgc − vcfgu)
| ω | (1 − 1/ω) | vtarget | 含义 |
|---|---|---|---|
| 1 | 0 | x − ε | 无 CFG,退化为 plain FM ✓ |
| 2 | 0.5 | (x−ε) + 0.5·(vcfgc − vcfgu) | 中等放大 |
| 3 | 0.667 | (x−ε) + 0.667·(...) | ELF 默认 SC-CFG=3 |
| 5 | 0.8 | (x−ε) + 0.8·(...) | SC-CFG 上限 |
| ∞ | 1 | (x−ε) + (vcfgc − vcfgu) | 极限放大 |
ω=1 退化 ✓、ω→∞ 系数趋于 1 ✓、ω=3 对应 ELF OWT 默认 ✓。
这里的 trick 是 self-cond 的"condition"不是输入文本 c,而是 self-cond 输入 x'。 所以"uncond"就是 x'=0,"cond"就是 x'=stopgrad(net1)。CFG scale ω∈[0.5, 5] 是 ELF 自己也学的输入参数(4 个 SC-CFG token 编码)。
实现需要 3 次 forward。代码实际执行顺序是:
梯度只过第 2 次。最终 L2 loss = ‖v_pred − stopgrad(v + (1−1/ω)(v_cond − v_uncond))‖²。
# src/train_step.py — Eq 3 的自条件 CFG target 构造(简化版)
def compute_shared_uncond(z, t_input, x_tokens):
# forward #1: self-cond input = zeros
z_uncond = restore_cond(torch.zeros_like(z), x_tokens, cond_seq_mask)
z_input_uncond = torch.cat([z, z_uncond], dim=-1)
with torch.no_grad(), autocast(bf16):
net_out_uncond = model(z_input_uncond, t_input,
self_cond_cfg_scale=self_cond_cfg_scale)
return net_out_uncond
def get_sc_cond_and_uncond(z, t_input, cond_mask, x_tokens, shared_net_out_uncond):
v_uncond, x_uncond = net_out_to_v_x(shared_net_out_uncond, z, t_input, t_eps)
x_uncond = restore_cond(x_uncond, x_tokens, cond_mask)
# forward #2: self-cond input = stopgrad(x_uncond)
z_input_cond = torch.cat([z, x_uncond], dim=-1) # 注意 stop-grad on x_uncond
with torch.no_grad(), autocast(bf16):
net_out_cond = model(z_input_cond, t_input,
self_cond_cfg_scale=self_cond_cfg_scale)
v_cond, _ = net_out_to_v_x(net_out_cond, z, t_input, t_eps)
return v_cond, v_uncond
def get_sc_guided_v(z, t_input, base_v_target, x_tokens, shared_net_out_uncond):
v_cond, v_uncond = get_sc_cond_and_uncond(...)
sc_w = self_cond_cfg_scale.reshape(B, 1, 1)
sc_guidance = (1 - 1 / sc_w) * (v_cond - v_uncond) # ← Eq 3 第二项
sc_guidance = torch.where(use_self_cond_mask.bool(),
sc_guidance,
torch.zeros_like(sc_guidance))
return (base_v_target + sc_guidance).detach() # ← .detach() 是关键
# forward #3 (gradient-tracked) 在主 batch forward 里:
net_out, decoder_logits = model(model_input, t_mixed,
self_cond_cfg_scale=self_cond_cfg_scale,
decoder_step_active=decoder_step_active)
v_pred, _ = net_out_to_v_x(net_out, denoiser_z, t, t_eps=0.05) # gradient-tracked
v_final_target = get_sc_guided_v(denoiser_z, t, base_v_target=v_target, ...)
l2_per_token = ((v_pred - v_final_target) ** 2).mean(dim=-1) # MSE
推理时 CFG 要每步跑两遍 forward(cond + uncond),32 步采样 = 64 次 forward。 ELF 把 SC-CFG 烤进训练后,推理只跑 1 次 forward / step,效率 ×2。代价:训练时 3 次 forward。但训练只跑一次(5 epochs),推理跑无数次。划算。
注意 ELF 还另外保留了输入条件的推理时 CFG(label drop 训练的)。两个机制独立:
SC-CFG 用 Eq 3 的 self_cond_cfg_scale(推理时通常 = 1,因为已烤入);
input-cond CFG 是标准推理时 cond+uncond 组合(XSum/WMT 默认 = 2)。
label drop 的训练信号只上游改变条件 embedding 状态,不进入 Eq 3 的 SC-CFG 公式。
Paper Algs 3+4 写成两个独立 training step,按 0.8 / 0.2 概率轮换。PyTorch port 改成:
# src/train_step.py — 关键混合 forward
# 每行独立 Bernoulli(0.2) → decoder mode;否则 → denoiser mode
decoder_step_active = torch.bernoulli(
torch.full((B,), config.decoder_prob, dtype=torch.float32),
generator=gen,
).to(device=device, dtype=dtype) # (B,) 1.0=decode 0.0=denoise
decoder_mask_B11 = decoder_step_active.view(-1, 1, 1)
decoder_mask_B1 = decoder_step_active.view(-1, 1)
# t、z 都按 per-example 混合
denoiser_t = t # logit-normal per-sample
decoder_t = torch.ones_like(t) # 永远 1
t_mixed = decoder_step_active * decoder_t + (1.0 - decoder_step_active) * t
z_mixed = decoder_mask_B11 * decoder_z + (1.0 - decoder_mask_B11) * denoiser_z
# 单次 forward — mode token gate 也是 per-example
net_out, decoder_logits = model(
model_input, t_mixed,
self_cond_cfg_scale=self_cond_cfg_scale,
decoder_step_active=decoder_step_active, # (B,) per-row gate
)
# CE / L2 各自用 mask 分流
loss_mask_f = loss_mask.to(ce_per_token.dtype)
ce_mask = loss_mask_f * decoder_mask_B1
l2_mask = loss_mask_f * (1.0 - decoder_mask_B1)
# 关键:单一分母归一化(不是两个分母)
total_sum = (ce_per_token * ce_mask).sum() + (l2_per_token * l2_mask).sum()
loss = total_sum / torch.clamp(loss_mask_f.sum(), min=1.0)
# loss_mask_f = ce_mask + l2_mask,所以这等价于 sum/(ce_mask.sum()+l2_mask.sum())
# 注意:pad_token=="pad" 时 loss_mask 屏蔽 padding;pad_token=="eos" 时全 1
论文 Algs 3+4 是两个独立 step,按 (1−p):p 比例轮换(p = decoder_prob = 0.2)。
PyTorch port 是同一个 step 内 per-example 抽 mode。先注意:
每个 example 内所有 token 共享同一个 mode 抽样结果(不是 token-wise i.i.d.)。
记 b ∈ (0, 1) 为 example 的 mode 指示(1 = decode),B0 = denoiser 行数,B1 = decoder 行数。
固定 batch、固定 loss denominator M = loss_mask.sum(),PyTorch port 的 loss:
Lcode = [ Σb=1 行 Σtoken CE + Σb=0 行 Σtoken L2 ] / M
对 mode 抽样取期望(每行 Bernoulli(p)):
𝔼[Lcode] = (p · 𝔼row[Σ CE | decode] + (1−p) · 𝔼row[Σ L2 | denoise]) · (B / M)
= 𝔼[Lpaper] · (per-example token 数加权和) — 等价于 paper Algs 3+4 的固定 batch-size 加权期望。
注意:
条件生成时还需要另一个 classifier-free guidance,针对 input condition(不是 self-cond)。 10% 概率把 cond sequence 直接 drop(zero out),让模型学到 p(x | ∅) 分布:
# src/train_step.py — label drop for input-condition CFG
if config.label_drop_prob > 0:
drop = label_drop_mask.to(dtype=torch.float32).reshape(-1, 1, 1) # (B, 1, 1) 0/1
cond_mask = cond_seq_mask # (B, S)
# block_mask: 1 仅在 (non-cond row, cond col)
# 目的:让 target token 看不到 cond token
block_mask = (1 - cond_mask).unsqueeze(-1) * cond_mask.unsqueeze(1)
encoder_attention_mask = encoder_attention_mask * (1 - drop * block_mask)
label drop 实际两阶段:
encoder_attention_mask,让 target token 看不到 cond token
(block_mask 只在 non-cond row × cond col 上为 1) — 这样 T5 encode 出来的 x₀ 本身就不含条件信息denoiser_z 和 x₀ 在 cond 位置清零
(torch.where(drop & cond_seq_mask, zeros, denoiser_z))— 匹配 paper "zeroing condition embeddings"| 类别 | 参数 | 默认值 |
|---|---|---|
| Optimizer & Schedule | Optimizer | Muon(2D 参数走 Newton-Schulz)+ Nesterov-AdamW(其余) |
| LR (peak) | 0.002(公式:blr=0.001 × global_batch / 256) | |
| LR schedule | constant after warmup | |
| Warmup | 0.5 epoch(5 epochs ≈ 86K steps:45.2B ÷ (512×1024);对应 ~8.6K warmup steps) | |
| Weight decay | 0(关闭 — Muon 自带 shape scaling) | |
| Grad clip | 1.0 (norm) | |
| Batching | Global batch size | 512 |
| Sequence length | 1024 | |
| Grad accumulation | 1 (硬件够大不用) | |
| Diffusion | Denoiser P_mean / P_std / noise scale | −1.5 / 0.8 / 2.0 |
| Decoder P_mean / P_std / noise scale | 0.8 / 0.8 / 5.0 (OWT) | |
| Decoder prob (per example) | 0.2 (denoiser 0.8) | |
| Self-cond prob | 0.5(denoiser only;decoder 永远 0) | |
| CFG | SC-CFG range | [0.5, 5],power-bias sample 偏小值 |
| SC-CFG tokens | 4 | |
| Label drop prob (cond only) | 0.1(XSum/WMT),0(OWT) | |
| Numerics | Precision | bf16 autocast;输出头强制 fp32 |
| EMA decay | 0.9999 | |
| Random seed | 42(per-rank seed + rank offset,让噪声 desync) | |
| 训练量 | OWT epochs | ELF-B: 5 ELF-M: 4 ELF-L: 3(大模型收敛快) |
| OWT 总 tokens | ≈ 45.2B(OWT 数据集约 9.04B × 5 ep) | |
| Hardware | TPU v5p × 64,1.5h/epoch(ELF-B) |
| Method | Base training | Distillation training | Effective tokens | Ratio vs ELF |
|---|---|---|---|---|
| MDLM | 512 × 1M × 1024 | — | 524.3B | 11.6× |
| Duo | 512 × 1M × 1024 | — | 524.3B | 11.6× |
| MDLM + SDTT | 512 × 1M × 1024 | 512 × 10K × 5 × 1024 | 550.5B | 12.2× |
| Duo + DCD | 512 × 1M × 1024 | 512 × 10K × 5 × 1024 | 550.5B | 12.2× |
| FLM | 512 × 1M × 1024 | — | 524.3B | 11.6× |
| FMLM | 512 × 1M × 1024 | 512 × 100K × 1024 | 576.7B | 12.8× |
| LangFlow | 512 × 1M × 1024 | — | 524.3B | 11.6× |
| ELF (ours) | 5 × 9.04B | — | 45.2B | 1.0× |
Baseline 估算公式:batch_size × n_steps × seq_length。ELF 用 OWT 总 token × epochs。
精确 ratio 是 11.6× / 12.2× / 12.8×,paper 文字简称"约 12×"。
注意:ELF 的"45.2B"不包括 T5-small 预训练(Google 用了 1T+ tokens 训 T5)—— 这是 ELF 数据效率 claim 的主要 caveat。
Muon = "Momentum + Newton-Schulz orthogonalization",2024 Keller Jordan 推。核心思想:
nn.Linear.weight、bare nn.Parameter 矩阵),
先算 Nesterov-momentum 平均的梯度,再用 5 步 Newton-Schulz 把它正交化(近似 SVD ⟨U V⊤⟩)sqrt(max(1, fan_out/fan_in)) 形状缩放ELF 的 PyTorch port 用的是 PyPI muon-optimizer 包,加几层 wrapper / patches:
(a) Nesterov-bias-corrected Adam update(替换上游 adam_update);
(b) NS5 强制 fp32 + eps 1e-8(替换上游 bf16 + 1e-7);
(c) Muon update 重写 + shape scaling layout 修正(区分 nn.Linear 和 bare Parameter);
(d) _SafeMuonAuxAdam subclass:zero-fill missing grads + distributed all_gather padding 修复。
全部为了匹配 JAX optax.contrib.muon。
论文 App C.5 ablation:Muon vs AdamW,SDE 采样下差距尤其大(paper 只定性写"more pronounced under SDE",没给数字)。
latent_std=0.2 硬编码,不在 paper Tab 4 里,等于把 T5 raw embedding 放大 5× 后再加噪声。换数据集时需要重新估计,否则 SNR 漂移。γ=2.0, SC-CFG=3seed=42;paper Tab 6 是 6-seed 平均。当前 CLI 没有 --seeds 参数,需要多次运行用 --config_override seed=Nmuon-optimizer>=0.1.0 未 pin,多处 wrapper/patches,上游一改就风险静默漂移。建议 pin 死版本get_sc_guided_v 用了 3 forward 但只对第 2 forward 反传,其余 no_grad。如果 fork 改代码忘了 .detach(),梯度会污染 target 路径| Model | Depth | Hidden | Heads | head_dim | MLP ratio | Bottleneck | Params (DiT) | OWT Epochs |
|---|---|---|---|---|---|---|---|---|
| ELF-B | 12 | 768 | 12 | 64 | 4× | 128 | 105M | 5 |
| ELF-M | 24 | 1056 | 16 | 66 | 4× | 128 | 342M | 4 |
| ELF-L | 32 | 1280 | 16 | 80 | 4× | 128 | 652M | 3 |
三个模型共享同一份 frozen T5-small encoder (~35.3M 参数,encoder-only,不参与训练)。条件生成(XSum / WMT14)时 T5-small encoder 也用来编码 source context — 因此条件 ELF-B 报作 "105M+35M"。 Decoder 没有独立训练 — denoise 和 decode 用同一份 Transformer 权重,靠 model-mode token 切换。
t5-small)| 项 | 值 |
|---|---|
| Vocab size | 32128(SentencePiece vocabulary) |
| Layers | 6 (encoder only) |
| Hidden d_model | 512 |
| Attention heads | 8 (head_dim 64) |
| FFN d_ff | 2048 |
| Activation | ReLU (非 gated) |
| Params | ~35M |
| 训练时 | 整个 encoder 都 requires_grad_(False),作 frozen embedder |
训练 ELF 时永远只 forward 这个 encoder 一次(per batch),bf16 autocast 跑,last_hidden_state
shape [B, L, 512]。然后做归一化:(x − latent_mean) / latent_std,
默认 latent_std = 0.2 等于把 raw T5 embedding 放大 5×。
不在 vocabulary 里,是可学的 nn.Parameter + 输入相关 embedding 的加和:
| 类别 | 个数 | 编码 | 作用 |
|---|---|---|---|
t_emb_tokens | 4 | learnable_tokens[1,4,d] + TimestepEmbedder(t)[B,d].unsqueeze(1) | 把当前扩散时间 t∈[0,1] 注入 |
self_cond_cfg_tokens | 4 | learnable_tokens[1,4,d] + TimestepEmbedder(ω)[B,d].unsqueeze(1) | 把 self-cond CFG scale ω∈[0.5, 5] 注入(训练时随机抽,推理时固定) |
mode_tokens | 4 | learnable_tokens[1,4,d] × active_gate | denoise mode → gate=0(token 被乘 0);decode mode → gate=1(token 激活) |
三组合起来共 12 个 prefix tokens。最终序列顺序是
[time(4) + sc_cfg(4)] + [mode(4)] + [main_x(L)] ——
代码先 cat([mode, main]),再 prepend [time + sc_cfg](见 5.12 forward 第 4-5 步)。
重要细节:RoPE 在所有 prefix(12 个) 上 cos=1, sin=0(不旋转),
主序列从位置 0 开始正常 RoPE 编码。这样添加 prefix 不会破坏 main token 的相对位置。
T5 embedding 512-d → bottleneck → DiT hidden。直觉:clean 文本数据其实在 低维流形上。 论文 App C.2 sweep 了 bottleneck ∈ (32, 128, 512):
class BottleneckTextProj(nn.Module):
def __init__(self, text_encoder_dim, hidden_size, bottleneck_dim):
super().__init__()
self.proj1 = nn.Linear(text_encoder_dim, bottleneck_dim, bias=False) # 512 → 128
self.proj2 = nn.Linear(bottleneck_dim, hidden_size, bias=True) # 128 → 768
def forward(self, x): # [B, L, 512]
return self.proj2(self.proj1(x)) # [B, L, 768]
class TextRotaryEmbeddingFast(nn.Module):
def __init__(self, dim, pt_seq_len=512, ft_seq_len=None,
theta=10000.0, num_empty_token=0):
super().__init__()
ft_seq_len = ft_seq_len or pt_seq_len
# 标准 RoPE 频率
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[:dim//2].float() / dim))
# 位置缩放支持 fine-tune 时 seq_len 与 pretrain 不同
pos = torch.arange(ft_seq_len).float() / ft_seq_len * pt_seq_len
freqs_main = torch.einsum('..., f -> ... f', pos, freqs)
freqs_main = repeat(freqs_main, '... n -> ... (n r)', r=2)
# prefix 的 cos=1, sin=0 → 不旋转
if num_empty_token > 0:
cos_prefix = torch.ones((num_empty_token, freqs_main.shape[-1]))
sin_prefix = torch.zeros_like(cos_prefix)
freqs_cos = torch.cat([cos_prefix, torch.cos(freqs_main)], dim=0)
freqs_sin = torch.cat([sin_prefix, torch.sin(freqs_main)], dim=0)
self.register_buffer("freqs_cos", freqs_cos, persistent=False)
self.register_buffer("freqs_sin", freqs_sin, persistent=False)
def forward(self, t): # [B, n_heads, L_total, head_dim]
cos = self.freqs_cos.to(t.dtype) # 显式 dtype cast 防 bf16 精度漂移
sin = self.freqs_sin.to(t.dtype)
return t * cos + rotate_half(t) * sin
class TimestepEmbedder(nn.Module):
# t (scalar in [0,1]) -> hidden vector (e.g., 768)
# Init: MLP weights normal(0.02), biases zero
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp_0 = nn.Linear(frequency_embedding_size, hidden_size, bias=True)
self.mlp_2 = nn.Linear(hidden_size, hidden_size, bias=True)
@staticmethod
def timestep_embedding(t, dim=256, max_period=10000):
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(half).float() / half)
args = t[:, None].float() * freqs[None]
return torch.cat([torch.cos(args), torch.sin(args)], dim=-1) # [B, 256]
def forward(self, t): # [B]
emb = self.mlp_0(self.timestep_embedding(t, 256)) # [B, 768]
return self.mlp_2(F.silu(emb)) # [B, 768]
class Attention(nn.Module):
def __init__(self, dim=768, num_heads=12, qkv_bias=True, qk_norm=True,
attn_drop=0.0, proj_drop=0.0):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads # 64 for ELF-B
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = RMSNorm(head_dim) if qk_norm else nn.Identity()
self.k_norm = RMSNorm(head_dim) if qk_norm else nn.Identity()
self.proj = nn.Linear(dim, dim, bias=True)
def forward(self, x, rope_fn, attention_mask=None):
B, N, C = x.shape # N = 1036 在 ELF-B 训练时
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
qkv = qkv.permute(2, 0, 3, 1, 4) # [3, B, n_heads, N, head_dim]
q, k, v = qkv[0], qkv[1], qkv[2]
q = self.q_norm(q) # qk-norm(论文与官方实现都开)
k = self.k_norm(k)
if rope_fn is not None: # 应用 RoPE(含 prefix-no-rotation)
q = rope_fn(q)
k = rope_fn(k)
# 实际 layers.py 里包了一层 wrapper:int/float mask -> bool mask 再传 SDPA
x = scaled_dot_product_attention(q, k, v, attn_mask=attention_mask)
return self.proj(x.permute(0, 2, 1, 3).reshape(B, N, C))
注意三件事:(1) qk_norm=True 是 ELF 默认开(提升 bf16 训练稳定性,从 Henry et al. EMNLP'20);
(2) 注意 F.scaled_dot_product_attention 内部用 PyTorch 2.x flash kernel;
源码 wrapper 在 mask 是 2D/3D 时 reshape 加 head 维并 cast 为 bool;
(3) ELF 训练/采样调用模型时都不传 attention_mask(cond+target 全双向 attention);
条件生成时 source/target 的不对称处理由 ELF 数据管线在拼接时构造,不是 T5 encoder 自带(标准 T5 encoder 本身是双向 self-attention)。
class SwiGLUFFN(nn.Module):
def __init__(self, dim, hidden_dim, drop=0.0):
super().__init__()
# SwiGLU 标准做法:把 hidden 缩到 2/3 保持 param count
hidden_dim_eff = int(hidden_dim * 2 / 3) # 768*4 = 3072 → 2048
self.w12 = nn.Linear(dim, 2 * hidden_dim_eff, bias=True) # 768 → 4096
self.w3 = nn.Linear(hidden_dim_eff, dim, bias=True) # 2048 → 768
def forward(self, x): # [B, N, 768]
x1, x2 = self.w12(x).chunk(2, dim=-1) # 各 [B, N, 2048]
return self.w3(F.silu(x1) * x2) # [B, N, 768]
class FinalLayer(nn.Module):
# Last layer that maps hidden 768 -> embedding output 512.
def __init__(self, hidden_size, patch_size, out_channels):
super().__init__()
self.norm_final = RMSNorm(hidden_size)
# 关键: kernel + bias 都用 0 初始化(DiT 标配)
# → 开局时模型预测的 clean x_pred ≡ 0;
# velocity 由后处理 v=(x_pred - z)/(1 - t) 计算
self.linear = nn.Linear(hidden_size, patch_size**2 * out_channels, bias=True)
nn.init.zeros_(self.linear.weight)
nn.init.zeros_(self.linear.bias)
def forward(self, x): # [B, L, 768]
return self.linear(self.norm_final(x)) # [B, L, 512]
class ELFBlock(nn.Module):
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0):
super().__init__()
self.norm1 = RMSNorm(hidden_size, eps=1e-6)
self.attn = Attention(hidden_size, num_heads, qkv_bias=True, qk_norm=True)
self.norm2 = RMSNorm(hidden_size, eps=1e-6)
self.mlp = SwiGLUFFN(hidden_size, int(hidden_size * mlp_ratio))
def forward(self, x, rope_fn, attention_mask=None):
x = x + self.attn(self.norm1(x), rope_fn, attention_mask)
x = x + self.mlp(self.norm2(x))
return x
注意是 Pre-Norm + 残差(RMSNorm 在 attn/mlp 之前)。这是 LLaMA / DiT / GPT-J 等长 bf16 训练的常见稳定选择。
class ELF(nn.Module):
# 源码默认 num_model_mode_tokens=0, vocab_size=0
# train.py 实际从 config / tokenizer 把 4 / 32128 传进来
def __init__(self, text_encoder_dim=512, max_length=1024,
hidden_size=768, depth=12, num_heads=12,
mlp_ratio=4.0, bottleneck_dim=128,
num_time_tokens=4, num_self_cond_cfg_tokens=4,
num_model_mode_tokens=4, vocab_size=32128):
super().__init__()
# Self-conditioning projection: [z; x_pred] (2×512) -> 512
self.self_cond_proj = nn.Linear(2 * text_encoder_dim, text_encoder_dim, bias=True)
# Bottleneck text projection (512 -> 128 -> 768)
self.text_proj = BottleneckTextProj(text_encoder_dim, hidden_size, bottleneck_dim)
# Prefix tokens + their (input-dependent) embedders
self.t_embedder = TimestepEmbedder(hidden_size)
self.t_emb_tokens = nn.Parameter(torch.empty(1, num_time_tokens, hidden_size))
self.self_cond_cfg_embedder = TimestepEmbedder(hidden_size)
self.self_cond_cfg_tokens = nn.Parameter(
torch.empty(1, num_self_cond_cfg_tokens, hidden_size))
self.mode_tokens = nn.Parameter(torch.empty(1, num_model_mode_tokens, hidden_size))
# RoPE with prefix no-rotation
head_dim = hidden_size // num_heads
prefix_total = num_time_tokens + num_self_cond_cfg_tokens + num_model_mode_tokens
self.feat_rope = TextRotaryEmbeddingFast(
dim=head_dim, pt_seq_len=max_length, num_empty_token=prefix_total)
# 12 ELFBlocks
self.blocks = nn.ModuleList([
ELFBlock(hidden_size, num_heads, mlp_ratio) for _ in range(depth)])
# Flow-matching output head (zero-init)
self.final_layer = FinalLayer(hidden_size, patch_size=1, out_channels=text_encoder_dim)
# Factored decoder unembedding: 768 -> 512 (GELU) -> vocab
self.proj_kernel = nn.Parameter(torch.empty(hidden_size, text_encoder_dim))
self.proj_bias = nn.Parameter(torch.empty(text_encoder_dim))
self.unembed_kernel = nn.Parameter(torch.empty(text_encoder_dim, vocab_size))
self.unembed_bias = nn.Parameter(torch.empty(vocab_size))
def forward(self, x, t, self_cond_cfg_scale=None, decoder_step_active=None):
# x: [B, L, 512] if no self-cond, OR [B, L, 1024] if self-cond cat([z, x_pred])
# t: [B] 扩散时间 ∈ [0,1]
# self_cond_cfg_scale: [B] or None ω 用 SC-CFG token 编码
# decoder_step_active: None | True/False | Tensor[B] 控制 model-mode token gate
B = x.shape[0]
# ====== 1. self-cond projection: 2C -> C ======
if x.shape[-1] == 2 * self.text_encoder_dim: # 训练 & 推理都走这里
x = self.self_cond_proj(x.float()) # [B, L, 1024] -> [B, L, 512]
# ====== 2. bottleneck text projection ======
x = self.text_proj(x.float()) # [B, L, 512] -> [B, L, 768]
# ====== 3. build prefix context tokens ======
time_emb = self.t_embedder(t) # [B, 768]
prefix = self.t_emb_tokens.expand(B, -1, -1) + time_emb.unsqueeze(1) # [B, 4, 768]
if self_cond_cfg_scale is not None:
sc_emb = self.self_cond_cfg_embedder(self_cond_cfg_scale) # [B, 768]
prefix2 = self.self_cond_cfg_tokens.expand(B, -1, -1) + sc_emb.unsqueeze(1)
prefix = torch.cat([prefix, prefix2], dim=1) # [B, 8, 768]
# ====== 4. model-mode tokens with per-example gating ======
if decoder_step_active is None:
gate = 0.0 # 默认 denoise: tokens 被乘 0
elif isinstance(decoder_step_active, torch.Tensor):
gate = decoder_step_active.view(-1, 1, 1) # [B, 1, 1] per-example
else:
gate = float(decoder_step_active) # 1.0 in final decode
mode_tokens = self.mode_tokens.expand(B, -1, -1) * gate # [B, 4, 768]
x = torch.cat([mode_tokens, x], dim=1) # [B, L+4, 768]
model_mode_offset = self.num_model_mode_tokens # = 4
# ====== 5. prepend (time + sc-cfg) tokens ======
# 最终顺序: [time(4), sc-cfg(4), mode(4), main(L)]
x = torch.cat([prefix, x], dim=1) # [B, L+12, 768]
prefix_len = prefix.shape[1] # = 8 (time + sc-cfg)
# ====== 6. 12 ELFBlocks with RoPE ======
for block in self.blocks:
x = block(x, rope_fn=self.feat_rope, attention_mask=None) # full bidi
# ====== 7. strip prefix (动态 = prefix_len + model_mode_offset) ======
x = x[:, prefix_len + model_mode_offset:] # [B, L, 768]
# ====== 8. flow-matching head (always computed) ======
flow_output = self.final_layer(x.float()) # [B, L, 512]
# ====== 9. decoder head (only if decoder_step_active) ======
decoder_logits = None
if decoder_step_active is not None:
x_f32 = x.float()
hidden = F.gelu(x_f32 @ self.proj_kernel + self.proj_bias, approximate="tanh")
decoder_logits = hidden @ self.unembed_kernel + self.unembed_bias # [B, L, 32128]
return flow_output, decoder_logits
| 步骤 | 张量 | shape | 说明 |
|---|---|---|---|
| 0 | input_ids | [512, 1024] | tokenized text,int64 |
| 1 | T5 encoder output | [512, 1024, 512] | frozen contextual embedding |
| 2 | normalized x₀ | [512, 1024, 512] | 除以 latent_std=0.2(相当于 ×5) |
| 3 | noisy z_t | [512, 1024, 512] | z = t·x₀ + (1−t)·ε·2.0;t = sigmoid(N(−1.5, 0.8²)) logit-normal |
| 4 | self-cond input | [512, 1024, 1024] | cat([z_t, x_pred]),channel-wise concat,2×512 |
| 5 | self_cond_proj output | [512, 1024, 512] | 线性投影回 512 |
| 6 | bottleneck proj1 | [512, 1024, 128] | 512 → 128(无 bias,强约束) |
| 7 | bottleneck proj2 | [512, 1024, 768] | 128 → 768 |
| 8 | time emb | [512, 768] | sinusoidal(256) → MLP(768) |
| 9 | time prefix tokens | [512, 4, 768] | learnable 加上 time_emb 广播 |
| 10 | sc-cfg prefix tokens | [512, 4, 768] | learnable 加上 ω embedding |
| 11 | mode tokens (gated) | [512, 4, 768] | per-example gate 0 (denoise) 或 1 (decode) |
| 12 | concat: mode + main | [512, 1028, 768] | 中间态:mode 暂时在最前 |
| 13 | concat: prefix + above | [512, 1036, 768] | 最终顺序 [time(4) + sc-cfg(4) + mode(4) + main(1024)] |
| 14 | RoPE 应用于 q,k | [512, 12, 1036, 64] | n_heads=12, head_dim=64;prefix 位置 cos=1, sin=0 |
| 15 | each ELFBlock output | [512, 1036, 768] | 共 12 个 block,shape 不变 |
| 16 | strip prefix | [512, 1024, 768] | 取后 1024 个 token |
| 17 | FinalLayer (flow head) | [512, 1024, 512] | RMSNorm + Linear 768→512,预测 clean x₀ |
| 18a | (decode only) proj | [512, 1024, 512] | x_f32 @ proj_kernel + proj_bias,再 GELU(tanh) |
| 18b | (decode only) logits | [512, 1024, 32128] | hidden @ unembed_kernel + unembed_bias |
| 组件 | 计算 | 参数 |
|---|---|---|
| self_cond_proj | 1024×512 + 512 | 524,800 |
| BottleneckTextProj.proj1 | 512×128 (no bias) | 65,536 |
| BottleneckTextProj.proj2 | 128×768 + 768 | 99,072 |
| TimestepEmbedder (×2: time, sc-cfg) | (256×768+768) + (768×768+768) ≈ 788K | ×2 = 1,575,936 |
| t_emb_tokens + sc_cfg_tokens + mode_tokens | 3 × (4×768) | 9,216 |
| 每个 ELFBlock | ~7.09M(见下) | — |
| RMSNorm × 2 | 2 × 768 | 1,536 |
| Attention.qkv | 768 × 2304 + 2304 | 1,771,776 |
| Attention.q_norm + k_norm | 2 × 64 | 128 |
| Attention.proj | 768 × 768 + 768 | 590,592 |
| SwiGLU.w12 | 768 × 4096 + 4096 | 3,149,824 |
| SwiGLU.w3 | 2048 × 768 + 768 | 1,573,632 |
| 子总(一个 block) | 7,087,488 | |
| 12 个 ELFBlock | 12 × 7.09M | 85,049,856 |
| FinalLayer (norm + linear) | 768 + 768×512 + 512 | 394,496 |
| proj_kernel + proj_bias (decoder) | 768×512 + 512 | 393,728 |
| unembed_kernel + unembed_bias | 512×32128 + 32128 | 16,481,664 |
| 合计(trainable) | 各行精确相加 | 104,594,304 (~105M, 假设 vocab=32128) |
| + T5-small encoder (frozen) | ~35M |
关键观察:decoder unembedding 占 16.5M(~16%),这是 32128 词表的主要开销。 而 12 个 transformer block 占 85M(~81%)—— 真正的核心。 self_cond_proj 只占 0.5%,但它是 self-conditioning trick 能 work 的关键 plumbing。
给定 (B, L) 和采样配置(method, n_steps, cfg, sc_cfg, γ):
# src/utils/generation_utils.py — _generate_samples_single_batch (简化)
@torch.no_grad()
def _generate_samples_single_batch(model, generator, z, t_steps,
cond_seq, cond_seq_mask,
config, sampling_config,
cfg_scale, self_cond_cfg_scale):
method = sampling_config.sampling_method # 'ode' or 'sde'
B, L, d = z.shape # d = 512
if cond_seq is None: # OWT 无条件生成
cond_seq = torch.zeros((B, L, d), dtype=z.dtype, device=z.device)
cond_seq_mask = torch.zeros((B, L), dtype=z.dtype, device=z.device)
# 在条件位置上把 z 设回 clean cond_seq(cond 不去噪)
z = restore_cond(z, cond_seq, cond_seq_mask)
x_pred = restore_cond(torch.zeros_like(z), cond_seq, cond_seq_mask)
n = t_steps.shape[0] # = n_steps + 1
sde_gamma = getattr(sampling_config, "sde_gamma", 0.0)
use_bf16 = config.use_bf16 and z.is_cuda
with torch.amp.autocast('cuda', dtype=torch.bfloat16, enabled=use_bf16):
# n_steps - 2 个中间步用 ODE/SDE
for i in range(n - 2):
t = t_steps[i].item()
t_next = t_steps[i + 1].item()
if method == "sde":
z, x_pred = _sde_step(z, t, t_next, x_pred, gamma=sde_gamma, ...)
else: # 'ode'
z, x_pred = _ode_step(z, t, t_next, x_pred, ...)
# 最后一步强制 ODE — t 接近 1 不再注入新噪声
z, x_pred = _ode_step(z, t_steps[-2], t_steps[-1], x_pred, ...)
return z # [B, L, 512]
@torch.no_grad()
def _dlm_decode_batch(z, model, t_final_val, config, self_cond_cfg_scale):
# 最终一步把 latent z (≈ clean embedding) 映射回 token IDs
B = z.shape[0]
t_final = torch.full((B,), float(t_final_val), dtype=z.dtype, device=z.device)
sc_batch = (torch.full((B,), float(self_cond_cfg_scale), dtype=z.dtype, device=z.device)
if config.num_self_cond_cfg_tokens > 0 else None)
# 推理时 self-cond 输入永远 = zeros(与训练 decoder 分支一致)
z_input = torch.cat([z, torch.zeros_like(z)], dim=-1) if config.self_cond_prob > 0 else z
with torch.amp.autocast('cuda', dtype=torch.bfloat16, enabled=config.use_bf16):
_, decoder_logits = model(
z_input, t_final,
self_cond_cfg_scale=sc_batch,
decoder_step_active=True, # ← mode token gate = 1
)
return decoder_logits.argmax(dim=-1) # [B, L]
关键点:主循环 + 最终 decode 是两次独立 forward。前者把噪声 ε transport 到接近 clean x; 后者把 clean embedding 投影到 vocab。
def _ode_step(model, z, t, t_next, x_pred_prev,
config, cfg_scale, self_cond_cfg_scale,
cond_seq, cond_seq_mask):
t_batch = torch.full((z.shape[0],), float(t), dtype=z.dtype, device=z.device)
v_pred, x_pred = _forward_sample(
model=model, z=z, t_batch=t_batch, x_pred_prev=x_pred_prev,
config=config, cfg_scale=cfg_scale, self_cond_cfg_scale=self_cond_cfg_scale,
cond_seq=cond_seq, cond_seq_mask=cond_seq_mask,
)
return z + (t_next - t) * v_pred, x_pred # Euler: z_{i+1} = z_i + dt · v
def _sde_step(model, z, t, t_next, x_pred_prev,
config, cfg_scale, self_cond_cfg_scale,
cond_seq, cond_seq_mask, gamma, generator):
h = float(t_next - t)
alpha = max(0.0, min(1.0, 1.0 - gamma * h)) # 信号保留比例 ∈ [0,1]
t_back = alpha * float(t) # 时间往回拉到 α·t
eps = torch.randn(z.shape, dtype=z.dtype, device=z.device) * config.denoiser_noise_scale
z_back = restore_cond(alpha * z + (1.0 - alpha) * eps, cond_seq, cond_seq_mask)
t_batch = torch.full((z.shape[0],), t_back, dtype=z.dtype, device=z.device)
v_pred, x_pred = _forward_sample(
model=model, z=z_back, t_batch=t_batch, x_pred_prev=x_pred_prev,
config=config, cfg_scale=cfg_scale, self_cond_cfg_scale=self_cond_cfg_scale,
cond_seq=cond_seq, cond_seq_mask=cond_seq_mask,
)
# 从 backtracked state 用 Euler 一步推到 t_next
return z_back + (t_next - t_back) * v_pred, x_pred
三个直觉:
γ = 0: alpha = 1, t_back = t, z_back = z — 退化为 ODE Eulerγ > 0: alpha < 1,z 被 α 缩小并注入新噪声 (1-α)·ε — 相当于把"已经去噪一些"的状态拉回到更早时刻,再 denoise 一次# 函数签名默认参数是 P_mean=-0.8(旧默认)
# 但 caller 在 generation.py 中传入 config.denoiser_p_mean = -1.5(OWT 训练值)
def get_sampling_steps(n_steps, time_schedule="logit_normal",
P_mean=-0.8, P_std=0.8, device=None, dtype=torch.float32):
if time_schedule == "uniform":
return torch.linspace(0.0, 1.0, n_steps + 1, dtype=dtype, device=device)
# logit-normal:从训练同分布抽 n_steps - 1 个点
z = torch.randn((n_steps - 1,), dtype=dtype, device=device) * P_std + P_mean
steps = torch.sigmoid(z)
steps = torch.sort(steps).values # 升序排列
# 强制端点 0 / 1
lo = torch.zeros((1,), dtype=dtype, device=steps.device)
hi = torch.ones((1,), dtype=dtype, device=steps.device)
return torch.cat([lo, steps, hi], dim=0) # [n_steps + 1]
论文 App B.2:"smaller intervals when t is close to 0 and larger intervals as t approaches 1"。 P_mean=−1.5 使 sigmoid(N(−1.5, 0.64)) 的密度向 t 较小的 noisy 区集中, 所以中间 sample 点偏小 → 排序后noisy 区间隔细密、clean 区间隔粗,与训练 t 同分布匹配。
真实的每步 forward要处理两层 CFG:
self_cond_cfg_scale token 编码 ω# sampling_utils.py — 双层 CFG 嵌套
def _forward_sample(model, z, t_batch, x_pred_prev, config,
cfg_scale, self_cond_cfg_scale, cond_seq, cond_seq_mask):
# 内层:带 cond 的 forward(条件 token + SC-CFG)
v_cond, x_cond = _forward_sample_self_cond(
model, z, t_batch, x_pred_prev, config,
self_cond_cfg_scale=self_cond_cfg_scale,
cond_seq=cond_seq, cond_seq_mask=cond_seq_mask,
)
if cfg_scale == 1.0:
return v_cond, x_cond # 无 input-cond CFG,直接返回
# 外层:input-cond CFG → 再跑一次 uncond
z_uncond = restore_cond(z, torch.zeros_like(z), cond_seq_mask)
x_pred_prev_uncond = (None if x_pred_prev is None else
restore_cond(x_pred_prev, torch.zeros_like(x_pred_prev), cond_seq_mask))
v_uncond, x_uncond = _forward_sample_self_cond(
model, z_uncond, t_batch, x_pred_prev_uncond, config,
self_cond_cfg_scale=self_cond_cfg_scale,
cond_seq=torch.zeros_like(cond_seq), cond_seq_mask=cond_seq_mask,
)
# 标准 CFG 线性组合
v_out = v_uncond + cfg_scale * (v_cond - v_uncond)
x_out = x_uncond + cfg_scale * (x_cond - x_uncond)
return restore_vx(v_out, x_out, cond_seq, cond_seq_mask)
每步代价(worst case,cond 生成):2 次模型 forward(cond + uncond)。
OWT 无条件生成 cfg_scale=1,每步只1 次 forward。SC-CFG 被烤进训练所以不用 ×2。
| 方法 | 每步 forward 数 | 32 步总 forward | 说明 |
|---|---|---|---|
| 传统 inference-time CFG 方法(启用时) | 2 (cond + uncond) | — | 常用于 MDLM/LLaDA/Duo 等启用 CFG 的配置 |
| ELF OWT 无条件 (SC-CFG baked-in, cfg=1) | 1 | 32 + 1 decode | SC-CFG 单次 forward 已含 ω 信息 |
| ELF 条件 (input-cond CFG=2) | 2 (cond + uncond) | 64 + 1 decode | SC-CFG=1 不额外 ×2,input-cond CFG 才 ×2 |
论文 page 26 Table 6(6 个 evaluation seed 平均 ± SE):
| Steps | SC-CFG | γ | Gen-PPL ↓ | Entropy ↑ |
|---|---|---|---|---|
| 8 | 3 | 2.0 | 67.32 ± 2.25 | 5.14 ± 0.085 |
| 16 | 3 | 2.0 | 33.66 ± 1.09 | 5.16 ± 0.026 |
| 32 | 3 | 1.5 | 24.08 ± 0.16 | 5.15 ± 0.002 |
三个观察:
论文 page 26 Table 7(下表 PPL 均为 Gen-PPL)。三个 size × 两种采样器 × CFG sweep。SDE 全方位优于 ODE。 表内标 ⁱ 的灰格是 CFG>3 后 PPL 反转上升的 poor-generation cell(详见表后脚注;论文用 entropy<5.0 或 PPL>300 划红区):
| Sampler | SC-CFG | ELF-B 105M (PPL/Ent) | ELF-M 342M (PPL/Ent) | ELF-L 652M (PPL/Ent) |
|---|---|---|---|---|
| SDE (γ=1.0) | 0.5 | 36.77 / 5.28 | 39.21 / 5.35 | 37.50 / 5.41 |
| 1.0 | 29.50 / 5.23 | 33.45 / 5.30 | 31.82 / 5.37 | |
| 1.5 | 25.25 / 5.18 | 28.42 / 5.26 | 28.72 / 5.35 | |
| 2.0 | 22.53 / 5.14 | 25.34 / 5.23 | 26.47 / 5.32 | |
| 3.0 | 19.72 / 5.10 | 21.69 / 5.18 | 23.31 / 5.28 | |
| 3.5 | 37.56 / 5.30 ⁱ | 36.48 / 5.34 ⁱ | 22.28 / 5.27 | |
| 4.0 | 36.50 / 5.29 ⁱ | 34.93 / 5.33 ⁱ | 21.37 / 5.26 | |
| ODE | 0.5 | 104.29 / 5.51 | 88.51 / 5.51 | 68.27 / 5.52 |
| 1.0 | 65.30 / 5.40 | 62.47 / 5.44 | 49.72 / 5.45 | |
| 1.5 | 44.85 / 5.31 | 46.71 / 5.37 | 39.97 / 5.40 | |
| 2.0 | 34.65 / 5.23 | 37.66 / 5.32 | 33.72 / 5.36 | |
| 3.0 | 26.62 / 5.15 | 28.80 / 5.24 | 26.57 / 5.29 |
ⁱ = 论文表中标灰的 cell:CFG > 3 后 ELF-B / ELF-M 出现 PPL 反转/上升, 不是 entropy < 5.0;它们 entropy 仍 ≥5.29。论文 App C 同时用 entropy < 5 或 PPL > 300 作"poor generation"红区(entropy 指生成 token 分布的熵,越低越接近复读/退化,故太低也不算 valid)。
关键观察(数字均来自 Tab 7 valid 区间内):
src/configs/sampling_configs/cond_sampling_configs.yml:
- sampling_method: ode
num_sampling_steps: [64]
cfgs: [2] # input-cond CFG = 2
self_cond_cfg_scales: [1] # SC-CFG = 1(已烤进训练)
time_schedule: logit_normal
条件生成默认使用 ODE(paper/config 默认;论文未明确解释,可理解为条件任务在标准 CFG=2 下已经稳定,不需要 SDE 额外随机性)。 SC-CFG=1 因为推理时不需要额外 ω 调整;input-cond CFG=2 是标准 image diffusion 推理时 CFG。
README 承诺:PyTorch port 在 8× L40S / H200 上跑应与 paper TPU v5p-64 数字漂移 ≲ 1 Gen-PPL 或 < 0.5 BLEU/ROUGE。漂移来源:
optax.contrib.muonREADME 强调用 use_bf16=true(匹配训练 precision)和 use_compile=true(torch.compile ~3-4× speedup)作为推荐 eval flag。
论文 Alg 6 写 α = 1 − γ·dt。但 dt 不是 1/N(uniform 间隔),是 logit-normal 抽出来的——
不同步的 dt 跨度差很大(t 靠近 0 时 dt 可能 0.01,靠近 1 时可能 0.4)。
所以同一个 γ 在不同 step 里实际"重置强度"完全不同。代码用 clip(α, 0, 1),仅在 γ·dt ≥ 1 时才把 α clip 到 0:
比如 32-step 默认 γ=1.5,最大 dt 通常 ~0.4,γ·dt = 0.6 不会 clip;
但 8-step γ=2.0 时如果 dt ≥ 0.5 就会 clip。
这是 ELF 实现的关键细节,建议念 paper 时跟代码对照(src/utils/sampling_utils.py:226-251)。
Tab 7 各 size 内自身最低 valid PPL(valid = 落在 entropy 合理区): ELF-B SDE CFG=3 → 19.72;ELF-M SDE CFG=3 → 21.69;ELF-L SDE CFG=4 → 21.37。 绝对最低是 ELF-B 的 19.72;ELF-L 21.37 比 ELF-M 21.69 更低。 但 paper 强调 frontier 而非单点最低:scaling 整体向更优方向推进(同 entropy 下更低 PPL,同 PPL 下更高 entropy)。
ELF-B 32 步达到 Gen-PPL ≈ 24。从 Fig 7(a) 视觉读数:ELF-B 32-step 已接近或优于若干 baseline 在高 step(如 1024)下的水平, 推理时间 substantially less than prior methods(paper §4.2 原文)。
三种蒸馏过的 few-step variant:
这些都需要额外蒸馏阶段(10K-100K extra steps),但 32 步 PPL 还是不如未蒸馏的 ELF-B。 即在这套系统配置下,ELF 不加额外 distillation 仍然超过这些 distilled baselines("架构层面的优势"是我对这个现象的解读,非论文逐字 claim)。
详见 4.9 节 Tab 5:ELF 45.2B 总 token,所有 baseline 都在 524-577B 区间, 其中蒸馏 variant 因为还要加蒸馏 epoch,token 更多。精确 ratio 11.6× / 12.2× / 12.8×(paper 简称 ~12×)。 Tab 5 只统计 ELF/baseline 自身训练与蒸馏 token,不包含外部 encoder 预训练成本。
| Model | Size | De-En BLEU ↑ | XSum R-1 ↑ | R-2 ↑ | R-L ↑ |
|---|---|---|---|---|---|
| AR (Transformer) | 99M | 25.2 | 30.5 ± 0.13 | 10.2 ± 0.11 | 24.4 ± 0.12 |
| MDLM | 99M | 18.4 | 33.4 ± 0.11 | 11.6 ± 0.10 | 25.8 ± 0.10 |
| Duo | 170M+35M | 21.3 ‡ | 31.4 ± 0.12 | 10.1 ± 0.10 | 25.0 ± 0.12 |
| E2D2 | 99M | 24.8 | 28.4 ± 0.11 | 8.3 ± 0.09 | 22.0 ± 0.10 |
| SeqDiffuSeq | — | 21.3 | 19.3 † | 1.7 † | 14.1 † |
| CDCD | — | 24.9 | — | — | — |
| ELF-B (ours) | 105M+35M | 26.4 | 36.0 ± 0.13 | 12.2 ± 0.11 | 27.8 ± 0.12 |
† = 直接取自该方法原 paper(De-En 数据的默认来源); ‡ = ELF 团队用公开 codebase 重跑(XSum 数据的默认来源);Duo De-En 在 ELF 团队的对比里也是 ‡(重跑)。
重要观察:
论文 App C (pages 20-23) 系统 ablate 7 个设计选择。这是"为什么 ELF 能 work"的实证支撑:
| Ablation | Sweep | 结论 / 默认 | 差距 |
|---|---|---|---|
| C.1 Prediction target | x-pred / v-pred / ε-pred | x-pred 全 dim 稳定 | ε-pred 全 dim 都 collapse(512/768/1024);v-pred 在 512 dim ok,越高越差 |
| C.2 Bottleneck dim | 32 / 128 / 512 | 128 最佳 frontier | 32 偏低 entropy;512 偏高 PPL;128 balance |
| C.3 Denoiser mode prob | 0.2 / 0.5 / 0.8 | 0.8 (denoise) / 0.2 (decode) | 0.5 / 0.2 (denoise) PPL/entropy frontier 都明显劣化 |
| C.4 Conditioning style | in-context tokens / adaLN-Zero | in-context 略优 + 省 43M 参数 | 性能 ≈ adaLN,但 ELF-B 148M → 105M |
| C.5 Optimizer | Muon / AdamW | Muon 全面优于 AdamW | SDE 下差距最显著(paper 定性 "more pronounced under SDE") |
| C.6 Sampler + time grid | ODE / SDE × uniform / logit-normal | SDE + logit-normal | SDE 通常降低 PPL(幅度随 model/CFG 变化,CFG=2 时 21-35%);logit-normal 在各 step 都更优,few-step 时尤其 |
| C.7 Cond CFG scale | 1 / 2 / 3 / 4 | CFG=2 最佳 | 1→2 substantially improves;3、4 逐步下降。注:这是条件生成的 input-cond CFG,与 §6.8/Tab 7 主结果里的 SC-CFG(self-cond CFG,最优在 3/4)是不同量 |
三种 prediction 是数学等价的,但训练 signal 完全不同。论文用三个 encoder size (T5-small/base/large = 512/768/1024 dim) sweep:
解释:clean 文本数据在 embedding 空间是低维流形。x-pred 预测的就是这个流形上的点; ε-pred 预测的是高维 Gaussian,模型必须学一个全维度等熵分布——更难。 这条 finding 支持"continuous DLM 的关键不是连续,是 x-prediction" 的 framing。
很反直觉:如果 decoder 占比上升到 0.5,按理说 decoder 应该学得更好——但实际整体 frontier 都退化。 解释:decoder 共享 transformer 主干。如果训练时频繁切换 mode,主干被两个目标拉扯;只占 20% 时 decoder 学到的是 "在已经 transport 到 clean 的 embedding 上做最后一步映射"——比例小但效果反而好。
Logit-normal time grid 让 noisy 区间更密——8-step 时这极其重要。
Uniform 8 步 PPL 远高于 logit-normal 8 步。32-step 之后两者差距收窄。
γ sweep(paper 选默认值):8/16 步默认 γ=2.0;32 步 γ=1.5;64 步 γ=1.0。
论文文字说 γ 控制 PPL/entropy trade-off,paper 默认 γ=1.0 作为各 step budget 的 balance;
8/16 步用更大 γ=2.0 是因为粗步长需要更多 stochasticity 修正噪声累积。
| 表 | 覆盖 | 最佳 valid Gen-PPL |
|---|---|---|
| Tab 6 (system-level, ELF-B, 6 seeds) | 8/16/32 step | 32-step SDE γ=1.5 SC-CFG=3: 24.08 ± 0.16 |
| Tab 7 (scaling, 64-step) | B/M/L × ODE/SDE × CFG 0.5-4 | 各 size 内最低 valid PPL:ELF-B 19.72 (CFG=3) / ELF-M 21.69 (CFG=3) / ELF-L 21.37 (CFG=4)。CFG>3 部分 cell 标灰 |
use_bf16 + use_compile这张图回答了"continuous DLM 到底在做什么"——它在 embedding 空间里描出一条平滑轨迹, 最后一步才把轨迹终点投影到 token 词表。和离散 DLM 每步都做 vocab argmax 完全是两套范式。
Cola-DLM(ByteDance Seed, arXiv 2605.06548, 2026 年 5 月) 是和 ELF 几乎同时(2 周内)冒出来的同类工作。两边都是 continuous DLM,但设计哲学几乎相反: ELF 求简,Cola 求强。下面是我从 Cola 的公开 blog 和 arXiv 摘要整理的对比。
ELF = "把 encoder 冻结,diffusion 只学 transport"——最小架构、最高数据效率,刷 OWT Gen-PPL。
Cola-DLM = "diffusion 不应该恢复 noisy token observation,应该建模 semantic latent prior"—— 两阶段训练(VAE pre-train → joint VAE+DiT)+ block-wise 推理,~2B 参数,刷 reasoning task average。
| 维度 | ELF (MIT) | Cola-DLM (ByteDance) |
|---|---|---|
| 核心对象 | Contextual embedding 上 Flow Matching | Text VAE latent 上 block-causal FM |
| Encoder | Frozen T5-small (35M) | Learnable Text VAE (~500M) |
| Latent space | Token-aligned, 512-d (bottleneck 128) | Explicit z, d=16 (默认) |
| Diffusion 目标 | 恢复 clean contextual embedding | 把噪声 transport 到 learned latent prior |
| Decoder | 共享 Transformer (final-step 切换 mode) | 独立 VAE decoder + KV cache |
| Backbone | DiT,全双向 attention | Block-causal DiT (intra-block 双向、inter-block 因果) |
| 训练 | 单阶段 80/20 mix | 两阶段训练 (VAE pretrain → joint VAE+DiT) + block-wise 推理 |
| 损失 | 80% MSE + 20% CE | Stage 2: λVAE·LVAE + λFM·LFM + λref·KL(q‖qref) |
| 参数 | 105M / 342M / 652M | ~2.3B 总 (1.8B DiT + 500M VAE) |
| 采样 | 32-64 步 ODE/SDE Euler | 8-16 步 / block, block-causal + KV cache |
| 评测 | Gen-PPL, BLEU, ROUGE | Task Avg (LAMBADA/MMLU/SIQA/RACE/...) |
| 对标 | MDLM/Duo/FLM/LangFlow | AR + LLaDA at 2B |
显式 latent 边际化p(x)=∫p(x|z)p(z)dz | 无(直接对 contextual embedding 做 FM,最后一步 decode 到 token) | 有(VAE 提供 p(x|z),diffusion 学 p(z)) |
| 训练中间步 token-space loss | 不做(中间全 MSE on embedding,仅末步混 20% CE) | 不做(diffusion 在 latent 空间,CE 由 VAE 在两阶段训练分担) |
| Scaling ceiling | 论文未给更大 encoder 的 scaling 曲线(C.1 ablation 只到 T5-large 1024-d) | blog 报 curve still rising 至 2000 EFLOPs |
p(x)=∫p(x|z)p(z)dz,diffusion 建 pψ(z)、VAE decoder 建 pθ(x|z)Cola 不像 ELF 那样单阶段端到端训练。两个训练阶段 + 独立的推理流程:
| 阶段 | 训练对象 | 损失 | 目标 |
|---|---|---|---|
| Stage 1 — VAE pretraining | Text VAE encoder + decoder | LVAE = −𝔼[log pθ(x|z0)] + β · KL(qφ‖pbase) + λmask·Lmask (带 BERT-style masking loss) |
学一个稳定的 text↔latent 映射,避免 semantic collapse / decoder shortcutting |
| Stage 2 — Joint VAE + DiT | VAE + block-causal DiT | Lstage2 = λVAE·LVAE-like + λfm·LFM + λref·KL(qφ‖qφ_ref) (reference KL 抑制 latent drift) |
DiT 在已稳定的 latent 上学 flow matching prior;reference KL 不让 VAE 漂移 |
| Inference (非训练阶段) | — | — | Block-wise prior transport(DiT 生成 latent block)+ VAE decoder(latent → tokens, KV-cached) |
关键超参(来自 blog RQ4 sweet spot 表,released checkpoint 配置):
Cola DiT 在序列维度上做块因果(block-causal)分解:
这种设计的好处:
Block size sweep(来自他们 RQ2/RQ3):
Cola 不报 Gen-PPL。他们用"generative few-shot"协议把多选题转成生成任务, 和 AR + LLaDA 在 ~2B 同 scale 下对比。来自 ByteDance-Seed/Cola-DLM GitHub model card 的 released 数字(注意多选任务 4 选 1 随机基线 ≈ 25,下表 HellaSwag/RACE/OBQA 等接近或低于随机,反映 latent 模型当前在多选 few-shot 上偏弱):
| Task | Cola @ 2000 EFLOPs | 说明 |
|---|---|---|
| LAMBADA | 50.80 | 段落补全 |
| MMLU | 19.30 | 57 类多选 — 数字仍低于 AR 同 scale |
| SIQA | 28.90 | 社会场景推理 |
| RACE | 19.60 | 阅读理解 |
| Story Cloze | 30.77 | 故事结尾选择 |
| OBQA | 23.00 | 选择题常识问答 (OpenBookQA) |
| HellaSwag | 10.70 | 常识 NLI |
| SQuAD | 30.90 | 抽取式问答 |
| Task Avg | 26.75 | 8 任务平均 |
注意:blog 内 RQ2/RQ3 ablation 表给出更小的训练 budget 下数字(LAMBADA 31.1-34.6 / MMLU 5.4-10.1 / SIQA 11.1-23.6 等), 但这些是消融区间,不是 headline。上面才是 released checkpoint 数字。
他们的 scaling 实验跑到 ~2000 EFLOPs。Blog 说 "Task Avg 在 ~2000 EFLOPs 还在 rising"—— 官方称仍有 headroom,没看到饱和。
| 消融 | 结论 |
|---|---|
| Fixed-VAE vs Joint-VAE training | Joint 在大算力下 win。冻结 VAE 训练 (像 ELF 那样) 在小算力 ok,但 scaling 时无法继续涨 |
| All-Scratch baseline | 从零训 VAE + DiT 始终不如先 Stage 1 pretraining |
| Interval freezing (Stage 1.5) | VAE 在中间阶段冻结一段再放开 — 比一直 joint 差 |
| Sampling steps | 少步数明显不足;~8-10 步基本恢复;16-32 步饱和 |
| Patch size 2 (压缩 latent) | 整体比 patch 1 差;但当 prompt length 对齐 patch 边界时反而略好(18.12 vs 17.31)— "Token-level segmentation 不一定最优" |
这里有个有意思的对比:Cola 通过 ablation 证明 joint VAE+DiT 训练在大算力下更优;ELF 通过 ablation 证明 frozen encoder 在小数据小算力下更优。两个结论不矛盾—— 它们对应不同的训练规模、不同的"哪个组件更值得优化容量"的判断。
ELF 与 Cola 是互补的两条路线:ELF 用最小架构(frozen encoder + 共享 decoder)在小尺度把生成质量(Gen-PPL)推到极致;Cola 用 explicit latent + block-causal DiT 在 ~2B 尺度上换来 reasoning task 与 AR-style serving。完整技术差异已汇总在 §9.1 的对比表——选哪条,取决于你要的是数据效率,还是大尺度可扩展性 + 推理能力。
聊 ELF 时我们反复绕回一个问题:要在 token 上做 continuous diffusion,那个连续表征坐标系到底从哪来?ELF 的答案是「外借」——拿一个 frozen 的 T5 contextual embedding 当现成坐标系;Cola 的答案是「自己学一个」——联合训练一个 VAE latent,把离散 token 压进一个学出来的连续隐空间。本节要介绍的 DSL(Discrete Stochastic Localization for Non-autoregressive Generation)给出了第三种答案。但它真正吸引人的地方不在「坐标系怎么来」,而在一个训好的网络能同时扮演 continuous diffusion、discrete masked diffusion、autoregressive 三种角色。这条「统一」的主线,是理解 DSL 的钥匙。
论文一作是 吴云舒(Yunshu Wu),合作者包括 Jiayi Cheng、Longxuan Yu、Partha Thakuria、Rob Brekelmans、Evangelos E. Papalexakis、Greg Ver Steeg(UC Riverside 等)。
先把名字拆清楚。"Discrete" 指的是建模对象是离散 token,而建模本身是连续的——DSL 是一个 continuous-state framework,每个 token 被嵌到一个 unit-sphere(单位球面)上。所以它不是「离散扩散」的又一个变体,而是把离散符号搬进连续球面几何里去做的框架。abstract 点出了 DSL 的核心性质——Bayes-optimal denoiser 在 localization channel 下对名义 SNR 不变;但没写出具体 channel 形式与对应 SDE 方程,本节也不去编它的数学机制,只讲它带来的可观察现象。
DSL 的统一能力,根子上来自一条理论性质:它的 Bayes-optimal denoiser 在 localization channel 下对名义 SNR 不变(invariant to the nominal SNR under the localization channel)。直觉上的后果是——同一个训好的网络能支持一整族 per-token SNR path(an entire family of per-token SNR paths):序列里不同 token 可以各自拥有自己的 SNR 路径设计。
而这正是「统一」的第一块拼图:masked diffusion 不是另一个并列的范式,而是 DSL 的一个特例。论文把它表述为 "endpoint masked-diffusion paths as a special case"——当 SNR path 取到端点形态时,per-token SNR 框架就退化成我们熟悉的 masked diffusion。换句话说,masked diffusion 是这一族路径里的一条特殊曲线,而不是一个对手。
「一整族 SNR path」还带来一个很漂亮的副产品,也是 DSL 和 ELF / Cola 在表征哲学上最不同的地方:DSL 的表征几何没有一个固定形态,而是随 SNR 从粗到细(coarse-to-fine)动态涌现。同一个 token 在不同 SNR 下「是什么」会变:
SNR = 0 ───────▶ finite SNR ───────▶ high SNR
(零信息) (软证据) (硬锚点)
解码为 [MASK] soft evidence 靠近 hard token anchor
什么都还没说 带方向性的软证据 贴近一个具体 token 锚点
原文:"zero information decodes as [MASK], finite SNR carries
soft evidence, and high SNR localizes near hard token anchors."
└── 几何不是事先给定,也不是训练后冻结,而是沿 SNR 一层层「长」出来 ──┘
这条轨迹读起来很像一种层次化的细化:低 SNR 时模型只能给出最粗的判断,SNR 升高后逐步靠近某个具体词的锚点。表征几何既不是预先外借的(ELF),也不是学一个固定 latent(Cola),而是由信噪比这条轴自然组织出来的——也可以把它当作 "stochastic localization" 这个名字的一种直觉读法。
因为 denoiser 对 SNR 免疫、又内含一整族 per-token 路径,DSL 得以让同一个训好的 checkpoint在推理时切换成不同的生成模式,而不需要重训:
这才是「统一 continuous / discrete / AR」这句话的具体兑现:不是三个模型、三套权重,而是一个 checkpoint 在推理时长出三张脸。
一个很实用的副产物:可以从现成的 MDLM checkpoint 直接 fine-tune。由于 masked diffusion 本就是 DSL 路径族的端点特例,论文指出可以从一个 pretrained MDLM checkpoint 直接 fine-tune("without retraining",无需从头重训),提升 distributional faithfulness(分布层面的逼真度)。对已经手握 masked diffusion 模型的人来说,不必从零训练就能迁移进这套框架。
效率上,DSL 报告可以用少至 T=48 步完成生成("as few as T=48 total steps—without distillation",注意是不经蒸馏的),论文同时测试了 T=128–1024 的更长步数范围。评测指标用的是 MAUVE,衡量在 OpenWebText 上生成分布与真实分布的 distributional faithfulness。
口径提醒(别横比)。MAUVE 衡量的是「生成分布有多像真实分布」,和我们之前讨论 ELF 时用的 Gen-PPL 不是同一个东西,数值不可直接对照。本节也不去声称 DSL「打败」了 ELF 或 Cola——三者指标不同、setting 不同、规模不同,把它们排个名次没有意义。DSL 在这条 blog 里的价值,是提供了处理 token 表征的第三种思路,以及一套「一个 checkpoint 统一三范式」的视角。
把 ELF、Cola、DSL 放在一起,最清楚的对比维度就是「连续表征从哪来、是不是固定的」:
| 维度 | ELF | Cola | DSL |
|---|---|---|---|
| 表征坐标系怎么来 | 外借现成的 frozen T5 contextual embedding | 联合训练一个 VAE latent | 在 token 的 unit-sphere 几何上做 stochastic localization |
| 几何是否固定 | 固定(预先给定、冻结) | 固定(训练后定下来的 latent) | 不固定——随 SNR 动态涌现(coarse-to-fine) |
| 与 masked diffusion 的关系 | 独立路线,与 masked diffusion 正交 | 独立路线,与 masked diffusion 正交 | masked-diffusion path 是其路径族的端点特例 |
| 一句话定位 | 坐标系「外借」 | 坐标系「自己学一个固定的」 | 坐标系「随信噪比长出来」,一个 checkpoint 统一三范式 |
前两条都在回答「给 continuous diffusion 一个好坐标系」,区别只是借还是学;DSL 把问题换了个问法——表征几何随 SNR 动态涌现,而不是预先外借或训练成一个固定 latent。顺着这个改写,它把 masked diffusion 收成特例,也把 continuous / discrete / AR 三种采样方式统进了同一个 denoiser。如果说 ELF 与 Cola 回答的是「该把 token 放进哪个连续坐标系」,DSL 提供的,是理解 continuous DLM 表征的第三个有用视角。
前面 §10 把 DSL(Discrete Stochastic Localization,Wu et al. 2026)讲成了一条统一主线:离散 token 的「mask / unmask」可以被还原成一个连续的「信噪比逐渐升高」过程,于是 discrete diffusion 与 continuous diffusion 在同一套数学框架下握手。但 §10 主体里还留了一句话没展开——「可以从一个现成的 discrete dLLM checkpoint 出发,通过 continue-pretrain 给它赋予 continuity」。这句话听起来很美好,可它真能在 8B 这种规模上兑现,还是只是小模型上的玩具结论?
DSL-LLaDA: Scaling Continuous Denoising to 8B Masked Diffusion LMs(Longxuan Yu 与 吴云舒(Yunshu Wu)共同一作,equal contribution;另有 Yu Fu、Siheng Xiong、Rob Brekelmans、Hui Liu、Yue Dong,通讯作者 Greg Ver Steeg;单位 UC Riverside / Georgia Tech / Microsoft;arXiv 2606.01024)给出的答案相当干脆:不用从头重训,只 continue-pretrain 1000 步(在 FineWeb-Edu 上),就能把现成的 LLaDA-8B-Instruct 从「二值 mask 去噪器」改造成「连续 embedding 空间去噪器」。它的定位很清楚:不是一个新框架,而是 §10 DSL 框架的直接 8B scale-up 实证——这正是 §10 主线最想要的那块拼图。
传统 masked dLLM(LLaDA、MDLM 这一类)看世界只有两档:一个 token 要么被完全 mask(信息为零),要么是 gold token(信息完整)。DSL-LLaDA 的训练改造很集中——把 0/1 的硬 mask,换成 §10 里 DSL 的连续 per-token 高斯软 mask(推理侧相应改用 SDE sampler,见 §10.6.2)。对每个 clean token x_i 采一个信噪比 γ_i,构造一个带噪 embedding(DSL 把 token 的 noise embedding 约束在 §10 开头提到的单位球面上,高斯噪声就加在这个球面表示上——这也补上了 §10 里 DSL abstract 未展开的具体加噪形式):
对序列里每个位置 i:
采样 SNR γ_i ~ schedule # γ 是这个 token「已经露出多少信息」的旋钮
取 clean token 的 embedding e_{x_i}
采高斯噪声 ε_i ~ N(0, I)
构造 noisy embedding:
z_i = γ_i · e_{x_i} + sqrt(γ_i) · ε_i
# 三档连续过渡(不再是 0/1 两档):
# γ_i → 0 : z_i = 0, 不含 clean 信息 → 等价于「被 mask」 (unknown)
# γ_i 适中 : 信息半遮半掩 → 像一个随机 token (unreliable)
# γ_i 很大 : 噪声被淹没 → 基本等于 gold token (clear)
注意 γ_i = 0 时 z_i = 0(不含任何 clean-token 信息),正好对应原模型「全 mask」那一档;γ_i 增大时信息越来越清晰。也就是说,原来的二值 mask 只是这条连续谱的两个端点,DSL-LLaDA 只是把中间那段连续地填了进来。一个轻量的 softmax converter 负责把这些 noisy embedding 映回 backbone 的输入空间,于是每个位置就沿 SNR 轴自然走过三个阶段:unknown(低 SNR,像 mask)→ unreliable(中 SNR,像随机 token)→ clear(高 SNR,gold token)。
有两个设计上的便利,正是这套改造能「1000 步搞定」的根本原因:
对比 §10 主线就很清楚了:DSL 原文提出框架并在较小规模上验证,DSL-LLaDA 做的就是「scale 到 8B」这一步,证明那句「从 discrete checkpoint continue-pretrain 赋予 continuity」不是小模型特例,在 8B 上花 1000 步就能落地。
训练给了模型「连续噪声」的概念,推理这边对应的就是一个 SDE-based continuous sampler:沿着一个逐渐升高的 SNR schedule 做 Heun 步,并且把 hard token commitment(真正把某个位置落子成离散 token)一直推迟到最后一步。直观地说,中间过程全程在连续 embedding 空间里「软着陆」,避免了离散采样里那种「早早把一个位置钉死、后面发现错了也改不动」的尴尬。
输入: 8B backbone (continue-pretrain 1000 步), 升序 SNR schedule {γ_1 < γ_2 < ... < γ_T}
初始: 所有位置处于低 SNR (unknown / mask-like) 状态
for t = 1 .. T: # SNR 由低到高
用 softmax converter 把当前 noisy state 映回输入空间
backbone 预测各位置的 clean-token 分布
沿 SDE 做一个 Heun step(提升有效 SNR, 逐渐进入 clear 阶段)
# 注意: 此处不做 hard token commitment, 各位置仍可被后续步骤修正
最后一步:
才把各位置 commit 成离散 token ← 延迟 hard commitment
离散 dLLM (iterative unmask):
step1 step2 step3 ... 每步硬选若干位置 → 定了就不能改
[M M x M] → [M y x M] → [t y x M] → ... ← 早期错误会被"焊死"
DSL-LLaDA (continuous SDE):
沿 SNR 升高方向做 Heun 步, 全程在连续 embedding 空间软演化
γ↑ ───────────────────────────────────▶ 仅最后一步 commit 成离散 token
(整段都可被后续步骤修正, 延迟落子)
「延迟硬性敲定」这个特性,正是下面几个有趣行为的来源。
这篇工作最有意思的地方在于:仅仅把二值 mask 换成连续噪声,模型就涌现出三种在本文比较的 discrete iterative-unmasking baseline 中都没有观察到的定性行为:
下表是 zero-shot summarization 的 ROUGE-1 数字(越高越好,低 NFE 设置)。论文报告 DSL-LLaDA-SDE 在 NFE ≤ 16 上四个 benchmark 均最佳;下表并列论文 Table 2 给出的离散 LLaDA 基线(NFE=8/16)。注意在更高 NFE(32/64)上离散 LLaDA 在 PubMed/arXiv 反超——SDE 的优势集中在低 NFE。
| Benchmark | DSL-LLaDA-SDE (NFE=8) |
DSL-LLaDA-SDE (NFE=16) |
LLaDA 对照 (NFE=8 / 16) |
|---|---|---|---|
| XSum | 28.4 | 30.4 | 25.2 / 29.0 |
| CNN-DM | 28.1 | 33.0 | 23.1 / 28.2 |
| PubMed | 29.4 | 32.2 | 11.9 / 20.2 |
| arXiv | 27.3 | 28.6 | 10.6 / 16.2 |
注:数字取自论文 Table 2(每 benchmark 1000 samples,seed=42)。另有 LLM-as-judge(GPT-5.4)评测,NFE=16 下评审偏好 SDE 输出 57.9%、偏好离散 LLaDA 26.6%。注意 ROUGE-1 与 perplexity 是不同口径的指标;且这是 DSL-LLaDA 自己 summarization setting 下的对照,与 §10 其它工作(如 ELF、Cola)的 GenPPL / MAUVE 不在同一任务上、不宜直接横比,本节不主张「打败」谁。
一个自然的质疑是:上面这些好处,会不会只是「在 LLaDA 上又 continue-pretrain 了一会儿」的功劳,跟连续噪声没关系?论文专门设了两个 同等 compute 的对照来排除这个解释:
结果是——这两个对照都没有上面那三个行为(低复读、低 NFE 摘要领先、选择性鲁棒)。所以差异不该归因于「额外训练量」,而更合理地归因于continuous-noise exposure 本身:是「把 binary mask 换成连续高斯软 mask」这一件事,让模型学会了在连续空间里延迟 commit、选择性纠错、稳定生成。
把它放回 §10 的统一视角:DSL 给出了「离散 mask = 连续 SNR 谱的端点」这套框架,DSL-LLaDA 则把它 scale 到 8B 并坐实了一个很实用的结论——一个现成的 masked dLLM,不必从头训练,只用 1000 步 continue-pretrain,把二值 mask 换成连续高斯软 mask,就能获得连续 embedding 空间去噪能力,并解锁连续扩散那一侧的好处(延迟 commitment、low-NFE 推理)。配合 MDM-CPT / XDLM 两个 same-compute 对照,作者把这些增益归因到「连续噪声暴露」本身,而非额外算力或随机 token 训练。对整章而言,这是「从 discrete dLLM checkpoint 低成本赋予 continuity」这条主线,目前最直接的一份 8B 真实证据。
延伸阅读:
真正的来源是 frozen T5-small 的 contextual embedding 已经包含了语言的几何先验。 ELF 不学"语言是什么",只学"如何在这个空间里 transport"。等效于一种 transfer learning,T5 的预训练成本 (Google C4 上约 1T tokens)没被算进 ELF 的 45.2B tokens 里。Tab 5 只比较了 ELF vs baseline 的自身训练 token,不计 encoder pretrain。 这是 ELF 数据效率 claim 最主要的 caveat。
合理的回应:可以把 ELF 看作"在 T5 表示上的 transfer-learning DLM"。MDLM/Duo 也间接用了 word-level tokenizer 的语言先验 (虽然程度不同)。但承认这是"有限的 12×",不是免费午餐。
大概率是的,但 paper 没做这个实验(App C.1 只 sweep 了 T5-small / base / large 三个 size)。 这是 ELF 最大的 follow-up 机会,也是它最脆弱的 claim—— ELF 的天花板可能就是 encoder 的天花板。如果有人用 LLaMA-3-8B hidden state + ELF flow 训一个 105M model,能不能逼近 LLaMA 自身 quality? 这个实验没人做过,是 obvious next step。
Paper App C.1 ablate 了:x-prediction 在所有 embedding 维度都稳定,v-prediction 在高维 degrade,ε-prediction 全面崩溃。 背后假设(paper 引用 Li & He 2511.13720):clean 文本数据在 embedding 空间是低维流形。 中间步加 CE 等价于"把噪声大的 z_t 强行映射回 token",这等于把 quantization wall 偷偷搬回来—— 破坏了 ELF 的"only final step rounding"核心 framing。
可以,但训练时把 SC-CFG 烤进 vtarget 后,推理只需要一次 forward,省一半算力。 代价是训练时3 次 forward(2 个 no-grad + 1 个 gradient-tracked)。 推理跑无数次,训练跑一次。划算。 这个 trick 是 image diffusion 圈 Chen et al. "Visual Generation without Guidance" (ICML 2025) 的方法,ELF 直接搬过来。 注意 ELF 仍然保留另一个推理时 CFG(input-cond CFG=2,用于 XSum/WMT),跟 SC-CFG 是两个独立机制。
论文用的就是 frozen GPT-2 Large 当 judge。Fig 7 里的 "Dataset" 虚线就是 OWT 真实文本在 GPT-2 Large 下的 reference PPL。 ELF-B 32 步 SDE 24.08 接近这条 reference;ELF-M 在 64 步 SDE / SC-CFG=3 下 21.69,更接近该 reference。 注意 "Dataset reference" 不是严格的理论下限—— 低于它也可能伴随低 entropy(即重复但流畅)。 而且这只是 GPT-2 Large judge 下的"流畅性"指标,不是通用质量下限: 换更强 judge(GPT-4 / Llama-3)数字会变。
App C.2 sweep 了 {32, 128, 512}(不含 256/64,论文没测)。 128 是 ODE / SDE 双采样下的 frontier balance 点: 32-d 在 SDE 下 PPL 最低但 entropy < 5(多样性不够),512-d 把 PPL 推高了。 背后假设:clean text 在 embedding space 是低维流形,128-d 就足够"覆盖"这个流形。 对比 image diffusion 也常用类似 bottleneck(DiT、SD 都有)。256 这个中间值 paper 没测,按 frontier 趋势插值应该 fine。
App C.5。Muon 对 2D 参数用 Newton-Schulz orthogonalization 把梯度先正交化再 step, 这抑制 ill-conditioned 方向,相当于 implicit second-order。 SDE 采样需要更精确的 v 预测(噪声重新注入会放大 v 的误差), Muon 训出来的模型在 v 上更"光滑",所以 SDE 推理时优势更大。 Paper 只定性说 "more pronounced under SDE",没给具体数字。 工程上 Muon 还有 fp32 NS5 + Nesterov bias correction 等 patches(详 §4.10)。
论文实验最长 seq=1024(OWT)和 1088(XSum)。当前 checkpoint 和 config 没有验证 8K 以上长度。
Architecture 上没有 causal LM 那种生成方向限制,但:
(a) RoPE buffer 按 max_length 预构建(要扩 8K 需重建或加 RoPE scaling);
(b) 全双向 self-attention 是 O(L²) 复杂度。
关于 "能 scale 到 8K context 吗"——architecture 上理论可行但 paper 没测;
可能需要 RoPE scaling + linear/sparse attention 改造,是 obvious extension。
VAE 路线的代价是 encoder 也得训,要解决 posterior collapse、latent drift、reconstruction trade-off 等额外问题。 ELF 用 frozen T5 把 encoder 部分外包给现成的预训练模型,所以 paper 短、ablation 干净、可以专心 ablate 7 个 ELF 自身的超参。 Cola 必须同时管理一大堆 VAE 超参(latent dim, block size, logSNR, KL ratio, anti-drift KL, multi-stage scheduling, ...)。 两边的代价不同:ELF 把复杂度外包给 Google 的 T5 训练;Cola 把复杂度内化到自己的训练 pipeline。
对,这是 generative few-shot 协议下的数字(把多选题转成生成)。 按 Cola 自己的 model card 与同协议对照,Cola 的 LAMBADA 接近部分 AR baseline,MMLU 明显偏弱。 但 Cola 论文强调的是 scaling 趋势:曲线还在涨,没饱和。 需要说明的是,Cola 当下的绝对分数确实不高——但它的卖点是架构可行性 + scaling shape,而非当前数字。
ELF 的 DiT 是全双向 attention,每步 forward 都要重新算所有位置的 attention。 KV cache 需要 causal mask(前面 K/V 缓存供后面 query 用),ELF 没这个结构。 这是 ELF 推理慢于 AR 的根本原因之一: 即使 ELF 32 步 << AR 1024 token,每步的 attention 复杂度还是 O(L²)。 Cola 用 block-causal 解决了这个——这是 Cola 的核心架构卖点。
都没有。但 ELF 更接近"clean scientific demonstration",Cola 更接近"engineering system"。 具体看:
| 资源 | 链接 |
|---|---|
| ELF paper PDF | arXiv 2605.10938 |
| ELF GitHub (官方 JAX) | lillian039/ELF |
| ELF PyTorch port | pytorch_elf branch |
| ELF HF checkpoints | embedded-language-flows |
| Cola-DLM paper | arXiv 2605.06548 |
| Cola-DLM GitHub | ByteDance-Seed/Cola-DLM |
| Cola-DLM blog | hongcanguo blog |
| Cola-DLM HF | ByteDance-Seed/Cola-DLM |
本文涉及论文的 arXiv 链接(2026 新工作的编号经 cross-model 核对)。
FM 家族(§3,含主角)
第三条路线 / Cola-DLM(§9–§10)
离散 DLM 基础
Continuous DLM 旧路(embedding / latent)
Factorized-gap 修补 / 流形 / trick