ARIS (Auto Research in Sleep) + ARIS-in-AI-Offer 工作流生成 · Continuous DLM Frontier · 2026-05

A Survey on Continuous DLM
Representation Perspective —— 连续表征该放在哪
2026 上半年,连续扩散语言模型沿这条线一步步演进(同名 v1 综述的表征视角详细版)

全词表直接扩散 → 借用冻结 encoder(ELF)→ 专用 VAE latent(Cola)→ 统一框架(DSL)。
沿途文章:CFM · FLM/FMLM · DFM · LangFlow · ELF · Cola-DLM · DSL / DSL-LLaDA 等 2026 上半年同期工作。

Ruofeng Yang(杨若峰) · Shanghai Jiao Tong University · 2026-05 · ARIS(Auto Research in Sleep)作者
Interest:Diffusion Models (Theory / Image / Video / Post-training) · DLLM · SSL · RL · Auto Research · Long-horizon Agent
锚论文:ELF arXiv 2605.10938 (Hu, Qiu, Lu, Zhao, Li, Kim, Andreas, He · MIT · 2026-05-11,非本文作者)· Code: lillian039/ELF · 代码审过 PyTorch port @ pytorch_elf

🎯 一张图先建立直觉(给熟悉图像扩散的读者):2026 上半年的 continuous DLM,可以粗略挂到图像扩散的两条熟悉坐标轴上 ——

空间轴:在什么表示空间里做扩散?压缩 / 抽象程度递增 →
全空间
DDPM图像 ↕ 文本FM 家族

全像素 ↔ 全词表 simplex / one-hot(§3

借来的表征
"借表征" latent diffusion图像 ↕ 文本ELF

借现成 encoder 特征 ↔ frozen T5 embedding(§2

专用 latent
Stable Diffusion / LDM图像 ↕ 文本Cola-DLM

专用 VAE latent ↔ learned VAE + DiT(§9

▽ 框架轴(正交,贯穿上面三格):不换扩散空间,换"扩散过程怎么参数化、能否统一多个范式"
EDM / Stochastic Interpolants图像 ↔ 文本DSL(§10

怎么读:前三格在空间轴(全空间 → 借来的表征 → 专用 VAE latent,从左到右压缩递增);DSL 在正交的框架轴,是贯穿三格的"统一框架"而非"更压缩的第四格"。这是帮你快速入门的直觉类比,不是算法同构:DDPM↔FM 只对应"原始全维、不压缩";ELF 的 latent 是 token-aligned、不降序列的 frozen 通用 encoder 特征,≠ 图像 VAE 压缩 latent;Cola 的 VAE 与 DiT 联合训练、不压序列(≠ SD 那种先训好再冻结的固定 VAE);DSL 强调的是 per-token SNR path 如何把多种范式收成特例。

📎 姊妹篇:图像 / 视频扩散侧「借表征 / 用流形」的完整脉络(SSL · Consistency Models · REPA · RAE · JiT · V-JEPA2 …),见同一作者的 《扩散模型 × 表征学习 × 流形学习》——本文「借 frozen T5 表征 / x-prediction」正是那条脉络在语言上的落点。

📌 TL;DR

这篇怎么读:以 ELF 为锚(§2、§4–§8)→ §3 把它放进同期 5 家 Flow-Matching 工作里横向定位 → §9–§10 沿「连续表征坐标系从哪来」这条轴,拉出三条并列路线:ELF(借 frozen T5 embedding)/ Cola-DLM(字节,学一个 VAE latent,§9)/ DSL(UC Riverside,token 单位球面上 stochastic localization、几何随 SNR 动态涌现,§10;含 8B 实证 DSL-LLaDA §10.6)。

ELF(Embedded Language Flows, MIT; Hu/Qiu et al., senior 作者 Kim/Andreas/He, arXiv 2605.10938, 2026-05-11) = 在冻结的 T5-small contextual embedding 空间里做连续时间 Flow Matching,只在最后一步离散化; denoiser 和 final-step decoder 共享同一个 Transformer 权重。 一句话定位:"continuous DLM 的瓶颈可能不在'连续'本身——而在过去工作把 encoder 放进训练联合学。ELF 把它冻结,diffusion 只学如何在已有几何中 transport,在其 ablation 中优于 learnable encoder 的选项。"

关键数字(完整对比见 §3 / §7):ELF-B 仅 105M 参数(+35M frozen T5)、OWT Gen-PPL 24.08(32 步、无蒸馏),训练只用 45.2B tokens——约同期 MDLM/Duo/FLM 的 1/12

Gen-PPL = 模型自采样经 GPT-2 Large 打分的 perplexity,越低越像自然语言。

三个值得记住的点(关于这条研究线,不只关于 ELF):

历史定位:2026 上半年 continuous-DLM 出现 6+ 篇集中工作。FM 家族按发布时间排: CFM (Feb)、FLM/FMLM (Feb)、DFM (Apr)、LangFlow (Apr)、ELF (May); 外加 ByteDance Cola-DLM (May 末) 的 latent-VAE 路线,以及 UC Riverside 的 DSL / DSL-LLaDA 的 unit-sphere stochastic-localization 路线。 三条路线各占一个位置:ELF 在 FM 家族里 size 最小(105M)、结构最简(无 distillation / 无 latent VAE),32 步无蒸馏即接近 dataset reference PPL;Cola 走最大尺度、reasoning-focused 的 latent-VAE 路线;DSL 把离散 masked diffusion 收成连续 SNR 谱的特例,统一 continuous / discrete / AR。 详见 §3 (FM 家族 5-way)、§9 (Cola 对比) 和 §10 (DSL 第三条路线 + 8B 实证)。

1 · 离散 DLM 的已有问题

1.1 离散 DLM 五年演进路线

年份工作主要贡献评测尺度
2021D3PM (Austin et al. NeurIPS)定义离散扩散框架:mask / uniform / embedding transitiontext8, LM1B 小尺度
2024SEDD (Lou et al. ICML Best Paper)Score entropy loss 给离散空间一个 clean 损失OWT scale
2024MDLM (Sahoo et al. NeurIPS)Masked diffusion ELBO ≡ weighted CE — 极大简化训练OWT scale
2025LLaDA (Nie et al.)第一个 8B-scale 离散 DLM,证明 scaling 可行8B params
2025Dream 7B (Ye et al.)大规模 diffusion LM(具体机制细节见原文)7B params
2026外部 landscapediscrete:BD3-LM (semi-AR block) / ReMDM / PRISM(test-time remasking);continuous-latent:VADD / LADD / HDLM / CADD;inference-time coupling:CoDD-style PC layer各种 scale

1.2 共同特征 + 共同瓶颈

这套路线都把 D-LLM 定义为:mask/uniform 离散腐蚀 → 逐位置独立 softmax reverse。 随着工作越做越多,发现它撞上四堵墙:

① 表示层 — Quantization Wall

Token 在嵌入前是孤立范畴点,几何邻近性需要从零学。 对比连续扩散在 image / video 上,模型可以利用像素几何("红"和"暗红"邻近), 但离散 token 空间"猫"和"狗"是几何上无关的两个 one-hot vector。 模型必须把所有词之间的语义距离从训练数据慢慢"学"出来。
结果:参数效率低 — 一个 7B 离散 DLM 学完 token 几何后剩下的容量才用来学语言。

② 建模层 — 因式分解瓶颈(Factorization Gap)

标准 D-LLM 的 reverse 是每个 token 独立的 softmax(可类比为图模型的 fully factorized reverse head,treewidth-0)。 真实联合后验 P(X0|Xt) 的所有 token 间依赖被一刀切掉。典型症状(示意):

  • 重复:"the the the"、"sample sample"
  • 退化句(repetitive/degenerate generations):开头一句生成完美,接下来一段全是重复模板
  • 长序列上累积误差:序列越长 factorization 误差越容易累积

③ 优化层 — 参数化别扭

Categorical score / transition 在数学上没有干净的 x / v / ε 对应物。 Image diffusion 圈两年发展出来的工具——self-conditioning, classifier-free guidance, flow matching, rectified-flow, EDM-style noise schedule——没有一个能直接搬到离散空间。 每个都要重新设计,速度被拖慢。

④ 推理层 — 离散错误累积

每步都要 unmask / 取 argmax / round;离散误差不可微地累积。 经验上需要很多步才能 recover:按 ELF 论文的 step-efficiency 对比,MDLM / Duo 等要用上千步(如 1024)才能接近 ELF-B 仅 32 步的 Gen-PPL 水平。 即使做 distillation(MDLM+SDTT, Duo+DCD),在极少步下仍难免 PPL 退化;ELF 的卖点是无需蒸馏就在 32 步取得低 Gen-PPL。 作为对比,image diffusion 已经发展出 consistency model / mean-flow / flow-distillation 等 1-2 步直采路线。

1.3 过去一年的"修"路线

方法在哪个接口"修"问题
CoDD用 tractable probabilistic / circuit layer 替换或增广 factorized output训练时还可能是 factorized;inference-time coupling 不改根本
CoDARContinuous latent diffusion + fixed encoder + separate / contextual AR decoder
(ELF Tab 2 把它列在 latent diffusion 类)
引入 AR decoder,部分丢了 DLM 全并行性
E2D2 / 类 block diffusionSemi-autoregressive:block 内 joint denoising,block 间 ARblock-AR 牺牲并行性
VADD / LADD / HDLM / CADD (外部)加 latent / hierarchical 结构架构越来越复杂,没有 clean theoretical story
ReMDM / PRISM-styletest-time remasking / search训练阶段不变,只动推理

这些都是 ——保持"离散 + factorized"这个根基不变,在边缘 patching。 ELF 的回答更激进:跳出离散空间。如果 token 嵌入已经被 T5 学好了, 为什么 diffusion 还要在离散符号上跑?为什么不直接在 T5 那个连续 + 有几何结构的空间里跑 flow?

2 · ELF 的核心想法 — 一个 "已知未知" 的连续空间

ELF Fig 1 hero plot
Fig 1 (paper p.1): ELF 在 OWT 上 32 步达到 Gen-PPL≈24,明显优于 MDLM/Duo/FLM/LangFlow,且无蒸馏、训练 token 远少。

2.1 两个"连续",一个"冻结"

ELF 在两个意义上是 continuous(paper §2 明确说 "continuous in two senses"):

但 ELF 还有一个关键的"冻结":encoder 不学。这是它和过去同类工作的最大区别。

💡 ELF 到底 denoise 在哪个 tensor 上?—— 位置 vs 词表 vs 维度

很容易混淆"每个 token 位置对应一个 vector"是不是指"每个词有一个 vector"。不是。先理清三个数字:

数字含义
1024序列长度——位置数量("句子里有 1024 个 token 位")
512T5-small contextual embedding 维度——每个位置的 vector 有多少维
32128词表大小——T5 SentencePiece 总共 32128 种不同的 token

ELF 的 tensor 流(OWT, L=1024, D=512):

input_ids:        [B, 1024]               ← 1024 个 token ID(每个是 0~32127 之间的整数)
                       ↓ T5-small frozen encoder
T5 contextual emb: [B, 1024, 512]         ← 每个"位置"得到一个 512-d vector
                       ↓ + noise
z_t (noisy):       [B, 1024, 512]         ← ELF denoise 就在这个 tensor 上
                       ↓ 32 步 Flow Matching denoising
clean embedding:   [B, 1024, 512]         ← 终点 ≈ 干净的 contextual embedding
                       ↓ 再 forward 一次 → factored decoder head → 32128 logits
vocab logits:      [B, 1024, 32128]       ← 只在 t=1 这一步才出现
                       ↓ per-position argmax
output token IDs:  [B, 1024]

关键澄清

对比离散 DLM(MDLM / LLaDA)

离散 DLM (MDLM)ELF
Denoise 状态token IDs [B, 1024](离散)contextual embedding [B, 1024, 512](连续)
每步输出空间32128 vocab 上的 softmax 分布512-d 连续 velocity
词表 32128 出现频率1024 步每步都做 vocab classification只在 t=1 一次

这就是 §2.4 / §9 里说的"ELF 跳到 continuous embedding,但仍然每个位置一个 vector,最后做 per-position CE"—— ELF 的连续性已经改了(中间不再 round 到 vocab),但 factorization 在 final CE 那一步仍然 per-position 独立(每个位置独立 softmax,没有 joint posterior 跨位置耦合)。

2.2 为什么"frozen contextual embedding"是这条路的关键

过去 5 年的 continuous DLM 路线(paper Tab 2 p.15 给了完整 landscape)分两条线:

共同问题:

  1. "鸡和蛋":embedding 学习和 diffusion 学习互相依赖,早期训练不稳定
  2. 多目标拉扯:embedding 既要满足 reconstruction,又要满足 diffusion topology,又要不 collapse — 容易陷入 trivial 解
  3. Per-step token CE:Tab 2 显示许多老工作在中间步就强加 token-level CE 监督("Train per-step discr.: Yes"),这相当于把"连续 transport"和"离散 classification"两件事强行耦合,部分把问题①(quantization wall)带回来了

ELF 在 Tab 2 里是唯一一行同时满足「frozen encoder + 训练/推理都加 per-step token CE + 另起 separate decoder」的——架构上最干净的设计点。

ELF 的 framing:把"语义几何"和"transport 动力学"两件事在架构上分离

💡 类比:摄像机标定 vs 拍电影

过去的方法:拍电影 + 现场标定摄像机 + 同时调灯光 — 三件事一起做,一砸就一片乱。
ELF:用已标定好的摄像机 (T5),专心拍电影 (diffusion transport)。 "语言什么样"的问题已经被 Google 用 T5 预训练(~1T tokens)解决了, ELF 不再重复学这个,把所有训练算力都集中在 transport 上。 45B vs 524B 训练 token 的差距就是这个 framing 的直接结果。

⚠️ 但字节 Cola-DLM 给了对立答案:joint training 更好

ELF 上面这套"冻结 encoder 才稳"的论证,在 2 周后字节的 Cola-DLM(2026-05 末)那里被正面反驳。 Cola 选择joint training——VAE encoder 和 DiT prior 一起从 Stage 2 开始联合优化,冻结 encoder。 他们的 RQ2/RQ3 ablation 关键发现:

所以 frozen-vs-joint 这个设计选择没有一个干净的答案。两边都对,只是 regime 不同:

规模谁的 ablation 占优解读
~100M-650M 参数 + ~45B tokensELF frozenencoder 容量已够,再 unfreeze 反而拖慢 transport 学习
~2B 参数 + ~2000 EFLOPsCola joint大算力下 encoder 也有 headroom,joint co-adapt 释放更多潜力

注意两边都没测对方的 regime——ELF 没在 2B+ 规模 sweep joint,Cola 也没在 100M 规模测 frozen。所以这个 tension 严格说是open question。详细分析见 §9 Cola-DLM 对比

2.3 T5-small encoder 为什么是好选择

方面T5-small (ELF 用)其它候选
大小35M (encoder-only)BERT-base 110M、RoBERTa 125M、Sentence-BERT 110M
训练任务Span-corruption denoising 预训练;后续 text-to-text multi-task transferMLM (BERT) / 对比 (Sentence-BERT)
表示性质contextual(同一个 token 在不同句子里不同向量)BERT 也 contextual;static embedding (word2vec/GloVe) 则不
几何性质span corruption 训练让模型学到"什么 span 合理",几何比较平滑BERT MLM 也类似
VocabSentencePiece 32128BERT WordPiece 30K
是否 generativeencoder-only 用法

论文没有 sweep encoder 选择(只 ablate T5-small / base / large 三个 size,App C.1)。 这是 ELF 设计中最值得追问的一点:能不能换成 BERT? Sentence-BERT? CLIP text? 一个 LLaMA-7B hidden state? 猜测:T5 的 span corruption 让 hidden state 几何特别平滑,恰好适合 flow matching。 但这是猜测,不是 paper claim。

2.4 "Contextual" vs "Static" embedding —— 借 T5 的语义流形当坐标系

一句话总结:ELF 没有自己学语言的几何,它直接搭便车坐在 T5 已经画好的"文本流形"上做 transport。

2.4.1 Static (非 contextual) embedding

最早的词向量做法(word2vec / GloVe / Diffusion-LM 用的 learned embedding matrix):

所以无论是 "I deposited money in the bank"(金融机构)还是 "We walked along the river bank"(河岸),`bank` 都被映射到同一个固定向量。模型自己没法从 embedding 那一层看出"这个 bank 是哪个意思"——只能靠后面的 transformer 自己消歧。

2.4.2 Contextual embedding

预训练 encoder(BERT / T5 encoder 等)的做法:跑一遍 transformer encoder over the whole sequence,每个 token 位置的输出向量都受整句话影响。

input_ids: [B, 1024]    一个句子,1024 个位置
   ↓ T5 encoder (6 层 self-attention + MLP)
output:    [B, 1024, 512]    每个位置一个 512-d 向量,
                              且这个向量是 attention over 所有位置算出来的

所以同样是 "bank":

2.4.3 几何对比

Static embeddingContextual embedding
存储lookup table [V, D]encoder forward [L, D] per sequence
"bank" 在不同句子里同一个向量不同向量
信息量V 个固定点(vocab=32128)几乎无穷多个点(每句话每位置都不同)
几何结构学到的"词类别"聚类语义+句法消歧后的"上下文状态空间"
来源训练时学一张表跑一遍 frozen encoder

2.4.4 对 diffusion 来说为什么重要 —— 借坐标系,不画地图只走路

T5 预训练(~1T tokens)已经把文本数据组织成了一个有结构的 512-d 流形:语义相近的句子 / 位置在这个空间里几何上也相近,上下文消歧、词性、句法关系都被编码进了向量的位置和方向。

ELF 把 T5 encoder 冻住当"坐标系"——diffusion 的工作变成了"在这张已经画好的语义地图上做 transport",而不是"先画地图再 transport"。

对比之下:

2.4.5 这就是为什么 105M + 45B tokens 就能换来有竞争力的 Gen-PPL

因为 ELF 实际能用的"语言知识"远不止 45B,而是 45B + T5 那 1T tokens 的预训练迁移。这一点也常被点出:

"45B 'tokens' exclude pretrained T5-small prior; frozen encoder doing real work. What at 2B from scratch?"(一句常见的质疑)

代价:上限被 T5-small 锁死。换 7B encoder 还有效吗?scale 到 LLaMA-70B 行不行?这是 ELF 的命门——也有评论认为 ELF 更像"T5 的生成式插件而非通用架构"(详见 §11 与 ELF 的命门讨论)。

2.5 最后一步离散化 — 共享 Transformer 的 decoder head

整个 flow 从 ε(标准高斯 × noise_scale=2.0)→ clean embedding 都在连续空间。 但 t=1 时模型必须输出离散 token。ELF 的做法:

  1. 主 transformer forward 已经预测 clean embedding x̂([B, L, 512])
  2. 最后另外跑一次 forward,decoder_step_active=True,这次 transformer 输出经过 factored decoder head(backbone 768-d hidden → 512 GELU → 32128)映射到 vocab logits
  3. 取 argmax 得到 token

关键:denoise 主干 + 离散化 decoder 是同一份 transformer 权重,靠 4 个 mode token 切换。 论文 App C.4 显示这种 in-context conditioning 比 adaLN-Zero 略好且省 43M 参数(详见 §5)。

ELF method overview Fig 2
Fig 2 (paper p.2): 离散 DLM vs 连续 DLM 的对比图。离散版每步都做 vocab classification;连续版整条轨迹在 embedding 空间里平滑流动。

3 · Flow-Matching 家族:5 篇同期工作横向对比

2026 年 2-5 月,一共出现了 5 篇基于 Flow Matching 的语言模型工作, 其中 ELF 是 paper 自己 Tab 2 显示的最后一行。这一节把这 5 篇放在一起对比,明确 ELF 在 FM 家族里的独特定位

3.1 五篇 paper 速写

CFM — Categorical Flow Maps (Roos et al., UvA/Oxford, arXiv 2602.12233, Feb 2026)。 核心决策:把 categorical generation 写成"endpoint-prediction flow map", 模型预测 simplex 上的 endpoint 分布 πs,t,再用 self-distillation 做少步生成。 评测 Text8 NLL + LM1B Gen-PPL,单步 LM1B Gen-PPL 274.87

FLM / FMLM — Flow Map Language Models (Lee et al., KAIST + CMU, arXiv 2602.16813, Feb 2026)。 核心决策:在 token-wise one-hot 连续空间建模,simplex-valued denoiser,CE 后验预测; FMLM 是 FLM 蒸馏后的少步版本。LM1B + OWT 评测, FLM 1024-step LM1B Gen-PPL 96.91、OWT 62.23;FMLM 单步 LM1B 119.34、OWT 168.30

DFM — Discrete Flow Maps (Potaptchik et al., Harvard/Oxford/MIT/NYU, arXiv 2604.09784, Apr 2026)。 核心决策:把 flow map 从 average velocity 重参数化为 mean denoiser ψs,t, 让 off-diagonal flow map 自然落在 probability simplex 上。直接批判欧氏 L2 与概率几何的不匹配。 LM1B 评测,DFM-ESD 单步 Gen-PPL 68.11(entropy 3.79);DFM-PSD 单步 94.08(entropy 4.06)

LangFlow — Chen et al., UIUC, arXiv 2604.11748, Apr 2026。 核心决策:回到 learned embedding-space,用 Bregman-divergence 把 token CE 解释为 Flow-Matching posterior matching; 推出 ODE-based NLL upper bound,为 embedding-space DLM 提供了一个可信的 likelihood 估计。 LM1B PPL 30.0 / OWT PPL 24.6(注意:是 held-out NLL upper bound,不是 Gen-PPL); 对应 Gen-PPL 是 LM1B 92.2 / OWT 36.5。

ELF — Embedded Language Flows (Hu/Qiu et al., MIT, arXiv 2605.10938, May 2026, 主角)。 核心决策:在冻结的 T5-small contextual embedding里做 FM,只在 t=1 用共享 transformer 切到 decode mode。 80% MSE + 20% CE,无蒸馏。OWT 32-step SDE Gen-PPL 24.08 ± 0.16

3.2 八维结构对比表 (点击列名 → 高亮该列 + 弹论文卡片)

维度 CFM FLM/FMLM DFM LangFlow ELF
State space Probability simplex ΔK−1(endpoint predictor πs,t one-hot 连续,simplex-valued denoiser Mean denoiser ψs,t: ℝK → ΔK-1 Learned token embedding (V×D matrix) Frozen T5-small contextual embedding (512-d)
Interpolant 几何 Straight-line stochastic interpolant Gaussian/one-hot interpolant + simplex repar Linear interpolant (β-reparameterized) VP γ-path, deterministic ODE Rectified linear: zt = t·x + (1−t)·ε
Training loss CE diagonal + endpoint consistency (CSD/ECLD) FLM: CE posterior; FMLM: KL/CE flow-map distillation CE diagonal + PSD/LSD/ESD KL consistency CE-as-Bregman (FM posterior matching) 80% MSE on velocity + 20% CE on decode
默认步数 1-step 主推(NFE 1-64 sweep) FLM: 1024-step / FMLM: 1-step 1-2-4-8 step 128-step (LM1B) / 1024-step (OWT) 32-step SDE (headline) / 64-step (scaling)
是否蒸馏 ✅ Self-distillation (CSD/ECLD) ✅ FMLM 从 FLM teacher 蒸馏 ✅ Diagonal 1M + off-diagonal 200k/100k ❌ 无 distillation,明说留给未来 无 distillation
Time grid Logit-normal w/ 0.75 diagonal fraction Decoding-error-rate reparameterization argmax-linearized + convex mix β̃(t) Learnable Gumbel scheduler Logit-normal (P_mean=−1.5)
Endpoint 处理 Simplex endpoint π; argmax 或采样 Simplex posterior; argmax Simplex mean denoiser; softmax Argmax over token probability 共享 transformer t=1 decode mode + argmax
Headline eval LM1B 1-step Gen-PPL 274.87 FLM OWT 1024-step 62.23; FMLM 1-step 168.30 LM1B 1-step Gen-PPL 68.11 (ESD) OWT Gen-PPL 36.5 @ 1024-step (PPL upper-bound 24.6) OWT 32-step Gen-PPL 24.08 ± 0.16

3.3 State space 几何 — 什么叫"在 simplex / one-hot 上走 flow"

§3.2 表第一行的 State space 是 5 篇 paper 的核心分歧点。CFM / FLM / DFM 一类说自己"在 simplex 上走 flow", LangFlow 说自己"在 learned embedding 上走",ELF 说自己"在 contextual embedding 上走"——这些到底什么意思?把它讲透。

3.3.1 One-hot 是什么

一个 token 表示成 vocab-size 长的向量,只有一位是 1,其他全是 0:

vocab V = 32128
token "bank" (id=4321) → one_hot 长度 V 的向量
   = [0, 0, ..., 0, 1, 0, ..., 0]
              ↑ 第 4321 位

整个句子(L=64 个 token)→ 形状 [64, 32128],每行是一个 one-hot。

3.3.2 Simplex 是什么

V 维空间里所有"概率分布"组成的集合:长度 V 的非负向量,且各分量之和 = 1。直观理解:

V=3 时的 simplex(直观图):

      (1,0,0)    ← token A 的 one-hot
        /\
       /  \
      / .  \    ← 内部任意点 = 概率混合
     / .. . \
    /________\
(0,1,0)    (0,0,1)
 token B    token C

3.3.3 "走 flow" 在不同 state space 是什么意思

Flow matching 这一族的训练对象是一条从噪声 ε 到数据 x 的路径;最常见、也是 ELF 用的形式是 rectified-linear 插值 zt = t·x + (1−t)·ε(§3.2 表里其余几篇用的是别的插值)。 关键问题是:x 长什么样、ε 长什么样、在哪个空间里走

路线x 长什么样走 flow 的空间哪几篇
One-hot / SimplexV 维 one-hot(或 simplex 内的概率向量)L × V(V ≈ 32k)CFM, FLM/FMLM, DFM
Learned embeddingD 维 static 向量(embedding lookup)L × D(D ≈ 128)LangFlow, Diffusion-LM
Contextual embeddingD 维 context-aware 向量(encoder 输出)L × D(D = 512)ELF

3.3.4 几何含义

想象 V=32128 维空间里数据点 x 长什么样:

路线数据 manifold 形状
ELF(contextual)T5 已组织好的"语义流形",相近词在附近
LangFlow / Diffusion-LM(static)学一张 lookup 表,V 个固定点的离散点云
CFM / FLM / DFM(one-hot/simplex)V 个相互垂直的"角"(汉明距离 = 2,谁离谁都一样远)

最关键的一点:one-hot 空间里,"bank" 和 "dog" 的欧氏距离 = "bank" 和 "shore" 的欧氏距离 = √2。没有任何语义几何——token 之间的相似性必须由 Transformer 自己从头学出来。

3.3.5 4-token 直观例子

假设词表只有 4 个 token:river / bank / money / shore

3.3.6 为什么说 simplex/one-hot 派"连地图都没有,纯靠 Transformer 后处理"

回到 §2.4 那个比喻:

这就解释了为什么 CFM / FLM / DFM 都需要更多 trick:

  1. 大模型容量:Transformer 自己要把"one-hot 空间 → 内部 hidden 语义空间"的映射学出来
  2. 蒸馏 / 重参数化:FLM 蒸馏到 1 步,CFM self-distill,DFM 改 mean denoiser——都是为了弥补"几何信息缺失"
  3. 必须用 CE/KL on simplex:纯 L2 在 one-hot 上不合理(详见 §3.4 ③),所以 loss 必须搬到概率几何上

对照之下 ELF 用 MSE 是合法的——因为 contextual embedding 不是概率分布,是带几何结构的语义向量,欧氏距离反映语义距离。

3.4 四个聚类轴

① State space

  • 纯 simplex 派:CFM + DFM — endpoint / mean denoiser 都约束在 ΔK-1
  • one-hot 派:FLM/FMLM — 状态是 V 维 one-hot,但学习对象重参数到 simplex posterior
  • Learned embedding 派:LangFlow — 学一份 V×D embedding matrix
  • Frozen contextual embedding 派ELF(独占) — 借 T5-small encoder 的语义几何

ELF 是唯一不学 embedding 也不约束在 simplex 上的方案。这是它最独特的设计点。

② Distillation vs base flow

  • 蒸馏派(少步压缩):CFM + FMLM + DFM — 目标都是 1-4 步生成
  • Base flow 派:FLM + LangFlow + ELF — 目标是基础质量,步数靠采样器调

ELF 选 "32-64 步直接采样" 这条路而不蒸馏,是反潮流的—— 当时所有 LM-FM 工作都在卷"少步"。

③ Loss geometry

  • 纯 CE/KL on simplex:CFM, FLM, DFM — "L2 regression 几何不对"
  • Bregman-CE 桥:LangFlow — "embedding 走 L2 / output 走 CE",靠 Bregman 连接
  • Euclidean MSE on embedding + final CEELF — 中间态不是概率分布,是 contextual vector,所以 MSE 合理

ELF 的 MSE 合法性建立在"embedding 不是分布"这个根本前提。这也是它和其它 4 篇的分水岭。

④ 评测范式

  • Gen-PPL(外部 LM 评生成质量):CFM, FLM/FMLM, DFM, ELF
  • Held-out NLL upper bound(likelihood / PPL)LangFlow 独占(同时也报 Gen-PPL)

注意:LangFlow LM1B 30.0 / OWT 24.6 是 ODE NLL upper bound,不是 Gen-PPL。 和 ELF 的 24.08(Gen-PPL)不可直接横比。可比的是 LangFlow OWT Gen-PPL 36.5(1024 步)vs ELF 24.08(32 步)。

3.5 两个 critical reading 问题

Q1: Euclidean vs simplex mismatch 怎么处理?

Q2: 能不能蒸馏到 1 步?

不对称的原因:CFM/FMLM/DFM 的状态在 simplex 或 one-hot 上,flow map 有 clean 的几何对象(mean denoiser)可学;ELF 的状态是 T5 embedding,要做 flow map distillation 需要先证明 T5 embedding 上的 flow map 也有 clean 形式——paper 没做。

3.6 ELF 对每个 sibling 的对比叙述

ELF vs CFM:两者都把离散文本放到连续 FM 里,但CFM 的连续对象是 simplex-valued endpoint,ELF 的是 T5 contextual embedding。 CFM 押注"少步 self-distillation"(单步 LM1B 274.87);ELF 押注"无蒸馏 32 步"(OWT 24.08)——注意两数跨不同数据集(LM1B vs OWT),不可直接横比,这里只对比路线取向。 两条完全互补的路线——蒸馏路 vs 采样路。(点击此标题 → 滚到 §3.2 表 + 高亮 CFM 列)

ELF vs FLM/FMLM:两者都做 continuous flow LM,也都在 OWT 上用 Gen-PPL; 但 FLM 用 V 维 one-hot(vocab 越大状态越大),ELF 用 512 维 T5 embedding + 128 bottleneck。 FLM 的卖点是天然可蒸馏到 FMLM;ELF 的卖点是无蒸馏即少步。 直接对比:FLM 1024 步 OWT Gen-PPL 62.23 vs ELF 32 步 24.08 —— ELF 在这一对照下明显更优,但代价是放弃了 vocab-level 状态空间的"天然性"。

ELF vs DFM:两者都认真做 ablation,但变量完全不同—— DFM 在变 PSD/ESD consistency 把 flow map 压到 few-step;ELF 在变 ODE/SDE/γ/SC-CFG 把 base flow 调到 32 步。 DFM 的强点是 1-4 NFE;ELF 的强点是无蒸馏 32 NFE。再次互补。

ELF vs LangFlow:两者都在 embedding-space 做 FM,也都报 OWT 数字;但 LangFlow 的 24.6 是 held-out NLL upper bound,ELF 的 24.08 是 Gen-PPL不能直接比。 可比项:LangFlow OWT Gen-PPL 36.5 (1024 步) vs ELF 24.08 (32 步) —— ELF 在生成质量上明显优。 反过来,LangFlow 有 likelihood story(可信 NLL bound + 4/7 zero-shot benchmark 超过 AR),ELF 目前没有对应 likelihood claim。 两边互补。

3.7 FM 家族一句话总结

这 5 篇本质在回答同一个问题:语言 FM 应该把"连续性"放在哪里? Simplex(CFM/DFM)、one-hot(FLM)、learned embedding(LangFlow),还是 frozen contextual embedding(ELF)
ELF 押注最后一种,并用无蒸馏 32 步 OWT Gen-PPL 24.08 证明这条路在 sample quality 上最有说服力。

从横轴到纵轴。§3 这条横轴把 ELF 放进同期 FM 家族里、确认了它"借 frozen contextual embedding"这一押注的独特性。但 ELF 借坐标系只是"连续表征从哪来"的一种答案——接下来 §4–§8 先把 ELF 这个锚讲透(训练、架构、采样、结果),然后 §9–§10 沿同一条轴拉出另外两条并列路线:字节 Cola-DLM 的"学一个 VAE latent"(§9),和 UC Riverside DSL 的"在 token 球面上做 stochastic localization、让几何随 SNR 动态涌现"(§10)。三条路线放在一起,才是这篇 blog 想给的完整图景。

4 · 训练 pipeline(ELF 算法核心)

💡 阅读前置:把 ELF 网络当黑盒

本节先把 ELF 网络当成一个抽象函数:

x̂ = netθ(z, t, c, ω, mode)

这一节不依赖具体架构——只要 netθ 是个能接受 (z, t, c, ω, mode) 的可学函数即可。具体 Transformer 实现(T5 encoder、DiT block、RoPE、bottleneck、factored decoder head 等)下一节 §5 展开。

💡 关键澄清:网络直接预测 clean embedding x̂,不是 velocity

ELF 用的是 x-prediction parameterization(App C.1 ablation 选出来的—— x-pred 在高维比 v-pred / ε-pred 稳定,因为"clean text 在低维流形上"):

# sampling_utils.py:115-127
def net_out_to_v_x(net_out, z, t, t_eps=5e-2):
    x = net_out                            # ← 网络输出直接就是 x̂(clean embedding pred)
    denom = torch.clamp(1.0 - t, min=t_eps)
    v = (x - z) / denom                    # ← velocity 是后处理算出来的
    return v, x

但这不代表跳过逐步去噪。推理仍然 32 步 Euler 迭代:

for step i in range(32):
    x̂  = net_θ(z_t, t, c, ω, "denoise")    # 每步预测 clean embedding
    v  = (x̂ - z_t) / (1 - t)               # 由 x̂ 换算瞬时 velocity
    z_t = z_t + dt · v                     # Euler 走一小步
# 最后一步:x̂ = net_θ(z_t≈x, t=1, c, ω, "decode") → unembed → token

三件事要分清:

真值
网络直接输出什么clean embedding x̂(一直在预测最终目标 x0
Flow Matching 框架下的 transport 量velocity v
推理过程仍然 32 步 Euler 迭代;每步用网络的 x̂ 算 v 再走一小步

为什么这么设计

  1. x-pred 的 target 固定(clean embedding x0),不随 t 变;v 和 ε 的 target 都随 t 变化,高维下难学
  2. App C.1 显示 ε-pred 在 768/1024 dim 直接 collapse
  3. 实际 MSE loss 形式上是 ‖v_pred − v_target‖²,但 v_pred / v_target 都从 x_pred / x0 换算,梯度其实在监督 x_pred → x0
App C.1 Fig 10 — Prediction targets
Fig 10 (paper p.21, App C.1): x-pred vs v-pred vs ε-pred 在三种 encoder 维度(T5-small 512 / T5-base 768 / T5-large 1024)下的 Gen-PPL ↔ entropy frontier。x-pred (橙色) 在全 dim 都稳定 + frontier 接近平行;v-pred 在 512 ok,768 退化,1024 严重退化;ε-pred 全 dim 都崩(落到红色 entropy < 5 或 PPL > 300 区)。

和 consistency model / mean flow 的区别:那些工作目标是真的 1 步从 noise 直接跳到 x(学一个"时间无关的 x-prediction")。ELF 没走那条路——仍是 32 步,但每步的局部预测对象是 x 而非 v。FMLM 就是 FLM 的 consistency-distilled 1-步版本;ELF 的 future work 也提了这个方向。

Training pipeline Fig 9
Fig 9 (paper p.16, App B.1): clean embedding x → corrupt → self-condition → add control → ELF net → MSE 或 CE loss。同一个 ELF 网络同时学 denoise 和 decode,靠 model-mode token 切换。
4.1 实际代码执行顺序(≠ 论文叙述顺序)

论文 App B.1 把 Alg 3 / 4 写得像两条独立 pipeline。PyTorch port 实际是这样执行(按 train_step.py):

  1. Label drop mask 应用到 T5 encoder attention mask(XSum/WMT 有,OWT 无)
  2. T5-encode + 归一化input_ids → x₀ ∈ [B, L, 512](x₀−μ)/0.2
  3. 为整个 batch 构造 denoiser corruption:抽 per-sequence t,加噪 z = t·x + (1−t)·ε·2.0
  4. Per-example 抽 decoder/denoiser gate(Bernoulli(0.2))
  5. Decoder 行另构造 per-token corruption:抽 per-token pz̃ = p·x + (1−p)·ε·5.0
  6. Mixz_mixed[row] = decoder_z if decoder_step_active[row]==1 else denoiser_z
  7. 第 1 次 no-grad forward:self-cond 输入 = zeros,用于构造 uncond reference
  8. 主 gradient-tracked forward:self-cond 输入 = stopgrad(uncond x_pred);CE + L2 都来自这次
  9. 第 2 次 no-grad forward:和第 8 步同输入,但用于构造 CFG cond reference
  10. CFG target 组装:v_target = v + (1−1/ω)(v_cond − v_uncond),.detach()
  11. Loss 合并(单一分母)+ grad clip 1.0 + Muon step + EMA

关键认识:CFG target 的第二个 no-grad forward 在主 gradient forward 之后。 数学目标不变(用 v_target 监督 v_pred),但实现顺序不是"两个 no-grad 然后一个 grad"。

4.2 论文 Algorithm 3 (denoiser) + Algorithm 4 (decoder)

💡 MSE 和 CE 在 ELF 里各扮演什么角色?(看 ELF 之前先看这段)

ELF 用两条不同的 loss 学两件不同的事。这是 ELF 整篇 paper 最容易被误解的点:

Loss学什么作用训练频率
LMSE(denoiser) 连续 embedding 空间里 transport 噪声到 clean embedding 学"动力学"——给定 noisy zt 和时间 t,怎么把 zt 沿 flow 推到 x。这是 Flow Matching 的核心。 80%(per-example Bernoulli(0.2) 抽 decoder 模式,剩下都是 denoiser)
LCE(decoder) 离散 token 空间里把 clean embedding 投影回 vocab 学"离散化"——把 flow 终点 (clean embedding ≈ x) 映射到 32128 个 token id。这是 ELF 唯一触及离散 token 的训练步骤。 20%

为什么必须有 CE? 你既然懂 MSE,问题就清楚了: MSE 只能让模型预测 clean embedding,但下游任务要的是生成 token。 如果只有 MSE,推理时拿到 clean embedding 后没办法回到 token(找最近的 T5 embedding 找最近邻?精度差且不可微)。 所以必须额外训练一个 decoder head,把 embedding 映射回 vocab logits —— 这个 head 必须用 CE 训练(因为 vocab 是离散的)。 ELF 的优雅之处是:decoder head 和 denoiser 共享同一份 transformer 权重,只用 4 个 mode token 切换 mode。

所以 ELF 的训练目标本质是:

Ltotal = 𝔼(s, c) [ (1−pdecode) · LMSE(transport) + pdecode · LCE(round-to-token) ]

其中 pdecode = 0.2 是 decoder 分支抽样概率(论文 Tab 4 默认)。

💡 在看 Algorithm 3 之前:什么是 self-conditioning(自条件)?

Self-conditioning(来自 Chen, Zhang, Hinton, "Analog Bits", ICLR 2023,ELF 引用 [9])是 diffusion / flow matching 的一个推理时迭代精修技巧:

普通 diffusion 每一步 forward 只看 ztt = net(zt, t)。 Self-cond 让网络额外接收上一步的预测作为输入:t = net(zt, t, x̂t−1)。 推理时这相当于"我已经有一个 partial 估计,refine 它",比从零预测更稳。

但训练时模型并没有"上一步预测" — 因此训练用以下 trick 模拟:

  1. 50% 概率跑两次 forward:第一次 self-cond 输入填 0,得到 x̂no_sc; 第二次 self-cond 输入 = stopgrad(x̂no_sc),对这次反传梯度。 这样模型学到"给 x̂prev 作 input 时怎么 refine"。
  2. 另外 50% 概率self-cond 输入直接填 0 — 让模型也学到"没有 prior 估计时怎么从 z 直接预测"(推理第 1 步用得到)。

为什么要 stopgrad?因为不希望梯度通过第一次 forward 回流—— 那会让训练目标变成"让 x̂no_sc 也参与优化",破坏 mathematical setup。 ELF 的 self-cond 完全沿用 Chen et al. 这个标准做法。

实现上 self-cond 把网络输入维度从 D 扩到 2D(concatenate [z, x̂prev]), 然后用一个 self_cond_proj: 2D → D 线性层压回 D。 所以你在 Alg 3 看到的 self_cond_proj(concat([z, ...], dim=-1)) 就是这个压回操作;不是 ELF 独创

💡 那 ELF 为什么要用 self-conditioning?为什么不就是 plain Flow Matching?

有 5 个理由,按重要性排序:

  1. 推理时迭代精修 — 同样步数下质量更好
    Plain FM 的 32 步采样 = 32 次独立预测,每步只看 zt。 带 self-cond:每步还看上一步的 x̂i−1,等价于"看着你之前的答案 refine"。 轨迹更平滑,达到同质量需要的总步数少 2-3 倍。这是 self-cond 的原始动机
  2. x-prediction 让 self-cond 特别契合
    ELF 每步直接预测 clean x̂(见上面"关键澄清")。自然的语义是"我这步的 x̂ 比上步更准吗"。 self-cond 显式把 x̂i−1 当 input,让网络有"refine"的语义把手。 v-prediction 没这个直觉——v 是局部量,前一步的 v 跟当前步关系弱。
  3. self-cond 是 SC-CFG (Eq 3) 的 prerequisite ← 这一条最被忽略。
    SC-CFG 的 "C" 就是 self-Conditioning。Eq 3 的 (1−1/ω)·(v_cond − v_uncond) 里:

    • "uncond" = self-cond 输入填 0
    • "cond" = self-cond 输入 = stopgrad(x̂_no_sc)

    没 self-cond 就没 SC-CFG,没 SC-CFG 就没有"推理只跑 1 次 forward"的训练时 CFG 优化(见 §4.4)。 所以整套 Eq 3 训练时 CFG trick 都建立在 self-cond 之上。

  4. Chen et al. 2023 证明 discrete-target diffusion 没 self-cond 不行
    "Analog Bits" 的原始动机:在 discrete data(token / quantized image)上做 diffusion, 没 self-cond 质量差一大截。即使 ELF 在 continuous embedding 空间跑,最终输出仍是离散 token —— ELF 也吃这个红利。
  5. Cost 不大,所以"为啥不用"反而需要理由
    训练 forward 数:1 → 最多 3 次(no-sc + sc + gradient),但训练只跑一次(5 epoch)。 推理 forward 数:完全不变(OWT 无条件 1 forward/step,因 SC-CFG 已 baked-in)。 代价主要是训练时间约 1.5×,换来更稳的轨迹 + 训练时 CFG 优化 + 离散 target 上更好的质量。

所以 Alg 3 里 self-cond 占大位置 ≠ 概念复杂

Alg 3 里 self-cond 看起来占大半篇幅是因为记账复杂,不是概念复杂。剥掉 plumbing 后本质就是 "plain FM + Chen 2023 的 self-cond + Eq 3 的训练时 CFG"。下面这 5 个 step 都是 implementation 细节,不是 5 个独立 ELF 创新:

Alg 3 里的步骤它在干嘛
x_no_sc = net(concat([z, 0]))no-cond reference,self-cond 输入填 0
stopgrad(x_no_sc)不让梯度回流第一次 forward(否则训练目标污染)
x_sc = net(concat([z, stopgrad(x_no_sc)]))self-cond reference forward
(1−1/ω)·(v_sc − v_no_sc)SC-CFG guidance 烤进 v_target(Eq 3)
where(self_cond_mask, ..., ...)50% 概率二选一(让模型也学"没 prior" 的情况,覆盖推理第 1 步)

一句话:self-cond 是 ELF 站在 Chen 2023 + DiT-style 训练时 CFG 这两个肩膀上的"工程接缝"。 拿掉 self-cond 也能跑 plain FM,但你会失去 (a) 32 步达到同质量的能力 (b) Eq 3 训练时 CFG 这两个 ELF 系统级卖点。

论文 App B 写成两个独立算法。PyTorch port 改成 per-example mix,数学上等价(见 §4.6)。两个算法的关键差异:

Denoiser (Alg 3)Decoder (Alg 4)
Mode token gate0(无 mode 信号)1(mode token 激活)
时间 tper-sample t = σ(N(Pm=−1.5, Ps²=0.8²))t = 1(始终终点)
Corruption ratioper-sample tper-token p = σ(N(0.8, 0.8²))(独立!)
Noise scale2.05.0 (OWT) / 1.0 (XSum, WMT)
Self-cond input50% stopgrad(x'); 50% zeros始终 zeros(不学 self-cond)
LossMSE on velocity(base + CFG-augmented target)CE per token
Output headFinalLayer (768→512 flow output)Factored decoder (768→512→32128)

Algorithm 3 伪代码(denoiser,paper App B.1,去 LaTeX):

# Algorithm 3: ELF denoiser training with conditioning and guidance
# net(z, t, c, w, mode): ELF network with in-context conditioning
# self_cond_proj(z): concat-to-original-dim projection
# self_cond_prob: 0.5
# s: discrete token sequence
# c: condition (optional, only for XSum/WMT)

x = encode(s)                               # T5 frozen forward
t = sample_t()                              # logit-normal scalar per sample
w = sample_sc_cfg_scale()                   # ω ∈ [0.5, 5], paper: power-biased
                                            # PyTorch port: shifted log-uniform
e = randn_like(x)                           # standard Gaussian(与 paper 一致)
z = t * x + (1 - t) * e                     # rectified-flow interpolant
v = x - e                                   # base velocity target(e 为 standard Gaussian = 论文 noise_scale=1 抽象写法;PyTorch port 实际 noise_scale=2.0,见 §4.3)

# (1) 不带 self-cond 的 forward (no_grad)
z_no_sc = self_cond_proj(concat([z, zeros_like(z)], dim=-1))
x_no_sc = net(z_no_sc, t, c, w, mode="denoise")
v_no_sc = (x_no_sc - z) / (1 - t)

# (2) 带 self-cond 的 forward (no_grad, stopgrad on x_no_sc)
z_sc = self_cond_proj(concat([z, stopgrad(x_no_sc)], dim=-1))
x_sc = net(z_sc, t, c, w, mode="denoise")
v_sc = (x_sc - z) / (1 - t)

# (3) CFG target: post-combination quantity
v_target = v + (1 - 1/w) * (v_sc - v_no_sc)

# Per-example self-cond mask
self_cond_mask = uniform(B) < self_cond_prob
v_pred   = where(self_cond_mask, v_sc, v_no_sc)
v_target = where(self_cond_mask, v_target, v)
v_target = stopgrad(v_target)

loss_denoise = mse_loss(v_pred, v_target)

逐行拆解:v_target 是怎么构造出来的?(Alg 3 最绕的 5 行)

看 Alg 3 最容易迷的就是 (3) 之后 wherestopgrad 那 5 行。它们做了 4 件事:

# (3) CFG target: post-combination quantity
v_target = v + (1 - 1/w) * (v_sc - v_no_sc)     # 计算"带 CFG"的 target

# Per-example self-cond mask
self_cond_mask = uniform(B) < self_cond_prob     # 每个 example 独立抽 Bernoulli(0.5)
v_pred   = where(self_cond_mask, v_sc, v_no_sc)  # gradient forward 走哪条
v_target = where(self_cond_mask, v_target, v)    # target 选哪个
v_target = stopgrad(v_target)                    # 截梯度

关键 insight:这是同一个 batch 同时训练两种模式。 mask 把 batch 切两半:50% 走 self-cond + CFG 路径,50% 走 plain FM 路径。 这两条路径必须用同一个 mask挑 v_pred 和 v_target,否则配对错乱。

mask 值输入有 self-cond?v_pred (gradient)v_target网络学到什么
True (50%) 有,self-cond 输入 = stopgrad(x̂_no_sc) v_sc(带 self-cond 的 forward) (x − ε) + (1 − 1/ω)·(v_sc − v_no_sc) "给 prior 估计 + ω,输出 CFG-amplified velocity"
False (50%) 没有,self-cond 输入 = 0 v_no_sc(不带 self-cond 的 forward) v = (x − ε)(base velocity) "没 prior,给 plain FM velocity"(推理第 1 步用得到)
三个不容易看出来的细节
  1. 为什么 v_pred 和 v_target 必须用同一个 mask?
    它们是配对的。mask=True 的 example:gradient forward 输入带 self-cond,那 target 也必须是"with-self-cond 想达到的样子"。 如果 mask 不对齐——比如 v_pred 是 v_sc 但 target 是 base v——网络会被告知"用 self-cond 输入预测 no-self-cond 的 target",等价于教网络不用 self-cond 输入也行,破坏整个 self-cond 机制。
  2. 为什么 CFG 只 apply 到 mask=True 的 example?
    SC-CFG 的本质是"放大 self-cond 信号"(cond = 有 prior,uncond = 没 prior,按 ω 外推)。 mask=False 的例子连 self-cond 输入都是 0,不存在 cond/uncond 的对偶,没什么可放大。 所以这种 example 的 target 就是 plain FM 的 base velocity (x − ε),没有 CFG 项。
  3. stopgrad 在防范什么?
    v_no_sc 和 v_sc 都在 no_grad 上下文里算出来的(已经无梯度),v = (x − ε) 来自 leaf tensor x 和 ε 也无梯度。 所以理论上 v_target 本身就没有 gradient flow 进网络。 stopgrad 是防御性的——(a) 文档上明确"v_target 是 supervision,不可微";(b) 防止有人 fork 代码后去掉 no_grad 时还能保住语义。 这是好的工程习惯。
那剩下的 loss_denoise 是什么?

就是 mse_loss(v_pred, v_target) 算 L2 距离:

这里 mse_loss 展开是什么?

就是 velocity 上的 L2 距离 + per-token mean,标准 Flow Matching 损失。具体计算:

# v_pred ∈ [B, L, 512],来自 ELF 主干 gradient-tracked forward 后转 velocity:
#   x_pred = ELF_transformer(z_self_cond, t, c, ω, decode_mode=False)
#   v_pred = (x_pred - z) / clamp(1 - t, t_eps)            # paper Eq 4

# v_target ∈ [B, L, 512],由 Eq 3 训练时 CFG 构造(前面 4.4 节详):
#   v_base    = (x_0 - z) / clamp(1 - t, t_eps) = (x - ε·noise_scale)
#   v_target  = v_base + (1 - 1/ω) · (v_cond - v_uncond)
#   v_target  = v_target.detach()                          # 梯度只走 v_pred

l2_per_token = ((v_pred - v_target) ** 2).mean(dim=-1)     # [B, L]  channel-wise mean
l2_per_token *= loss_mask                                  # 排除 padding + cond positions

数学上 ELF 的 MSE 等价于:

LMSE = 𝔼(x, c) 𝔼t, ε [ Σi ∈ valid ‖ vθ(zt, x'prev, t, c, ω)i − vtarget,i22 / D ]

其中:

符号含义
xfrozen T5-small encode 后的 contextual embedding,归一化后(× 1/0.2 = ×5)
ztnoisy embedding:zt = t·x + (1−t)·ε·noise_scale,其中 noise_scale = 2.0
tper-sequence corruption 时间,t ~ σ(N(−1.5, 0.64)),logit-normal
ε标准高斯 ∈ ℝD,D = 512(T5-small encoder dim)
vθELF 主干预测的速度场(实际预测 x_pred,再转 v = (x_pred − z) / (1−t))
vtarget训练时 CFG 烤进的 target velocity(Eq 3,.detach() 截梯度)
x'prevself-conditioning 输入(50% 概率 = stopgrad(uncond x_pred);50% = zeros)
‖·‖22 / Dchannel 维(512)取均值,等价于 per-channel MSE
valid 位置排除 padding + 条件生成的 cond positions

Algorithm 4 伪代码(decoder,paper App B.1):

# Algorithm 4: ELF decoder training with conditioning and guidance
x = encode(s)
p = sample_per_token_p()                    # logit-normal PER TOKEN (different from denoiser)
w = sample_sc_cfg_scale()
e = randn_like(x)                           # standard Gaussian(与 paper 一致)
z = p * x + (1 - p) * e                     # per-token corruption ratio

# decoder always uses zero self-cond input
z = self_cond_proj(concat([z, zeros_like(z)], dim=-1))
h = net(z, t=1, c, w, mode="decode")        # mode token gate = 1
s_pred = unembed(h)                         # factored decoder head

loss_decode = ce_loss(s_pred, s)

这里 ce_loss 展开是什么?

就是标准的 per-token cross-entropy,没有任何 ELF-specific 改造。具体计算(来自 src/train_step.py):

# decoder_logits ∈ [B, L, 32128],由 ELF 主干 forward + factored decoder head 给出:
#   hidden_768 = ELF_transformer(z̃, t=1, c, ω, decode_mode=True)[B, L, 768]
#   hidden_512 = GELU_tanh(hidden_768 @ proj_kernel + proj_bias)     # 768 → 512
#   decoder_logits = hidden_512 @ unembed_kernel + unembed_bias       # 512 → 32128

log_probs    = F.log_softmax(decoder_logits.float(), dim=-1)          # [B, L, 32128]
ce_per_token = -log_probs.gather(-1, target_ids.unsqueeze(-1)).squeeze(-1)
# ⇔ -log p_θ(s_i | z̃_i, t=1, c, ω, mode=decode)                     # [B, L]

# 然后 mask 掉 padding 位置 + 条件生成的 cond 位置,再聚合
ce_per_token *= loss_mask   # loss_mask = (1 - cond_seq_mask) * attention_mask

数学上 ELF 的 CE 等价于:

LCE = − 𝔼(s, c) 𝔼pi, εi [ Σi ∈ valid log pθ(si | z̃i, t=1, c, ω, mode=decode) ]

关键变量含义:

符号含义
si第 i 个位置的真实 token id(ground truth)
i该位置的 per-token corrupted clean embedding: z̃i = pi·xi + (1−pi)·εi·noise_scale, 其中 pi ~ σ(N(0.8, 0.64)) 每个 token 独立抽,noise_scale = 5.0 (OWT) / 1.0 (cond)
t=1decoder 分支永远在终点;时间 token 编码 t=1
c条件输入(XSum / WMT 才有;OWT 无条件 c=∅)
ωSC-CFG scale ∈ [0.5, 5](虽然 decoder 分支不学 SC-CFG guidance,但 ω token 仍 prepend 作为输入)
mode=decode4 个 mode_token gate=1(denoiser 分支时 gate=0,token 被乘 0)
pθ(·)factored decoder head 输出的 softmax 分布(768 → 512 GELU → 32128 → softmax)
valid 位置排除 padding token + 条件生成的 cond positions(cond 不要求模型预测)

三个容易混淆的点

  1. decoder 分支永远 t=1。不像 denoiser 分支 t 从 logit-normal 抽,decoder 总是在"clean 端点"做 final-step decoding。 但 z̃i 仍然有 corruption——只是 corruption ratio pi 从一个独立的 logit-normal 分布抽(per token,不是 per sequence),让 decoder 见过 noisy embedding 也能还原 token。
  2. decoder 分支不带 SC-CFG guidance。Eq 3 的 self-conditioning CFG 只用于 denoiser 分支的 LMSE。 decoder 分支自己只跑一次 forward,输入 self-cond 部分填 0,直接算 CE。
  3. per-token corruption,不是 per-sequence。这是 decoder 和 denoiser 最大的差别—— denoiser 的 t 是 (B,) scalar per 序列;decoder 的 p 是 (B, L, 1) tensor per 位置。 意义:模拟推理时不同 token 位置 reconstruction 质量不同(有些 token 在 ODE 轨迹上接近 clean,有些差远了)。

4.3 Embedding corruption — 两套独立的 logit-normal schedule

两条分支用不同的时间/腐蚀分布。Paper App B.1 + 代码 src/configs/training_configs/train_owt_ELF-B.yml 默认:

分支分布P_meanP_stdNoise scale说明
denoiserper-sequence logit-normal−1.50.82.0t = σ(N(−1.5, 0.64)) → 偏向小 t(噪声多)
decoder (OWT)per-token logit-normal0.80.85.0p = σ(N(0.8, 0.64)) → 偏向大 p(接近 clean);每个 token 独立
decoder (XSum/WMT)per-token logit-normal0.80.81.0条件生成用更小 noise

核心思路:denoiser 训练时多接触噪声大的样本(学习如何 transport), decoder 训练时多接触接近 clean 的样本(学习如何 round 回 token)。 而且 decoder 是每个 token 独立抽 corruption ratio,模拟推理时不同位置的 reconstruction 质量。

💡 两个细节背后的设计意图:(1) "大 p = clean" 的约定 (2) per-sequence vs per-token 的 granularity 差

(1) 为什么"大 p = clean"?

ELF 用 Lipman et al. 2023 rectified-flow 标准插值公式:

z = p · x + (1 − p) · ε · noise_scale

p 值z 长什么样翻译
p = 0z = ε · 5.0 = 全是噪声完全 noisy
p = 0.5z = 0.5·x + 0.5·ε·5.0一半信号一半噪声
p = 1z = x = 完全 clean embedding完全 clean

所以 p 是"信号比例",1−p 是"噪声比例"。p 越大,z 里 clean x 成分越多。 denoiser 的 t 同理:z_t = t·x + (1−t)·ε·2.0,t=1 是 clean 端点,t=0 是噪声端点。

(2) 为什么 denoiser t 是 per-sequence,decoder p 是 per-token?

关键原则:训练分布必须匹配推理分布。两条分支推理时见到的样子完全不同:

denoiserdecoder
推理时跑在哪 整条 32 步 ODE 轨迹,所有 token 位置同步沿 t 推进 32 步完后跑一次 forward,把 zt=1 映射到 token
推理时各 token 位置的 corruption 状态 同一时刻 t,所有位置共享同一个 corruption level 不同位置 ODE 落地质量不均——有的 95% 干净,有的 70% 干净(attention 把噪声去除得不均衡)
训练 sample 单位 per-sequence t ~ logit-normal(−1.5) per-token pi ~ logit-normal(+0.8)
shape t: (B,) → broadcast 到 (B, 1, 1) p: (B, L, 1) 每个位置独立
动机 所有位置同 t 才能让 attention 学到正确的 transport 动力学 per-token p 训练 decoder 在per-position 残余噪声不均时仍能 round 回 token

Denoiser 不能 per-token:如果训练时同一序列内不同位置取不同 t,attention 看到"有的位置很 noisy 有的很 clean"——但推理时根本不会出现这种状态。模型会学到错误的 transport 模式。

Decoder 必须 per-token:推理时 32 步 ODE 落地的 zt=1 不是完美 clean,不同位置残余噪声幅度不同(attention 在 32 步里对各位置的去噪进度不同步)。 Per-token pi 训练 decoder看上下文判断该位置真正的 token——paper App B.1 原话: "encourages the shared-weight decoder mode to recover corrupted embeddings from their surrounding context, making final-step discretization more robust to imperfect embeddings produced by the denoiser at inference time."

注意 p 不作为网络输入

per-token pi 只通过 z̃i 的"信号/噪声混合比例"隐式传递给网络——网络直接知道某个位置的 pi 是多少。网络的时间输入是固定的 t=1 scalar,decoder 通过 attention 看周围上下文自动判断该位置可信度。

对比一图看清

denoiserdecoder
P_mean−1.5(偏小 t)+0.8(偏大 p)
σ(P_mean) 中位~0.18(多 noisy)~0.69(多 clean)
大部分样本落在[0.05, 0.4] noisy 区[0.5, 0.95] clean 区
Granularityper-sequence(match transport 同步性)per-token(match 终点 per-position 不均)
Noise scale2.0OWT 5.0 / 条件 1.0
# sampling_utils.py::add_noise (denoiser)
def add_noise(x0, noise, t, config, cond_seq_mask=None):
    t_expanded = t.reshape(-1, 1, 1)                       # [B, 1, 1]
    z = t_expanded * x0 + (1 - t_expanded) * noise * config.denoiser_noise_scale
    if cond_seq_mask is not None:                          # 条件生成时保留 cond token 不腐蚀
        z = cond_seq_mask * x0 + (1 - cond_seq_mask) * z
    return z

# train_step.py — decoder branch (per-token p sampling)
decoder_z_vals = (
    torch.randn((B * L,), dtype=dtype, device=device)
    * config.decoder_p_std + config.decoder_p_mean         # = N(0.8, 0.8²)
)
decoder_lambda_t = torch.sigmoid(decoder_z_vals).reshape(B, L, 1)  # [B, L, 1] per-token
decoder_noise = torch.randn(x0.shape, dtype=dtype, device=device) * config.decoder_noise_scale
decoder_z = decoder_lambda_t * x0 + (1 - decoder_lambda_t) * decoder_noise

4.4 训练时 CFG(Eq 3,全文最 tricky 的点)

💡 在看 Eq 3 之前:CFG 是什么?为什么 ELF 要做"训练时"版本而不是标准做法?

Classifier-Free Guidance (CFG)(Ho & Salimans, "Classifier-Free Diffusion Guidance", 2022) 是 image diffusion 的一个条件信号放大器

训练时让一个网络同时学有条件 vθ(z, t, c) 和无条件 vθ(z, t, ∅)(10% 概率把 c 置空)。 推理时按系数 ω > 1 把两者外推:

vfinal = vuncond + ω · (vcond − vuncond)

直觉:朝"和无条件的方向"反方向多走一步,等价于更强地遵从条件 c。ω=1 是原始条件 forward;ω=3 时条件信号被显著放大(细节更锐利、和 prompt 更贴)。 代价:推理每步要跑 2 次 forward(cond 一次 + uncond 一次),算力 ×2。 这是 DALL·E / Stable Diffusion / Imagen 等图像 diffusion 的标配。

ELF 的问题:每步 ×2 forward 太贵

ELF OWT 默认 32 步 × 1 forward = 32 forward。如果套标准 CFG → 64 forward,推理算力翻倍。 更要命的是:ELF 把 sampling step 从 baseline 的 1024 步压缩到 32 步是它的核心卖点,再 ×2 就失去了步数效率优势

解决方案:训练时 CFG(Chen et al. "Visual Generation without Guidance", ICML 2025)

这个 trick 是 Chen 等人 2025 年 image diffusion 的工作(ELF App C 引用)。核心想法: 不在推理时做 CFG 组合,而是在训练时就把 CFG 组合烤进 v_target。 让模型直接学 post-combination quantity v_θcfg,推理只需一次 forward 就拿到等价于 CFG 后的 velocity。

标准推理时 CFG训练时 CFG (Chen 2025 / ELF)
训练 forward 数1(只学 v_cond / v_uncond 之一)3(uncond + cond + grad-tracked)
推理 forward 数 / 步2(cond + uncond)1
32 步总推理 forward6432
训练 1.5× ↔ 推理 2× 谁划算训练 5 epoch 1 次,推理跑无数次 — 训练时 CFG 明显更划算

ELF 把训练时 CFG 应用在哪?— Self-cond CFG,不是 input-cond CFG

注意 ELF 的 c 在 Eq 3 里不是文本 condition,而是self-conditioning 输入(见前面 self-cond 那个 callout):

所以 ELF 实际上把两个独立的 CFG 机制分开处理:

机制训练时还是推理时放大什么OWT 用XSum/WMT 用
SC-CFG(Eq 3)训练时(baked-in)self-conditioning 信号ω=3ω=1(已 baked,不额外推理)
Input-cond CFG(标准)推理时文本 condition 信号—(无条件)ω=2(推理时 ×2 forward)

为什么不把 input-cond CFG 也烤进训练? 因为 input-cond CFG 的 ω 通常需要在推理时 sweep 不同值(找最佳质量),烤进训练就锁死了。 SC-CFG 的 ω 在 ELF 里是固定行为(不用 sweep),所以烤进训练划算。 这是一个 nuanced 的工程权衡。

一句话总结

Eq 3 = Chen 2025 训练时 CFG 这个 image diffusion trick,移植到 ELF 的 self-conditioning 维度, 让 32 步无条件采样的算力优势在加 CFG 后仍然保住。剥掉这个 trick, ELF 要么没 SC-CFG(质量低),要么推理 32 步 ×2 forward(失去步数效率卖点)。

Image diffusion 的 classifier-free guidance 通常是推理时跑两遍 forward 然后线性组合:

推理时 CFG: vfinal = vuncond + ω · (vcond − vuncond)

ELF 把它烤进训练——模型直接学 post-combination quantity,推理只需一次 forward。Paper Eq 3:

vtarget = (x − ε) + (1 − 1/ω) · (vθcfg(zt | t, c, ω) − vθcfg(zt | t, ∅, ω))

📐 Eq 3 的 (1 − 1/ω) 系数怎么推出来?

这个系数不是 ad hoc 写的,是从标准 inference-time CFG 5 步代数推导出来的。Chen et al. ICML 2025 "Visual Generation without Guidance" 的核心贡献就是这个 reparameterization。

Step 1 · 标准 CFG 公式(Ho & Salimans 2022 推理时 CFG):

vfinal(ω) = vu + ω·(vc − vu) = vc + (ω − 1)(vc − vu)

vc, vu 是"不加 CFG"时网络对 cond / uncond 的 logical base 输出。

Step 2 · 训练时 CFG 的网络已经学 post-combination quantity

ELF 网络输出 vθcfg(c, ω) 直接就是 vfinal(推理不再组合)。所以:

两者之差已经被 ω 预放大

vcfgc − vcfgu = [vu + ω(vc−vu)] − vu = ω·(vc − vu)

Step 3 · 反推 logical base 差

vc − vu = (1/ω)·(vcfgc − vcfgu)

Step 4 · 构造 training target

我们要让网络的 vcfgc 收敛到 vfinal。代入 Step 3:

vtarget = vc + (ω − 1)·(vc − vu)
= vc + (ω − 1)·(1/ω)·(vcfgc − vcfgu)
= vc + (1 − 1/ω)·(vcfgc − vcfgu)

所以 (1 − 1/ω) 等于 (ω − 1) × (1/ω) 化简的结果——前一项是标准 CFG"超越 cond 的放大量",后一项是把"已被 ω 预放大的差值"还原回 logical base。

Step 5 · FM target 替换 vc

base FM target 就是 vc = x − ε(rectified-flow 标准 target),代入得 Eq 3:

vtarget = (x − ε) + (1 − 1/ω)·(vcfgc − vcfgu)

边界检查

ω(1 − 1/ω)vtarget含义
10x − ε无 CFG,退化为 plain FM ✓
20.5(x−ε) + 0.5·(vcfgc − vcfgu)中等放大
30.667(x−ε) + 0.667·(...)ELF 默认 SC-CFG=3
50.8(x−ε) + 0.8·(...)SC-CFG 上限
1(x−ε) + (vcfgc − vcfgu)极限放大

ω=1 退化 ✓、ω→∞ 系数趋于 1 ✓、ω=3 对应 ELF OWT 默认 ✓。

这里的 trick 是 self-cond 的"condition"不是输入文本 c,而是 self-cond 输入 x'。 所以"uncond"就是 x'=0,"cond"就是 x'=stopgrad(net1)。CFG scale ω∈[0.5, 5] 是 ELF 自己也学的输入参数(4 个 SC-CFG token 编码)。

实现需要 3 次 forward。代码实际执行顺序是:

  1. v_uncond:no_grad,self-cond 输入 = zeros(最早跑,作 baseline)
  2. v_pred带梯度 forward,self-cond 输入 = stopgrad(uncond 的 x_pred)。L2 loss 用这个 v_pred
  3. v_cond:no_grad,与第 2 次同输入,用于构造 v_target(不是 v_pred)

梯度只过第 2 次。最终 L2 loss = ‖v_pred − stopgrad(v + (1−1/ω)(v_cond − v_uncond))‖²。

# src/train_step.py — Eq 3 的自条件 CFG target 构造(简化版)

def compute_shared_uncond(z, t_input, x_tokens):
    # forward #1: self-cond input = zeros
    z_uncond = restore_cond(torch.zeros_like(z), x_tokens, cond_seq_mask)
    z_input_uncond = torch.cat([z, z_uncond], dim=-1)
    with torch.no_grad(), autocast(bf16):
        net_out_uncond = model(z_input_uncond, t_input,
                               self_cond_cfg_scale=self_cond_cfg_scale)
    return net_out_uncond

def get_sc_cond_and_uncond(z, t_input, cond_mask, x_tokens, shared_net_out_uncond):
    v_uncond, x_uncond = net_out_to_v_x(shared_net_out_uncond, z, t_input, t_eps)
    x_uncond = restore_cond(x_uncond, x_tokens, cond_mask)

    # forward #2: self-cond input = stopgrad(x_uncond)
    z_input_cond = torch.cat([z, x_uncond], dim=-1)         # 注意 stop-grad on x_uncond
    with torch.no_grad(), autocast(bf16):
        net_out_cond = model(z_input_cond, t_input,
                             self_cond_cfg_scale=self_cond_cfg_scale)
    v_cond, _ = net_out_to_v_x(net_out_cond, z, t_input, t_eps)
    return v_cond, v_uncond

def get_sc_guided_v(z, t_input, base_v_target, x_tokens, shared_net_out_uncond):
    v_cond, v_uncond = get_sc_cond_and_uncond(...)
    sc_w = self_cond_cfg_scale.reshape(B, 1, 1)
    sc_guidance = (1 - 1 / sc_w) * (v_cond - v_uncond)      # ← Eq 3 第二项
    sc_guidance = torch.where(use_self_cond_mask.bool(),
                              sc_guidance,
                              torch.zeros_like(sc_guidance))
    return (base_v_target + sc_guidance).detach()           # ← .detach() 是关键

# forward #3 (gradient-tracked) 在主 batch forward 里:
net_out, decoder_logits = model(model_input, t_mixed,
                                self_cond_cfg_scale=self_cond_cfg_scale,
                                decoder_step_active=decoder_step_active)
v_pred, _ = net_out_to_v_x(net_out, denoiser_z, t, t_eps=0.05)   # gradient-tracked
v_final_target = get_sc_guided_v(denoiser_z, t, base_v_target=v_target, ...)
l2_per_token = ((v_pred - v_final_target) ** 2).mean(dim=-1)     # MSE

💡 为什么不直接推理时 CFG?

推理时 CFG 要每步跑两遍 forward(cond + uncond),32 步采样 = 64 次 forward。 ELF 把 SC-CFG 烤进训练后,推理只跑 1 次 forward / step,效率 ×2。代价:训练时 3 次 forward。但训练只跑一次(5 epochs),推理跑无数次。划算。

注意 ELF 还另外保留了输入条件的推理时 CFG(label drop 训练的)。两个机制独立: SC-CFG 用 Eq 3 的 self_cond_cfg_scale(推理时通常 = 1,因为已烤入); input-cond CFG 是标准推理时 cond+uncond 组合(XSum/WMT 默认 = 2)。 label drop 的训练信号只上游改变条件 embedding 状态,不进入 Eq 3 的 SC-CFG 公式。

4.5 Per-example branching — 工程实现 vs 论文算法

Paper Algs 3+4 写成两个独立 training step,按 0.8 / 0.2 概率轮换。PyTorch port 改成:

# src/train_step.py — 关键混合 forward
# 每行独立 Bernoulli(0.2) → decoder mode;否则 → denoiser mode
decoder_step_active = torch.bernoulli(
    torch.full((B,), config.decoder_prob, dtype=torch.float32),
    generator=gen,
).to(device=device, dtype=dtype)             # (B,) 1.0=decode 0.0=denoise
decoder_mask_B11 = decoder_step_active.view(-1, 1, 1)
decoder_mask_B1  = decoder_step_active.view(-1, 1)

# t、z 都按 per-example 混合
denoiser_t = t                                # logit-normal per-sample
decoder_t  = torch.ones_like(t)               # 永远 1
t_mixed = decoder_step_active * decoder_t + (1.0 - decoder_step_active) * t
z_mixed = decoder_mask_B11 * decoder_z + (1.0 - decoder_mask_B11) * denoiser_z

# 单次 forward — mode token gate 也是 per-example
net_out, decoder_logits = model(
    model_input, t_mixed,
    self_cond_cfg_scale=self_cond_cfg_scale,
    decoder_step_active=decoder_step_active,   # (B,) per-row gate
)

# CE / L2 各自用 mask 分流
loss_mask_f = loss_mask.to(ce_per_token.dtype)
ce_mask = loss_mask_f * decoder_mask_B1
l2_mask = loss_mask_f * (1.0 - decoder_mask_B1)

# 关键:单一分母归一化(不是两个分母)
total_sum = (ce_per_token * ce_mask).sum() + (l2_per_token * l2_mask).sum()
loss = total_sum / torch.clamp(loss_mask_f.sum(), min=1.0)
# loss_mask_f = ce_mask + l2_mask,所以这等价于 sum/(ce_mask.sum()+l2_mask.sum())
# 注意:pad_token=="pad" 时 loss_mask 屏蔽 padding;pad_token=="eos" 时全 1

4.6 等价性证明(简版)

论文 Algs 3+4 是两个独立 step,按 (1−p):p 比例轮换(p = decoder_prob = 0.2)。 PyTorch port 是同一个 step 内 per-example 抽 mode。先注意: 每个 example 内所有 token 共享同一个 mode 抽样结果(不是 token-wise i.i.d.)。

记 b ∈ (0, 1) 为 example 的 mode 指示(1 = decode),B0 = denoiser 行数,B1 = decoder 行数。 固定 batch、固定 loss denominator M = loss_mask.sum(),PyTorch port 的 loss:

Lcode = [ Σb=1 行 Σtoken CE + Σb=0 行 Σtoken L2 ] / M

对 mode 抽样取期望(每行 Bernoulli(p)):

𝔼[Lcode] = (p · 𝔼row[Σ CE | decode] + (1−p) · 𝔼row[Σ L2 | denoise]) · (B / M)

= 𝔼[Lpaper] · (per-example token 数加权和) — 等价于 paper Algs 3+4 的固定 batch-size 加权期望

注意

4.7 输入条件 CFG — Label drop 机制(XSum / WMT 才有)

条件生成时还需要另一个 classifier-free guidance,针对 input condition(不是 self-cond)。 10% 概率把 cond sequence 直接 drop(zero out),让模型学到 p(x | ∅) 分布:

# src/train_step.py — label drop for input-condition CFG
if config.label_drop_prob > 0:
    drop = label_drop_mask.to(dtype=torch.float32).reshape(-1, 1, 1)  # (B, 1, 1) 0/1
    cond_mask = cond_seq_mask                                          # (B, S)
    # block_mask: 1 仅在 (non-cond row, cond col)
    # 目的:让 target token 看不到 cond token
    block_mask = (1 - cond_mask).unsqueeze(-1) * cond_mask.unsqueeze(1)
    encoder_attention_mask = encoder_attention_mask * (1 - drop * block_mask)

label drop 实际两阶段

  1. 先改 T5 encoder 的 encoder_attention_mask,让 target token 看不到 cond token (block_mask 只在 non-cond row × cond col 上为 1) — 这样 T5 encode 出来的 x₀ 本身就不含条件信息
  2. Encode 之后再把 dropped 行的 denoiser_zx₀ 在 cond 位置清零torch.where(drop & cond_seq_mask, zeros, denoiser_z))— 匹配 paper "zeroing condition embeddings"
4.8 完整训练超参(论文 Table 4 + PyTorch port 实现细节)
类别参数默认值
Optimizer & Schedule OptimizerMuon(2D 参数走 Newton-Schulz)+ Nesterov-AdamW(其余)
LR (peak)0.002(公式:blr=0.001 × global_batch / 256)
LR scheduleconstant after warmup
Warmup0.5 epoch(5 epochs ≈ 86K steps:45.2B ÷ (512×1024);对应 ~8.6K warmup steps)
Weight decay0(关闭 — Muon 自带 shape scaling)
Grad clip1.0 (norm)
Batching Global batch size512
Sequence length1024
Grad accumulation1 (硬件够大不用)
Diffusion Denoiser P_mean / P_std / noise scale−1.5 / 0.8 / 2.0
Decoder P_mean / P_std / noise scale0.8 / 0.8 / 5.0 (OWT)
Decoder prob (per example)0.2 (denoiser 0.8)
Self-cond prob0.5(denoiser only;decoder 永远 0)
CFG SC-CFG range[0.5, 5],power-bias sample 偏小值
SC-CFG tokens4
Label drop prob (cond only)0.1(XSum/WMT),0(OWT)
Numerics Precisionbf16 autocast;输出头强制 fp32
EMA decay0.9999
Random seed42(per-rank seed + rank offset,让噪声 desync)
训练量 OWT epochsELF-B: 5 ELF-M: 4 ELF-L: 3(大模型收敛快)
OWT 总 tokens≈ 45.2B(OWT 数据集约 9.04B × 5 ep)
HardwareTPU v5p × 64,1.5h/epoch(ELF-B)
4.9 训练 token 用量对比(Table 5)— ELF 的 12× 数据效率 claim 来源
MethodBase trainingDistillation trainingEffective tokensRatio vs ELF
MDLM512 × 1M × 1024524.3B11.6×
Duo512 × 1M × 1024524.3B11.6×
MDLM + SDTT512 × 1M × 1024512 × 10K × 5 × 1024550.5B12.2×
Duo + DCD512 × 1M × 1024512 × 10K × 5 × 1024550.5B12.2×
FLM512 × 1M × 1024524.3B11.6×
FMLM512 × 1M × 1024512 × 100K × 1024576.7B12.8×
LangFlow512 × 1M × 1024524.3B11.6×
ELF (ours)5 × 9.04B45.2B1.0×

Baseline 估算公式:batch_size × n_steps × seq_length。ELF 用 OWT 总 token × epochs。 精确 ratio 是 11.6× / 12.2× / 12.8×,paper 文字简称"约 12×"。 注意:ELF 的"45.2B"不包括 T5-small 预训练(Google 用了 1T+ tokens 训 T5)—— 这是 ELF 数据效率 claim 的主要 caveat。

4.10 Muon 优化器简介

Muon = "Momentum + Newton-Schulz orthogonalization",2024 Keller Jordan 推。核心思想:

ELF 的 PyTorch port 用的是 PyPI muon-optimizer 包,加几层 wrapper / patches: (a) Nesterov-bias-corrected Adam update(替换上游 adam_update); (b) NS5 强制 fp32 + eps 1e-8(替换上游 bf16 + 1e-7); (c) Muon update 重写 + shape scaling layout 修正(区分 nn.Linear 和 bare Parameter); (d) _SafeMuonAuxAdam subclass:zero-fill missing grads + distributed all_gather padding 修复。 全部为了匹配 JAX optax.contrib.muon

论文 App C.5 ablation:Muon vs AdamW,SDE 采样下差距尤其大(paper 只定性写"more pronounced under SDE",没给数字)。

⚠️ 复现注意事项

5 · 架构、参数量、shape(最详细一节)

5.1 三种 ELF 大小(论文 Tab 3)
ModelDepthHiddenHeadshead_dimMLP ratioBottleneckParams (DiT)OWT Epochs
ELF-B127681264128105M5
ELF-M2410561666128342M4
ELF-L3212801680128652M3

三个模型共享同一份 frozen T5-small encoder (~35.3M 参数,encoder-only,不参与训练)。条件生成(XSum / WMT14)时 T5-small encoder 也用来编码 source context — 因此条件 ELF-B 报作 "105M+35M"。 Decoder 没有独立训练 — denoise 和 decode 用同一份 Transformer 权重,靠 model-mode token 切换。

5.2 T5-small encoder 规格(来自 HuggingFace t5-small
Vocab size32128(SentencePiece vocabulary)
Layers6 (encoder only)
Hidden d_model512
Attention heads8 (head_dim 64)
FFN d_ff2048
ActivationReLU (非 gated)
Params~35M
训练时整个 encoder 都 requires_grad_(False),作 frozen embedder

训练 ELF 时永远只 forward 这个 encoder 一次(per batch),bf16 autocast 跑,last_hidden_state shape [B, L, 512]。然后做归一化:(x − latent_mean) / latent_std, 默认 latent_std = 0.2 等于把 raw T5 embedding 放大 5×。

5.3 控制 token 设计(12 个,prepend 到序列前)

不在 vocabulary 里,是可学的 nn.Parameter + 输入相关 embedding 的加和

类别个数编码作用
t_emb_tokens4learnable_tokens[1,4,d] + TimestepEmbedder(t)[B,d].unsqueeze(1)把当前扩散时间 t∈[0,1] 注入
self_cond_cfg_tokens4learnable_tokens[1,4,d] + TimestepEmbedder(ω)[B,d].unsqueeze(1)把 self-cond CFG scale ω∈[0.5, 5] 注入(训练时随机抽,推理时固定)
mode_tokens4learnable_tokens[1,4,d] × active_gatedenoise mode → gate=0(token 被乘 0);decode mode → gate=1(token 激活)

三组合起来共 12 个 prefix tokens。最终序列顺序是 [time(4) + sc_cfg(4)] + [mode(4)] + [main_x(L)] —— 代码先 cat([mode, main]),再 prepend [time + sc_cfg](见 5.12 forward 第 4-5 步)。 重要细节:RoPE 在所有 prefix(12 个) 上 cos=1, sin=0(不旋转), 主序列从位置 0 开始正常 RoPE 编码。这样添加 prefix 不会破坏 main token 的相对位置。

5.4 Bottleneck 设计(App C.2 ablation)

T5 embedding 512-d → bottleneck → DiT hidden。直觉:clean 文本数据其实在 低维流形上。 论文 App C.2 sweep 了 bottleneck ∈ (32, 128, 512):

App C.2 Fig 11 — Bottleneck ablation
Fig 11 (paper p.21, App C.2): bottleneck ∈ {32, 128, 512} 在 ODE 和 SDE 下的 Gen-PPL ↔ entropy 曲线。32-d 在 SDE 下能压到最低 PPL 但落到红色 entropy < 5 的退化区;512-d 维持高 entropy 但 PPL 飙升;128-d 是 frontier 平衡点,所以是 ELF 默认。
class BottleneckTextProj(nn.Module):
    def __init__(self, text_encoder_dim, hidden_size, bottleneck_dim):
        super().__init__()
        self.proj1 = nn.Linear(text_encoder_dim, bottleneck_dim, bias=False)  # 512 → 128
        self.proj2 = nn.Linear(bottleneck_dim, hidden_size,    bias=True)     # 128 → 768

    def forward(self, x):                       # [B, L, 512]
        return self.proj2(self.proj1(x))        # [B, L, 768]
5.5 RoPE — prefix 不旋转的 trick
class TextRotaryEmbeddingFast(nn.Module):
    def __init__(self, dim, pt_seq_len=512, ft_seq_len=None,
                 theta=10000.0, num_empty_token=0):
        super().__init__()
        ft_seq_len = ft_seq_len or pt_seq_len
        # 标准 RoPE 频率
        freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[:dim//2].float() / dim))
        # 位置缩放支持 fine-tune 时 seq_len 与 pretrain 不同
        pos   = torch.arange(ft_seq_len).float() / ft_seq_len * pt_seq_len
        freqs_main = torch.einsum('..., f -> ... f', pos, freqs)
        freqs_main = repeat(freqs_main, '... n -> ... (n r)', r=2)

        # prefix 的 cos=1, sin=0 → 不旋转
        if num_empty_token > 0:
            cos_prefix = torch.ones((num_empty_token, freqs_main.shape[-1]))
            sin_prefix = torch.zeros_like(cos_prefix)
            freqs_cos = torch.cat([cos_prefix, torch.cos(freqs_main)], dim=0)
            freqs_sin = torch.cat([sin_prefix, torch.sin(freqs_main)], dim=0)
        self.register_buffer("freqs_cos", freqs_cos, persistent=False)
        self.register_buffer("freqs_sin", freqs_sin, persistent=False)

    def forward(self, t):                       # [B, n_heads, L_total, head_dim]
        cos = self.freqs_cos.to(t.dtype)        # 显式 dtype cast 防 bf16 精度漂移
        sin = self.freqs_sin.to(t.dtype)
        return t * cos + rotate_half(t) * sin
5.6 TimestepEmbedder — sinusoidal + 2-layer MLP
class TimestepEmbedder(nn.Module):
    # t (scalar in [0,1]) -> hidden vector (e.g., 768)
    # Init: MLP weights normal(0.02), biases zero
    def __init__(self, hidden_size, frequency_embedding_size=256):
        super().__init__()
        self.mlp_0 = nn.Linear(frequency_embedding_size, hidden_size, bias=True)
        self.mlp_2 = nn.Linear(hidden_size, hidden_size, bias=True)

    @staticmethod
    def timestep_embedding(t, dim=256, max_period=10000):
        half = dim // 2
        freqs = torch.exp(-math.log(max_period) * torch.arange(half).float() / half)
        args  = t[:, None].float() * freqs[None]
        return torch.cat([torch.cos(args), torch.sin(args)], dim=-1)   # [B, 256]

    def forward(self, t):                       # [B]
        emb = self.mlp_0(self.timestep_embedding(t, 256))   # [B, 768]
        return self.mlp_2(F.silu(emb))                       # [B, 768]
5.7 Attention — qk-norm + RoPE 都加上
class Attention(nn.Module):
    def __init__(self, dim=768, num_heads=12, qkv_bias=True, qk_norm=True,
                 attn_drop=0.0, proj_drop=0.0):
        super().__init__()
        self.num_heads = num_heads
        head_dim       = dim // num_heads                  # 64 for ELF-B
        self.qkv       = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.q_norm    = RMSNorm(head_dim) if qk_norm else nn.Identity()
        self.k_norm    = RMSNorm(head_dim) if qk_norm else nn.Identity()
        self.proj      = nn.Linear(dim, dim, bias=True)

    def forward(self, x, rope_fn, attention_mask=None):
        B, N, C = x.shape                                   # N = 1036 在 ELF-B 训练时
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
        qkv = qkv.permute(2, 0, 3, 1, 4)                    # [3, B, n_heads, N, head_dim]
        q, k, v = qkv[0], qkv[1], qkv[2]

        q = self.q_norm(q)                                   # qk-norm(论文与官方实现都开)
        k = self.k_norm(k)
        if rope_fn is not None:                              # 应用 RoPE(含 prefix-no-rotation)
            q = rope_fn(q)
            k = rope_fn(k)
        # 实际 layers.py 里包了一层 wrapper:int/float mask -> bool mask 再传 SDPA
        x = scaled_dot_product_attention(q, k, v, attn_mask=attention_mask)
        return self.proj(x.permute(0, 2, 1, 3).reshape(B, N, C))

注意三件事:(1) qk_norm=True 是 ELF 默认开(提升 bf16 训练稳定性,从 Henry et al. EMNLP'20); (2) 注意 F.scaled_dot_product_attention 内部用 PyTorch 2.x flash kernel; 源码 wrapper 在 mask 是 2D/3D 时 reshape 加 head 维并 cast 为 bool; (3) ELF 训练/采样调用模型时都不传 attention_mask(cond+target 全双向 attention); 条件生成时 source/target 的不对称处理由 ELF 数据管线在拼接时构造,不是 T5 encoder 自带(标准 T5 encoder 本身是双向 self-attention)。

5.8 SwiGLU FFN
class SwiGLUFFN(nn.Module):
    def __init__(self, dim, hidden_dim, drop=0.0):
        super().__init__()
        # SwiGLU 标准做法:把 hidden 缩到 2/3 保持 param count
        hidden_dim_eff = int(hidden_dim * 2 / 3)             # 768*4 = 3072 → 2048
        self.w12 = nn.Linear(dim, 2 * hidden_dim_eff, bias=True)   # 768 → 4096
        self.w3  = nn.Linear(hidden_dim_eff, dim, bias=True)       # 2048 → 768

    def forward(self, x):                                    # [B, N, 768]
        x1, x2 = self.w12(x).chunk(2, dim=-1)                # 各 [B, N, 2048]
        return self.w3(F.silu(x1) * x2)                      # [B, N, 768]
5.9 FinalLayer — zero-init
class FinalLayer(nn.Module):
    # Last layer that maps hidden 768 -> embedding output 512.
    def __init__(self, hidden_size, patch_size, out_channels):
        super().__init__()
        self.norm_final = RMSNorm(hidden_size)
        # 关键: kernel + bias 都用 0 初始化(DiT 标配)
        # → 开局时模型预测的 clean x_pred ≡ 0;
        #   velocity 由后处理 v=(x_pred - z)/(1 - t) 计算
        self.linear = nn.Linear(hidden_size, patch_size**2 * out_channels, bias=True)
        nn.init.zeros_(self.linear.weight)
        nn.init.zeros_(self.linear.bias)

    def forward(self, x):                                    # [B, L, 768]
        return self.linear(self.norm_final(x))               # [B, L, 512]
5.10 ELFBlock — 标准 Pre-Norm Transformer block
class ELFBlock(nn.Module):
    def __init__(self, hidden_size, num_heads, mlp_ratio=4.0):
        super().__init__()
        self.norm1 = RMSNorm(hidden_size, eps=1e-6)
        self.attn  = Attention(hidden_size, num_heads, qkv_bias=True, qk_norm=True)
        self.norm2 = RMSNorm(hidden_size, eps=1e-6)
        self.mlp   = SwiGLUFFN(hidden_size, int(hidden_size * mlp_ratio))

    def forward(self, x, rope_fn, attention_mask=None):
        x = x + self.attn(self.norm1(x), rope_fn, attention_mask)
        x = x + self.mlp(self.norm2(x))
        return x

注意是 Pre-Norm + 残差(RMSNorm 在 attn/mlp 之前)。这是 LLaMA / DiT / GPT-J 等长 bf16 训练的常见稳定选择。

5.11 ELF.__init__ — 完整参数声明(删减为关键部分)
class ELF(nn.Module):
    # 源码默认 num_model_mode_tokens=0, vocab_size=0
    # train.py 实际从 config / tokenizer 把 4 / 32128 传进来
    def __init__(self, text_encoder_dim=512, max_length=1024,
                 hidden_size=768, depth=12, num_heads=12,
                 mlp_ratio=4.0, bottleneck_dim=128,
                 num_time_tokens=4, num_self_cond_cfg_tokens=4,
                 num_model_mode_tokens=4, vocab_size=32128):
        super().__init__()
        # Self-conditioning projection: [z; x_pred] (2×512) -> 512
        self.self_cond_proj = nn.Linear(2 * text_encoder_dim, text_encoder_dim, bias=True)

        # Bottleneck text projection (512 -> 128 -> 768)
        self.text_proj = BottleneckTextProj(text_encoder_dim, hidden_size, bottleneck_dim)

        # Prefix tokens + their (input-dependent) embedders
        self.t_embedder      = TimestepEmbedder(hidden_size)
        self.t_emb_tokens    = nn.Parameter(torch.empty(1, num_time_tokens, hidden_size))

        self.self_cond_cfg_embedder = TimestepEmbedder(hidden_size)
        self.self_cond_cfg_tokens   = nn.Parameter(
            torch.empty(1, num_self_cond_cfg_tokens, hidden_size))

        self.mode_tokens = nn.Parameter(torch.empty(1, num_model_mode_tokens, hidden_size))

        # RoPE with prefix no-rotation
        head_dim     = hidden_size // num_heads
        prefix_total = num_time_tokens + num_self_cond_cfg_tokens + num_model_mode_tokens
        self.feat_rope = TextRotaryEmbeddingFast(
            dim=head_dim, pt_seq_len=max_length, num_empty_token=prefix_total)

        # 12 ELFBlocks
        self.blocks = nn.ModuleList([
            ELFBlock(hidden_size, num_heads, mlp_ratio) for _ in range(depth)])

        # Flow-matching output head (zero-init)
        self.final_layer = FinalLayer(hidden_size, patch_size=1, out_channels=text_encoder_dim)

        # Factored decoder unembedding: 768 -> 512 (GELU) -> vocab
        self.proj_kernel    = nn.Parameter(torch.empty(hidden_size, text_encoder_dim))
        self.proj_bias      = nn.Parameter(torch.empty(text_encoder_dim))
        self.unembed_kernel = nn.Parameter(torch.empty(text_encoder_dim, vocab_size))
        self.unembed_bias   = nn.Parameter(torch.empty(vocab_size))
5.12 ELF.forward — 完整 forward(删减+加注释)
def forward(self, x, t, self_cond_cfg_scale=None, decoder_step_active=None):
    # x: [B, L, 512]  if no self-cond, OR  [B, L, 1024]  if self-cond cat([z, x_pred])
    # t: [B]                                            扩散时间 ∈ [0,1]
    # self_cond_cfg_scale: [B] or None                  ω 用 SC-CFG token 编码
    # decoder_step_active: None | True/False | Tensor[B]   控制 model-mode token gate
    B = x.shape[0]

    # ====== 1. self-cond projection: 2C -> C ======
    if x.shape[-1] == 2 * self.text_encoder_dim:       # 训练 & 推理都走这里
        x = self.self_cond_proj(x.float())             # [B, L, 1024] -> [B, L, 512]

    # ====== 2. bottleneck text projection ======
    x = self.text_proj(x.float())                      # [B, L, 512] -> [B, L, 768]

    # ====== 3. build prefix context tokens ======
    time_emb = self.t_embedder(t)                      # [B, 768]
    prefix   = self.t_emb_tokens.expand(B, -1, -1) + time_emb.unsqueeze(1)  # [B, 4, 768]
    if self_cond_cfg_scale is not None:
        sc_emb  = self.self_cond_cfg_embedder(self_cond_cfg_scale)          # [B, 768]
        prefix2 = self.self_cond_cfg_tokens.expand(B, -1, -1) + sc_emb.unsqueeze(1)
        prefix  = torch.cat([prefix, prefix2], dim=1)                       # [B, 8, 768]

    # ====== 4. model-mode tokens with per-example gating ======
    if decoder_step_active is None:
        gate = 0.0                                     # 默认 denoise: tokens 被乘 0
    elif isinstance(decoder_step_active, torch.Tensor):
        gate = decoder_step_active.view(-1, 1, 1)      # [B, 1, 1]  per-example
    else:
        gate = float(decoder_step_active)              # 1.0 in final decode
    mode_tokens = self.mode_tokens.expand(B, -1, -1) * gate                 # [B, 4, 768]
    x = torch.cat([mode_tokens, x], dim=1)                                  # [B, L+4, 768]
    model_mode_offset = self.num_model_mode_tokens     # = 4

    # ====== 5. prepend (time + sc-cfg) tokens ======
    # 最终顺序: [time(4), sc-cfg(4), mode(4), main(L)]
    x = torch.cat([prefix, x], dim=1)                                       # [B, L+12, 768]
    prefix_len = prefix.shape[1]                       # = 8 (time + sc-cfg)

    # ====== 6. 12 ELFBlocks with RoPE ======
    for block in self.blocks:
        x = block(x, rope_fn=self.feat_rope, attention_mask=None)           # full bidi

    # ====== 7. strip prefix (动态 = prefix_len + model_mode_offset) ======
    x = x[:, prefix_len + model_mode_offset:]                                # [B, L, 768]

    # ====== 8. flow-matching head (always computed) ======
    flow_output = self.final_layer(x.float())                               # [B, L, 512]

    # ====== 9. decoder head (only if decoder_step_active) ======
    decoder_logits = None
    if decoder_step_active is not None:
        x_f32  = x.float()
        hidden = F.gelu(x_f32 @ self.proj_kernel + self.proj_bias, approximate="tanh")
        decoder_logits = hidden @ self.unembed_kernel + self.unembed_bias   # [B, L, 32128]

    return flow_output, decoder_logits
5.13 完整 forward shape 走查(ELF-B, B=512, L=1024)
步骤张量shape说明
0input_ids[512, 1024]tokenized text,int64
1T5 encoder output[512, 1024, 512]frozen contextual embedding
2normalized x₀[512, 1024, 512]除以 latent_std=0.2(相当于 ×5)
3noisy z_t[512, 1024, 512]z = t·x₀ + (1−t)·ε·2.0;t = sigmoid(N(−1.5, 0.8²)) logit-normal
4self-cond input[512, 1024, 1024]cat([z_t, x_pred]),channel-wise concat,2×512
5self_cond_proj output[512, 1024, 512]线性投影回 512
6bottleneck proj1[512, 1024, 128]512 → 128(无 bias,强约束)
7bottleneck proj2[512, 1024, 768]128 → 768
8time emb[512, 768]sinusoidal(256) → MLP(768)
9time prefix tokens[512, 4, 768]learnable 加上 time_emb 广播
10sc-cfg prefix tokens[512, 4, 768]learnable 加上 ω embedding
11mode tokens (gated)[512, 4, 768]per-example gate 0 (denoise) 或 1 (decode)
12concat: mode + main[512, 1028, 768]中间态:mode 暂时在最前
13concat: prefix + above[512, 1036, 768]最终顺序 [time(4) + sc-cfg(4) + mode(4) + main(1024)]
14RoPE 应用于 q,k[512, 12, 1036, 64]n_heads=12, head_dim=64;prefix 位置 cos=1, sin=0
15each ELFBlock output[512, 1036, 768]共 12 个 block,shape 不变
16strip prefix[512, 1024, 768]取后 1024 个 token
17FinalLayer (flow head)[512, 1024, 512]RMSNorm + Linear 768→512,预测 clean x₀
18a(decode only) proj[512, 1024, 512]x_f32 @ proj_kernel + proj_bias,再 GELU(tanh)
18b(decode only) logits[512, 1024, 32128]hidden @ unembed_kernel + unembed_bias
5.14 ELF-B 参数量分解(实测 ~105M)
组件计算参数
self_cond_proj1024×512 + 512524,800
BottleneckTextProj.proj1512×128 (no bias)65,536
BottleneckTextProj.proj2128×768 + 76899,072
TimestepEmbedder (×2: time, sc-cfg)(256×768+768) + (768×768+768) ≈ 788K×2 = 1,575,936
t_emb_tokens + sc_cfg_tokens + mode_tokens3 × (4×768)9,216
每个 ELFBlock~7.09M(见下)
  RMSNorm × 22 × 7681,536
  Attention.qkv768 × 2304 + 23041,771,776
  Attention.q_norm + k_norm2 × 64128
  Attention.proj768 × 768 + 768590,592
  SwiGLU.w12768 × 4096 + 40963,149,824
  SwiGLU.w32048 × 768 + 7681,573,632
  子总(一个 block)7,087,488
12 个 ELFBlock12 × 7.09M85,049,856
FinalLayer (norm + linear)768 + 768×512 + 512394,496
proj_kernel + proj_bias (decoder)768×512 + 512393,728
unembed_kernel + unembed_bias512×32128 + 3212816,481,664
合计(trainable)各行精确相加104,594,304 (~105M, 假设 vocab=32128)
+ T5-small encoder (frozen)~35M

关键观察:decoder unembedding 占 16.5M(~16%),这是 32128 词表的主要开销。 而 12 个 transformer block 占 85M(~81%)—— 真正的核心。 self_cond_proj 只占 0.5%,但它是 self-conditioning trick 能 work 的关键 plumbing。

✅ 总结架构哲学

6 · 推理 / 采样 — 完整代码 + 时间网格 + Tab 6/7 全数字

6.1 推理主流程

给定 (B, L) 和采样配置(method, n_steps, cfg, sc_cfg, γ):

# src/utils/generation_utils.py — _generate_samples_single_batch (简化)

@torch.no_grad()
def _generate_samples_single_batch(model, generator, z, t_steps,
                                   cond_seq, cond_seq_mask,
                                   config, sampling_config,
                                   cfg_scale, self_cond_cfg_scale):
    method = sampling_config.sampling_method                # 'ode' or 'sde'
    B, L, d = z.shape                                       # d = 512

    if cond_seq is None:                                    # OWT 无条件生成
        cond_seq = torch.zeros((B, L, d), dtype=z.dtype, device=z.device)
        cond_seq_mask = torch.zeros((B, L), dtype=z.dtype, device=z.device)

    # 在条件位置上把 z 设回 clean cond_seq(cond 不去噪)
    z      = restore_cond(z, cond_seq, cond_seq_mask)
    x_pred = restore_cond(torch.zeros_like(z), cond_seq, cond_seq_mask)

    n = t_steps.shape[0]                                    # = n_steps + 1
    sde_gamma = getattr(sampling_config, "sde_gamma", 0.0)

    use_bf16 = config.use_bf16 and z.is_cuda
    with torch.amp.autocast('cuda', dtype=torch.bfloat16, enabled=use_bf16):
        # n_steps - 2 个中间步用 ODE/SDE
        for i in range(n - 2):
            t       = t_steps[i].item()
            t_next  = t_steps[i + 1].item()
            if method == "sde":
                z, x_pred = _sde_step(z, t, t_next, x_pred, gamma=sde_gamma, ...)
            else:                                           # 'ode'
                z, x_pred = _ode_step(z, t, t_next, x_pred, ...)

        # 最后一步强制 ODE — t 接近 1 不再注入新噪声
        z, x_pred = _ode_step(z, t_steps[-2], t_steps[-1], x_pred, ...)
    return z                                                 # [B, L, 512]


@torch.no_grad()
def _dlm_decode_batch(z, model, t_final_val, config, self_cond_cfg_scale):
    # 最终一步把 latent z (≈ clean embedding) 映射回 token IDs
    B = z.shape[0]
    t_final = torch.full((B,), float(t_final_val), dtype=z.dtype, device=z.device)
    sc_batch = (torch.full((B,), float(self_cond_cfg_scale), dtype=z.dtype, device=z.device)
                if config.num_self_cond_cfg_tokens > 0 else None)

    # 推理时 self-cond 输入永远 = zeros(与训练 decoder 分支一致)
    z_input = torch.cat([z, torch.zeros_like(z)], dim=-1) if config.self_cond_prob > 0 else z

    with torch.amp.autocast('cuda', dtype=torch.bfloat16, enabled=config.use_bf16):
        _, decoder_logits = model(
            z_input, t_final,
            self_cond_cfg_scale=sc_batch,
            decoder_step_active=True,                       # ← mode token gate = 1
        )
    return decoder_logits.argmax(dim=-1)                    # [B, L]

关键点:主循环 + 最终 decode 是两次独立 forward。前者把噪声 ε transport 到接近 clean x; 后者把 clean embedding 投影到 vocab。

6.2 ODE step — 标准 Euler(caller 在 generation.py 里先 init z = randn × denoiser_noise_scale)

def _ode_step(model, z, t, t_next, x_pred_prev,
              config, cfg_scale, self_cond_cfg_scale,
              cond_seq, cond_seq_mask):
    t_batch = torch.full((z.shape[0],), float(t), dtype=z.dtype, device=z.device)
    v_pred, x_pred = _forward_sample(
        model=model, z=z, t_batch=t_batch, x_pred_prev=x_pred_prev,
        config=config, cfg_scale=cfg_scale, self_cond_cfg_scale=self_cond_cfg_scale,
        cond_seq=cond_seq, cond_seq_mask=cond_seq_mask,
    )
    return z + (t_next - t) * v_pred, x_pred                # Euler: z_{i+1} = z_i + dt · v

6.3 SDE step — PyTorch branch(与论文伪码 Alg 6 update 基点略有差异:从 z_back Euler,不是从原 z)

def _sde_step(model, z, t, t_next, x_pred_prev,
              config, cfg_scale, self_cond_cfg_scale,
              cond_seq, cond_seq_mask, gamma, generator):
    h      = float(t_next - t)
    alpha  = max(0.0, min(1.0, 1.0 - gamma * h))            # 信号保留比例 ∈ [0,1]
    t_back = alpha * float(t)                               # 时间往回拉到 α·t

    eps = torch.randn(z.shape, dtype=z.dtype, device=z.device) * config.denoiser_noise_scale
    z_back = restore_cond(alpha * z + (1.0 - alpha) * eps, cond_seq, cond_seq_mask)

    t_batch = torch.full((z.shape[0],), t_back, dtype=z.dtype, device=z.device)
    v_pred, x_pred = _forward_sample(
        model=model, z=z_back, t_batch=t_batch, x_pred_prev=x_pred_prev,
        config=config, cfg_scale=cfg_scale, self_cond_cfg_scale=self_cond_cfg_scale,
        cond_seq=cond_seq, cond_seq_mask=cond_seq_mask,
    )
    # 从 backtracked state 用 Euler 一步推到 t_next
    return z_back + (t_next - t_back) * v_pred, x_pred

三个直觉:

  1. γ = 0: alpha = 1, t_back = t, z_back = z — 退化为 ODE Euler
  2. γ > 0: alpha < 1,z 被 α 缩小并注入新噪声 (1-α)·ε — 相当于把"已经去噪一些"的状态拉回到更早时刻,再 denoise 一次
  3. 因此 SDE 是用额外随机性"修正"早期 denoise 错误。Tab 7 同 CFG 下 SDE 比 ODE PPL 低 12-35%(详见 6.8)

6.4 时间网格 — 训练同分布的 logit-normal

# 函数签名默认参数是 P_mean=-0.8(旧默认)
# 但 caller 在 generation.py 中传入 config.denoiser_p_mean = -1.5(OWT 训练值)
def get_sampling_steps(n_steps, time_schedule="logit_normal",
                       P_mean=-0.8, P_std=0.8, device=None, dtype=torch.float32):
    if time_schedule == "uniform":
        return torch.linspace(0.0, 1.0, n_steps + 1, dtype=dtype, device=device)

    # logit-normal:从训练同分布抽 n_steps - 1 个点
    z = torch.randn((n_steps - 1,), dtype=dtype, device=device) * P_std + P_mean
    steps = torch.sigmoid(z)
    steps = torch.sort(steps).values                        # 升序排列
    # 强制端点 0 / 1
    lo = torch.zeros((1,), dtype=dtype, device=steps.device)
    hi = torch.ones((1,),  dtype=dtype, device=steps.device)
    return torch.cat([lo, steps, hi], dim=0)                # [n_steps + 1]

论文 App B.2:"smaller intervals when t is close to 0 and larger intervals as t approaches 1"。 P_mean=−1.5 使 sigmoid(N(−1.5, 0.64)) 的密度向 t 较小的 noisy 区集中, 所以中间 sample 点偏小 → 排序后noisy 区间隔细密、clean 区间隔粗,与训练 t 同分布匹配。

6.5 _forward_sample — 套两层 CFG

真实的每步 forward要处理两层 CFG:

# sampling_utils.py — 双层 CFG 嵌套

def _forward_sample(model, z, t_batch, x_pred_prev, config,
                    cfg_scale, self_cond_cfg_scale, cond_seq, cond_seq_mask):
    # 内层:带 cond 的 forward(条件 token + SC-CFG)
    v_cond, x_cond = _forward_sample_self_cond(
        model, z, t_batch, x_pred_prev, config,
        self_cond_cfg_scale=self_cond_cfg_scale,
        cond_seq=cond_seq, cond_seq_mask=cond_seq_mask,
    )
    if cfg_scale == 1.0:
        return v_cond, x_cond                       # 无 input-cond CFG,直接返回

    # 外层:input-cond CFG → 再跑一次 uncond
    z_uncond = restore_cond(z, torch.zeros_like(z), cond_seq_mask)
    x_pred_prev_uncond = (None if x_pred_prev is None else
                          restore_cond(x_pred_prev, torch.zeros_like(x_pred_prev), cond_seq_mask))
    v_uncond, x_uncond = _forward_sample_self_cond(
        model, z_uncond, t_batch, x_pred_prev_uncond, config,
        self_cond_cfg_scale=self_cond_cfg_scale,
        cond_seq=torch.zeros_like(cond_seq), cond_seq_mask=cond_seq_mask,
    )
    # 标准 CFG 线性组合
    v_out = v_uncond + cfg_scale * (v_cond - v_uncond)
    x_out = x_uncond + cfg_scale * (x_cond - x_uncond)
    return restore_vx(v_out, x_out, cond_seq, cond_seq_mask)

每步代价(worst case,cond 生成):2 次模型 forward(cond + uncond)。 OWT 无条件生成 cfg_scale=1,每步只1 次 forward。SC-CFG 被烤进训练所以不用 ×2。

6.6 单步算力 vs Baseline

方法每步 forward 数32 步总 forward说明
传统 inference-time CFG 方法(启用时)2 (cond + uncond)常用于 MDLM/LLaDA/Duo 等启用 CFG 的配置
ELF OWT 无条件 (SC-CFG baked-in, cfg=1)132 + 1 decodeSC-CFG 单次 forward 已含 ω 信息
ELF 条件 (input-cond CFG=2)2 (cond + uncond)64 + 1 decodeSC-CFG=1 不额外 ×2,input-cond CFG 才 ×2

6.7 论文 Tab 6 — System-level @ OWT (6 seeds)

论文 page 26 Table 6(6 个 evaluation seed 平均 ± SE):

StepsSC-CFGγGen-PPL ↓Entropy ↑
832.067.32 ± 2.255.14 ± 0.085
1632.033.66 ± 1.095.16 ± 0.026
3231.524.08 ± 0.165.15 ± 0.002

三个观察:

  1. 32 步 SE 极小 (0.16)——结果非常稳定,不是单 seed 运气
  2. γ 从 2.0 → 1.5 — 步数多时不需要那么多噪声注入
  3. 从 8 到 32 步 PPL 砍 ⅔(67.32 → 24.08),但 entropy 几乎不变(~5.15)

6.8 论文 Tab 7 — Scaling × CFG × Sampler(64 步)

论文 page 26 Table 7(下表 PPL 均为 Gen-PPL)。三个 size × 两种采样器 × CFG sweep。SDE 全方位优于 ODE。 表内标 ⁱ 的灰格是 CFG>3 后 PPL 反转上升的 poor-generation cell(详见表后脚注;论文用 entropy<5.0 或 PPL>300 划红区):

SamplerSC-CFGELF-B 105M (PPL/Ent)ELF-M 342M (PPL/Ent)ELF-L 652M (PPL/Ent)
SDE (γ=1.0) 0.536.77 / 5.2839.21 / 5.3537.50 / 5.41
1.029.50 / 5.2333.45 / 5.3031.82 / 5.37
1.525.25 / 5.1828.42 / 5.2628.72 / 5.35
2.022.53 / 5.1425.34 / 5.2326.47 / 5.32
3.019.72 / 5.1021.69 / 5.1823.31 / 5.28
3.537.56 / 5.30 ⁱ36.48 / 5.34 ⁱ22.28 / 5.27
4.036.50 / 5.29 ⁱ34.93 / 5.33 ⁱ21.37 / 5.26
ODE 0.5104.29 / 5.5188.51 / 5.5168.27 / 5.52
1.065.30 / 5.4062.47 / 5.4449.72 / 5.45
1.544.85 / 5.3146.71 / 5.3739.97 / 5.40
2.034.65 / 5.2337.66 / 5.3233.72 / 5.36
3.026.62 / 5.1528.80 / 5.2426.57 / 5.29

ⁱ = 论文表中标灰的 cell:CFG > 3 后 ELF-B / ELF-M 出现 PPL 反转/上升, 不是 entropy < 5.0;它们 entropy 仍 ≥5.29。论文 App C 同时用 entropy < 5 或 PPL > 300 作"poor generation"红区(entropy 指生成 token 分布的熵,越低越接近复读/退化,故太低也不算 valid)。

关键观察(数字均来自 Tab 7 valid 区间内):

6.9 条件生成默认 (XSum / WMT)

src/configs/sampling_configs/cond_sampling_configs.yml

- sampling_method: ode
  num_sampling_steps: [64]
  cfgs: [2]                  # input-cond CFG = 2
  self_cond_cfg_scales: [1]  # SC-CFG = 1(已烤进训练)
  time_schedule: logit_normal

条件生成默认使用 ODE(paper/config 默认;论文未明确解释,可理解为条件任务在标准 CFG=2 下已经稳定,不需要 SDE 额外随机性)。 SC-CFG=1 因为推理时不需要额外 ω 调整;input-cond CFG=2 是标准 image diffusion 推理时 CFG

6.10 PyTorch port 与 JAX 原版的数值漂移承诺

README 承诺:PyTorch port 在 8× L40S / H200 上跑应与 paper TPU v5p-64 数字漂移 ≲ 1 Gen-PPL< 0.5 BLEU/ROUGE。漂移来源:

README 强调用 use_bf16=true(匹配训练 precision)和 use_compile=true(torch.compile ~3-4× speedup)作为推荐 eval flag。

💡 SDE γ 的微妙

论文 Alg 6 写 α = 1 − γ·dt但 dt 不是 1/N(uniform 间隔),是 logit-normal 抽出来的—— 不同步的 dt 跨度差很大(t 靠近 0 时 dt 可能 0.01,靠近 1 时可能 0.4)。 所以同一个 γ 在不同 step 里实际"重置强度"完全不同。代码用 clip(α, 0, 1),仅在 γ·dt ≥ 1 时才把 α clip 到 0: 比如 32-step 默认 γ=1.5,最大 dt 通常 ~0.4,γ·dt = 0.6 不会 clip; 但 8-step γ=2.0 时如果 dt ≥ 0.5 就会 clip。 这是 ELF 实现的关键细节,建议念 paper 时跟代码对照(src/utils/sampling_utils.py:226-251)。

7 · 主结果 — 全部数字

7.1 Scaling: 三个 size 都改善 Gen-PPL / Entropy frontier

Scaling Fig 6
Fig 6 (paper p.8): ELF-B/M/L 在 Gen-PPL ↔ Entropy 平面上的整条 frontier 都改善。同熵下大模型 PPL 更低;同 PPL 下大模型熵更高(多样性更高)。SDE 在所有尺度都比 ODE 更优。

Tab 7 各 size 内自身最低 valid PPL(valid = 落在 entropy 合理区): ELF-B SDE CFG=3 → 19.72;ELF-M SDE CFG=3 → 21.69;ELF-L SDE CFG=4 → 21.37。 绝对最低是 ELF-B 的 19.72;ELF-L 21.37 比 ELF-M 21.69 更低。 但 paper 强调 frontier 而非单点最低:scaling 整体向更优方向推进(同 entropy 下更低 PPL,同 PPL 下更高 entropy)。

7.2 系统级对比(Fig 7)— 三个独立维度均占优

System-level Fig 7
Fig 7 (paper p.8): (a) 步数效率 — ELF-B 32 步 ≈ baseline 1024 步;(b) 即使对比蒸馏过的 MDLM+SDTT / Duo+DCD / FMLM,ELF(无蒸馏)依然 win;(c) 训练 token 用量 — ELF 45B (1×) vs baselines 524–578B (12×)。

Fig 7(a) — Gen-PPL vs Sampling Steps

ELF-B 32 步达到 Gen-PPL ≈ 24。从 Fig 7(a) 视觉读数:ELF-B 32-step 已接近或优于若干 baseline 在高 step(如 1024)下的水平, 推理时间 substantially less than prior methods(paper §4.2 原文)。

Fig 7(b) — Distillation 后的 baseline 仍不及未蒸馏的 ELF

三种蒸馏过的 few-step variant:

这些都需要额外蒸馏阶段(10K-100K extra steps),但 32 步 PPL 还是不如未蒸馏的 ELF-B。 即在这套系统配置下,ELF 不加额外 distillation 仍然超过这些 distilled baselines("架构层面的优势"是我对这个现象的解读,非论文逐字 claim)。

Fig 7(c) — Training token 预算(柱状)

详见 4.9 节 Tab 5:ELF 45.2B 总 token,所有 baseline 都在 524-577B 区间, 其中蒸馏 variant 因为还要加蒸馏 epoch,token 更多。精确 ratio 11.6× / 12.2× / 12.8×(paper 简称 ~12×)。 Tab 5 只统计 ELF/baseline 自身训练与蒸馏 token,不包含外部 encoder 预训练成本

7.3 条件生成(Tab 1)— 2 任务 / 4 指标全部表内最佳

Tab 1 conditional results
Table 1 (paper p.9): ELF-B (105M+35M) 在 WMT14 De-En 拿 BLEU 26.4,XSum 拿 R-1 36.0 / R-2 12.2 / R-L 27.8,超过 MDLM/Duo/E2D2/SeqDiffuSeq/CDCD/AR baseline。
ModelSizeDe-En BLEU ↑XSum R-1 ↑R-2 ↑R-L ↑
AR (Transformer)99M25.230.5 ± 0.1310.2 ± 0.1124.4 ± 0.12
MDLM99M18.433.4 ± 0.1111.6 ± 0.1025.8 ± 0.10
Duo170M+35M21.3 ‡31.4 ± 0.1210.1 ± 0.1025.0 ± 0.12
E2D299M24.828.4 ± 0.118.3 ± 0.0922.0 ± 0.10
SeqDiffuSeq21.319.3 †1.7 †14.1 †
CDCD24.9
ELF-B (ours)105M+35M26.436.0 ± 0.1312.2 ± 0.1127.8 ± 0.12

† = 直接取自该方法原 paper(De-En 数据的默认来源); ‡ = ELF 团队用公开 codebase 重跑(XSum 数据的默认来源);Duo De-En 在 ELF 团队的对比里也是 ‡(重跑)。

重要观察:

7.4 关键 ablation 汇总(App C 全部 7 个 ablation)

论文 App C (pages 20-23) 系统 ablate 7 个设计选择。这是"为什么 ELF 能 work"的实证支撑:

AblationSweep结论 / 默认差距
C.1 Prediction targetx-pred / v-pred / ε-predx-pred 全 dim 稳定ε-pred 全 dim 都 collapse(512/768/1024);v-pred 在 512 dim ok,越高越差
C.2 Bottleneck dim32 / 128 / 512128 最佳 frontier32 偏低 entropy;512 偏高 PPL;128 balance
C.3 Denoiser mode prob0.2 / 0.5 / 0.80.8 (denoise) / 0.2 (decode)0.5 / 0.2 (denoise) PPL/entropy frontier 都明显劣化
C.4 Conditioning stylein-context tokens / adaLN-Zeroin-context 略优 + 省 43M 参数性能 ≈ adaLN,但 ELF-B 148M → 105M
C.5 OptimizerMuon / AdamWMuon 全面优于 AdamWSDE 下差距最显著(paper 定性 "more pronounced under SDE")
C.6 Sampler + time gridODE / SDE × uniform / logit-normalSDE + logit-normalSDE 通常降低 PPL(幅度随 model/CFG 变化,CFG=2 时 21-35%);logit-normal 在各 step 都更优,few-step 时尤其
C.7 Cond CFG scale1 / 2 / 3 / 4CFG=2 最佳1→2 substantially improves;3、4 逐步下降。注:这是条件生成的 input-cond CFG,与 §6.8/Tab 7 主结果里的 SC-CFG(self-cond CFG,最优在 3/4)是不同量

C.1 — 为什么 x-pred 才行?

三种 prediction 是数学等价的,但训练 signal 完全不同。论文用三个 encoder size (T5-small/base/large = 512/768/1024 dim) sweep:

App C.1 Fig 10 — Prediction targets
Fig 10 (paper p.21, App C.1): 三种 prediction target 在不同 encoder dim 下的 Gen-PPL ↔ entropy frontier。详细解读见 §4 黑盒 callout 那张同图。

解释:clean 文本数据在 embedding 空间是低维流形。x-pred 预测的就是这个流形上的点; ε-pred 预测的是高维 Gaussian,模型必须学一个全维度等熵分布——更难。 这条 finding 支持"continuous DLM 的关键不是连续,是 x-prediction" 的 framing。

C.3 — 80/20 denoise/decode 比例

很反直觉:如果 decoder 占比上升到 0.5,按理说 decoder 应该学得更好——但实际整体 frontier 都退化。 解释:decoder 共享 transformer 主干。如果训练时频繁切换 mode,主干被两个目标拉扯;只占 20% 时 decoder 学到的是 "在已经 transport 到 clean 的 embedding 上做最后一步映射"——比例小但效果反而好。

C.6 — Time schedule 在 few-step 时尤其关键

Logit-normal time grid 让 noisy 区间更密——8-step 时这极其重要。 Uniform 8 步 PPL 远高于 logit-normal 8 步。32-step 之后两者差距收窄。
γ sweep(paper 选默认值):8/16 步默认 γ=2.0;32 步 γ=1.5;64 步 γ=1.0。 论文文字说 γ 控制 PPL/entropy trade-off,paper 默认 γ=1.0 作为各 step budget 的 balance; 8/16 步用更大 γ=2.0 是因为粗步长需要更多 stochasticity 修正噪声累积。

7.5 Tab 6 + Tab 7 数据汇总(detail 详见 6.7-6.8)

覆盖最佳 valid Gen-PPL
Tab 6 (system-level, ELF-B, 6 seeds)8/16/32 step32-step SDE γ=1.5 SC-CFG=3: 24.08 ± 0.16
Tab 7 (scaling, 64-step)B/M/L × ODE/SDE × CFG 0.5-4各 size 内最低 valid PPL:ELF-B 19.72 (CFG=3) / ELF-M 21.69 (CFG=3) / ELF-L 21.37 (CFG=4)。CFG>3 部分 cell 标灰

7.6 数据来源 & 实验环境

8 · 定性效果 — 去噪轨迹

Denoising trajectory Fig 17
Fig 17 (paper p.28): 从 t=0 的 gibberish/repetitive token 开始('strength will building building...'),随 t 增长 ELF 渐进地形成语义有意义的短语,最终(t=1)解码为流畅句子。连续轨迹的好处:每一步都是几何渐变,不是离散跳变。

这张图回答了"continuous DLM 到底在做什么"——它在 embedding 空间里描出一条平滑轨迹, 最后一步才把轨迹终点投影到 token 词表。和离散 DLM 每步都做 vocab argmax 完全是两套范式。

9 · 和字节 Cola-DLM 对比 — Field Landscape

Cola-DLM(ByteDance Seed, arXiv 2605.06548, 2026 年 5 月) 是和 ELF 几乎同时(2 周内)冒出来的同类工作。两边都是 continuous DLM,但设计哲学几乎相反: ELF 求简,Cola 求强。下面是我从 Cola 的公开 blog 和 arXiv 摘要整理的对比。

9.0 一句话定位

ELF = "把 encoder 冻结,diffusion 只学 transport"——最小架构、最高数据效率,刷 OWT Gen-PPL。

Cola-DLM = "diffusion 不应该恢复 noisy token observation,应该建模 semantic latent prior"—— 两阶段训练(VAE pre-train → joint VAE+DiT)+ block-wise 推理,~2B 参数,刷 reasoning task average。

9.1 关键差异表

维度ELF (MIT)Cola-DLM (ByteDance)
核心对象Contextual embedding 上 Flow MatchingText VAE latent 上 block-causal FM
EncoderFrozen T5-small (35M)Learnable Text VAE (~500M)
Latent spaceToken-aligned, 512-d (bottleneck 128)Explicit z, d=16 (默认)
Diffusion 目标恢复 clean contextual embedding把噪声 transport 到 learned latent prior
Decoder共享 Transformer (final-step 切换 mode)独立 VAE decoder + KV cache
BackboneDiT,全双向 attentionBlock-causal DiT (intra-block 双向、inter-block 因果)
训练单阶段 80/20 mix两阶段训练 (VAE pretrain → joint VAE+DiT) + block-wise 推理
损失80% MSE + 20% CEStage 2: λVAE·LVAE + λFM·LFM + λref·KL(q‖qref)
参数105M / 342M / 652M~2.3B 总 (1.8B DiT + 500M VAE)
采样32-64 步 ODE/SDE Euler8-16 步 / block, block-causal + KV cache
评测Gen-PPL, BLEU, ROUGETask Avg (LAMBADA/MMLU/SIQA/RACE/...)
对标MDLM/Duo/FLM/LangFlowAR + LLaDA at 2B
显式 latent 边际化
p(x)=∫p(x|z)p(z)dz
无(直接对 contextual embedding 做 FM,最后一步 decode 到 token)有(VAE 提供 p(x|z),diffusion 学 p(z)
训练中间步 token-space loss不做(中间全 MSE on embedding,仅末步混 20% CE)不做(diffusion 在 latent 空间,CE 由 VAE 在两阶段训练分担)
Scaling ceiling论文未给更大 encoder 的 scaling 曲线(C.1 ablation 只到 T5-large 1024-d)blog 报 curve still rising 至 2000 EFLOPs

9.2 两边各自能给出而另一边给不出的 claim

Cola 能给出而 ELF 给不出

  • 显式层次化潜变量 p(x)=∫p(x|z)p(z)dz,diffusion 建 pψ(z)、VAE decoder 建 pθ(x|z)
  • Block-causal attention + KV-cache:扩散 serving 可以像 AR 一样按 block 增量推理
  • 2B 规模 + reasoning benchmark 数字(LAMBADA 50.80 / MMLU 19.30 @ 2000 EFLOPs)

ELF 能给出而 Cola 给不出

  • "中间步不需要 token-space supervision" 的干净对照(中间全程 MSE on contextual embedding,只最后一步 20% CE)
  • 无 distillation、无独立 VAE 栈下的小尺度数字(105M、45B tokens、32 步 Gen-PPL 24.08)
  • Encoder / decoder 主干完全共享(同一份 Transformer 参数 + mode token 切换)的可行性

9.3 Cola-DLM 两阶段训练 + 推理(来自他们 blog)

Cola 不像 ELF 那样单阶段端到端训练。两个训练阶段 + 独立的推理流程:

阶段训练对象损失目标
Stage 1 — VAE pretrainingText VAE encoder + decoder LVAE = −𝔼[log pθ(x|z0)] + β · KL(qφ‖pbase) + λmask·Lmask
(带 BERT-style masking loss)
学一个稳定的 text↔latent 映射,避免 semantic collapse / decoder shortcutting
Stage 2 — Joint VAE + DiTVAE + block-causal DiT Lstage2 = λVAE·LVAE-like + λfm·LFM + λref·KL(qφ‖qφ_ref)
(reference KL 抑制 latent drift)
DiT 在已稳定的 latent 上学 flow matching prior;reference KL 不让 VAE 漂移
Inference (非训练阶段) Block-wise prior transport(DiT 生成 latent block)+ VAE decoder(latent → tokens, KV-cached)

关键超参(来自 blog RQ4 sweet spot 表,released checkpoint 配置)

9.4 Block-Causal Attention(Cola 的核心架构选择)

Cola DiT 在序列维度上做块因果(block-causal)分解:

这种设计的好处:

  1. KV cache 兼容:因为 inter-block 是因果的,已经生成完的早期 block 的 K/V 可以缓存。下一个 block 不重新算前面的 attention,类似 AR 的推理加速。
  2. 训练时仍并行:teacher forcing 的常规做法,把 block-causal mask 加到 attention 矩阵,所有 block 一次性算完。
  3. 生成质量 vs AR 的赌注:保留 diffusion 的"全 batch 同时 denoise"特性(block 内并行),同时获得 AR 的推理时 KV cache 加速

Block size sweep(来自他们 RQ2/RQ3):

9.5 Cola-DLM benchmark 数据(released checkpoint @ 2000 EFLOPs)

Cola 不报 Gen-PPL。他们用"generative few-shot"协议把多选题转成生成任务, 和 AR + LLaDA 在 ~2B 同 scale 下对比。来自 ByteDance-Seed/Cola-DLM GitHub model card 的 released 数字(注意多选任务 4 选 1 随机基线 ≈ 25,下表 HellaSwag/RACE/OBQA 等接近或低于随机,反映 latent 模型当前在多选 few-shot 上偏弱):

TaskCola @ 2000 EFLOPs说明
LAMBADA50.80段落补全
MMLU19.3057 类多选 — 数字仍低于 AR 同 scale
SIQA28.90社会场景推理
RACE19.60阅读理解
Story Cloze30.77故事结尾选择
OBQA23.00选择题常识问答 (OpenBookQA)
HellaSwag10.70常识 NLI
SQuAD30.90抽取式问答
Task Avg26.758 任务平均

注意:blog 内 RQ2/RQ3 ablation 表给出更小的训练 budget 下数字(LAMBADA 31.1-34.6 / MMLU 5.4-10.1 / SIQA 11.1-23.6 等), 但这些是消融区间,不是 headline。上面才是 released checkpoint 数字。

他们的 scaling 实验跑到 ~2000 EFLOPs。Blog 说 "Task Avg 在 ~2000 EFLOPs 还在 rising"—— 官方称仍有 headroom,没看到饱和。

9.6 Cola 的关键 ablation 发现(他们 RQ2/RQ3)

消融结论
Fixed-VAE vs Joint-VAE trainingJoint 在大算力下 win。冻结 VAE 训练 (像 ELF 那样) 在小算力 ok,但 scaling 时无法继续涨
All-Scratch baseline从零训 VAE + DiT 始终不如先 Stage 1 pretraining
Interval freezing (Stage 1.5)VAE 在中间阶段冻结一段再放开 — 比一直 joint 差
Sampling steps少步数明显不足;~8-10 步基本恢复;16-32 步饱和
Patch size 2 (压缩 latent)整体比 patch 1 差;但当 prompt length 对齐 patch 边界时反而略好(18.12 vs 17.31)— "Token-level segmentation 不一定最优"

这里有个有意思的对比:Cola 通过 ablation 证明 joint VAE+DiT 训练在大算力下更优;ELF 通过 ablation 证明 frozen encoder 在小数据小算力下更优。两个结论不矛盾—— 它们对应不同的训练规模、不同的"哪个组件更值得优化容量"的判断。

9.7 一句话收尾

ELF 与 Cola 是互补的两条路线:ELF 用最小架构(frozen encoder + 共享 decoder)在小尺度把生成质量(Gen-PPL)推到极致;Cola 用 explicit latent + block-causal DiT 在 ~2B 尺度上换来 reasoning task 与 AR-style serving。完整技术差异已汇总在 §9.1 的对比表——选哪条,取决于你要的是数据效率,还是大尺度可扩展性 + 推理能力

10 · DSL:连续表征的第三条路线

聊 ELF 时我们反复绕回一个问题:要在 token 上做 continuous diffusion,那个连续表征坐标系到底从哪来?ELF 的答案是「外借」——拿一个 frozen 的 T5 contextual embedding 当现成坐标系;Cola 的答案是「自己学一个」——联合训练一个 VAE latent,把离散 token 压进一个学出来的连续隐空间。本节要介绍的 DSL(Discrete Stochastic Localization for Non-autoregressive Generation)给出了第三种答案。但它真正吸引人的地方不在「坐标系怎么来」,而在一个训好的网络能同时扮演 continuous diffusion、discrete masked diffusion、autoregressive 三种角色。这条「统一」的主线,是理解 DSL 的钥匙。

论文一作是 吴云舒(Yunshu Wu),合作者包括 Jiayi Cheng、Longxuan Yu、Partha Thakuria、Rob Brekelmans、Evangelos E. Papalexakis、Greg Ver Steeg(UC Riverside 等)。

先把名字拆清楚。"Discrete" 指的是建模对象是离散 token,而建模本身是连续的——DSL 是一个 continuous-state framework,每个 token 被嵌到一个 unit-sphere(单位球面)上。所以它不是「离散扩散」的又一个变体,而是把离散符号搬进连续球面几何里去做的框架。abstract 点出了 DSL 的核心性质——Bayes-optimal denoiser 在 localization channel 下对名义 SNR 不变;但没写出具体 channel 形式与对应 SDE 方程,本节也不去编它的数学机制,只讲它带来的可观察现象。

10.1 技术支点:一个对 SNR「免疫」的 denoiser

DSL 的统一能力,根子上来自一条理论性质:它的 Bayes-optimal denoiser 在 localization channel 下对名义 SNR 不变(invariant to the nominal SNR under the localization channel)。直觉上的后果是——同一个训好的网络能支持一整族 per-token SNR path(an entire family of per-token SNR paths):序列里不同 token 可以各自拥有自己的 SNR 路径设计。

而这正是「统一」的第一块拼图:masked diffusion 不是另一个并列的范式,而是 DSL 的一个特例。论文把它表述为 "endpoint masked-diffusion paths as a special case"——当 SNR path 取到端点形态时,per-token SNR 框架就退化成我们熟悉的 masked diffusion。换句话说,masked diffusion 是这一族路径里的一条特殊曲线,而不是一个对手。

10.2 顺带涌现的几何:表征是随 SNR「长出来」的

「一整族 SNR path」还带来一个很漂亮的副产品,也是 DSL 和 ELF / Cola 在表征哲学上最不同的地方:DSL 的表征几何没有一个固定形态,而是随 SNR 从粗到细(coarse-to-fine)动态涌现。同一个 token 在不同 SNR 下「是什么」会变:

  SNR = 0   ───────▶   finite SNR   ───────▶   high SNR
  (零信息)            (软证据)              (硬锚点)

  解码为 [MASK]        soft evidence          靠近 hard token anchor
  什么都还没说        带方向性的软证据        贴近一个具体 token 锚点

  原文:"zero information decodes as [MASK], finite SNR carries
        soft evidence, and high SNR localizes near hard token anchors."

  └── 几何不是事先给定,也不是训练后冻结,而是沿 SNR 一层层「长」出来 ──┘

这条轨迹读起来很像一种层次化的细化:低 SNR 时模型只能给出最粗的判断,SNR 升高后逐步靠近某个具体词的锚点。表征几何既不是预先外借的(ELF),也不是学一个固定 latent(Cola),而是由信噪比这条轴自然组织出来的——也可以把它当作 "stochastic localization" 这个名字的一种直觉读法。

10.3 统一的兑现:同一个 checkpoint,三种采样方式

因为 denoiser 对 SNR 免疫、又内含一整族 per-token 路径,DSL 得以让同一个训好的 checkpoint在推理时切换成不同的生成模式,而不需要重训:

这才是「统一 continuous / discrete / AR」这句话的具体兑现:不是三个模型、三套权重,而是一个 checkpoint 在推理时长出三张脸

一个很实用的副产物:可以从现成的 MDLM checkpoint 直接 fine-tune。由于 masked diffusion 本就是 DSL 路径族的端点特例,论文指出可以从一个 pretrained MDLM checkpoint 直接 fine-tune("without retraining",无需从头重训),提升 distributional faithfulness(分布层面的逼真度)。对已经手握 masked diffusion 模型的人来说,不必从零训练就能迁移进这套框架。

10.4 采样步数与评测口径

效率上,DSL 报告可以用少至 T=48 步完成生成("as few as T=48 total steps—without distillation",注意是不经蒸馏的),论文同时测试了 T=128–1024 的更长步数范围。评测指标用的是 MAUVE,衡量在 OpenWebText 上生成分布与真实分布的 distributional faithfulness。

口径提醒(别横比)。MAUVE 衡量的是「生成分布有多像真实分布」,和我们之前讨论 ELF 时用的 Gen-PPL 不是同一个东西,数值不可直接对照。本节也不去声称 DSL「打败」了 ELF 或 Cola——三者指标不同、setting 不同、规模不同,把它们排个名次没有意义。DSL 在这条 blog 里的价值,是提供了处理 token 表征的第三种思路,以及一套「一个 checkpoint 统一三范式」的视角。

10.5 放回三方路线里看

把 ELF、Cola、DSL 放在一起,最清楚的对比维度就是「连续表征从哪来、是不是固定的」:

维度ELFColaDSL
表征坐标系怎么来 外借现成的 frozen T5 contextual embedding 联合训练一个 VAE latent 在 token 的 unit-sphere 几何上做 stochastic localization
几何是否固定 固定(预先给定、冻结) 固定(训练后定下来的 latent) 不固定——随 SNR 动态涌现(coarse-to-fine)
与 masked diffusion 的关系 独立路线,与 masked diffusion 正交 独立路线,与 masked diffusion 正交 masked-diffusion path 是其路径族的端点特例
一句话定位 坐标系「外借」 坐标系「自己学一个固定的」 坐标系「随信噪比长出来」,一个 checkpoint 统一三范式

前两条都在回答「给 continuous diffusion 一个好坐标系」,区别只是借还是学;DSL 把问题换了个问法——表征几何随 SNR 动态涌现,而不是预先外借或训练成一个固定 latent。顺着这个改写,它把 masked diffusion 收成特例,也把 continuous / discrete / AR 三种采样方式统进了同一个 denoiser。如果说 ELF 与 Cola 回答的是「该把 token 放进哪个连续坐标系」,DSL 提供的,是理解 continuous DLM 表征的第三个有用视角。

10.6 DSL-LLaDA:1000 步把一个现成的 8B masked dLLM「改」成连续去噪器

前面 §10 把 DSL(Discrete Stochastic Localization,Wu et al. 2026)讲成了一条统一主线:离散 token 的「mask / unmask」可以被还原成一个连续的「信噪比逐渐升高」过程,于是 discrete diffusion 与 continuous diffusion 在同一套数学框架下握手。但 §10 主体里还留了一句话没展开——「可以从一个现成的 discrete dLLM checkpoint 出发,通过 continue-pretrain 给它赋予 continuity」。这句话听起来很美好,可它真能在 8B 这种规模上兑现,还是只是小模型上的玩具结论?

DSL-LLaDA: Scaling Continuous Denoising to 8B Masked Diffusion LMsLongxuan Yu吴云舒(Yunshu Wu)共同一作,equal contribution;另有 Yu Fu、Siheng Xiong、Rob Brekelmans、Hui Liu、Yue Dong,通讯作者 Greg Ver Steeg;单位 UC Riverside / Georgia Tech / Microsoft;arXiv 2606.01024)给出的答案相当干脆:不用从头重训,只 continue-pretrain 1000 步(在 FineWeb-Edu 上),就能把现成的 LLaDA-8B-Instruct 从「二值 mask 去噪器」改造成「连续 embedding 空间去噪器」。它的定位很清楚:不是一个新框架,而是 §10 DSL 框架的直接 8B scale-up 实证——这正是 §10 主线最想要的那块拼图。

10.6.1 最抓人的一点:只 continue-pretrain 1000 步

传统 masked dLLM(LLaDA、MDLM 这一类)看世界只有两档:一个 token 要么被完全 mask(信息为零),要么是 gold token(信息完整)。DSL-LLaDA 的训练改造很集中——把 0/1 的硬 mask,换成 §10 里 DSL 的连续 per-token 高斯软 mask(推理侧相应改用 SDE sampler,见 §10.6.2)。对每个 clean token x_i 采一个信噪比 γ_i,构造一个带噪 embedding(DSL 把 token 的 noise embedding 约束在 §10 开头提到的单位球面上,高斯噪声就加在这个球面表示上——这也补上了 §10 里 DSL abstract 未展开的具体加噪形式):

对序列里每个位置 i:
    采样 SNR  γ_i  ~ schedule          # γ 是这个 token「已经露出多少信息」的旋钮
    取 clean token 的 embedding  e_{x_i}
    采高斯噪声  ε_i ~ N(0, I)
    构造 noisy embedding:
        z_i = γ_i · e_{x_i} + sqrt(γ_i) · ε_i

    # 三档连续过渡(不再是 0/1 两档):
    #   γ_i → 0    : z_i = 0, 不含 clean 信息 → 等价于「被 mask」 (unknown)
    #   γ_i 适中   : 信息半遮半掩          → 像一个随机 token       (unreliable)
    #   γ_i 很大   : 噪声被淹没            → 基本等于 gold token     (clear)

注意 γ_i = 0z_i = 0(不含任何 clean-token 信息),正好对应原模型「全 mask」那一档;γ_i 增大时信息越来越清晰。也就是说,原来的二值 mask 只是这条连续谱的两个端点,DSL-LLaDA 只是把中间那段连续地填了进来。一个轻量的 softmax converter 负责把这些 noisy embedding 映回 backbone 的输入空间,于是每个位置就沿 SNR 轴自然走过三个阶段:unknown(低 SNR,像 mask)→ unreliable(中 SNR,像随机 token)→ clear(高 SNR,gold token)

有两个设计上的便利,正是这套改造能「1000 步搞定」的根本原因:

对比 §10 主线就很清楚了:DSL 原文提出框架并在较小规模上验证,DSL-LLaDA 做的就是「scale 到 8B」这一步,证明那句「从 discrete checkpoint continue-pretrain 赋予 continuity」不是小模型特例,在 8B 上花 1000 步就能落地。

10.6.2 推理:SDE 连续采样,把「硬选 token」推迟到最后一步

训练给了模型「连续噪声」的概念,推理这边对应的就是一个 SDE-based continuous sampler:沿着一个逐渐升高的 SNR schedule 做 Heun 步,并且把 hard token commitment(真正把某个位置落子成离散 token)一直推迟到最后一步。直观地说,中间过程全程在连续 embedding 空间里「软着陆」,避免了离散采样里那种「早早把一个位置钉死、后面发现错了也改不动」的尴尬。

输入: 8B backbone (continue-pretrain 1000 步), 升序 SNR schedule {γ_1 < γ_2 < ... < γ_T}
初始: 所有位置处于低 SNR (unknown / mask-like) 状态

for t = 1 .. T:                      # SNR 由低到高
    用 softmax converter 把当前 noisy state 映回输入空间
    backbone 预测各位置的 clean-token 分布
    沿 SDE 做一个 Heun step(提升有效 SNR, 逐渐进入 clear 阶段)
    # 注意: 此处不做 hard token commitment, 各位置仍可被后续步骤修正

最后一步:
    才把各位置 commit 成离散 token   ← 延迟 hard commitment
  离散 dLLM (iterative unmask):
     step1   step2   step3 ...        每步硬选若干位置 → 定了就不能改
     [M M x M] → [M y x M] → [t y x M] → ...   ← 早期错误会被"焊死"

  DSL-LLaDA (continuous SDE):
     沿 SNR 升高方向做 Heun 步, 全程在连续 embedding 空间软演化
     γ↑ ───────────────────────────────────▶  仅最后一步 commit 成离散 token
     (整段都可被后续步骤修正, 延迟落子)

「延迟硬性敲定」这个特性,正是下面几个有趣行为的来源。

10.6.3 连续噪声「换」来的三个新行为

这篇工作最有意思的地方在于:仅仅把二值 mask 换成连续噪声,模型就涌现出三种在本文比较的 discrete iterative-unmasking baseline 中都没有观察到的定性行为

10.6.4 一张小表:zero-shot summarization(ROUGE-1)

下表是 zero-shot summarization 的 ROUGE-1 数字(越高越好,低 NFE 设置)。论文报告 DSL-LLaDA-SDE 在 NFE ≤ 16 上四个 benchmark 均最佳;下表并列论文 Table 2 给出的离散 LLaDA 基线(NFE=8/16)。注意在更高 NFE(32/64)上离散 LLaDA 在 PubMed/arXiv 反超——SDE 的优势集中在低 NFE。

表 10.6.1  Zero-shot summarization ROUGE-1(DSL-LLaDA-SDE,含 vs 离散 LLaDA 对照)
Benchmark DSL-LLaDA-SDE
(NFE=8)
DSL-LLaDA-SDE
(NFE=16)
LLaDA 对照
(NFE=8 / 16)
XSum28.430.425.2 / 29.0
CNN-DM28.133.023.1 / 28.2
PubMed29.432.211.9 / 20.2
arXiv27.328.610.6 / 16.2

注:数字取自论文 Table 2(每 benchmark 1000 samples,seed=42)。另有 LLM-as-judge(GPT-5.4)评测,NFE=16 下评审偏好 SDE 输出 57.9%、偏好离散 LLaDA 26.6%。注意 ROUGE-1 与 perplexity 是不同口径的指标;且这是 DSL-LLaDA 自己 summarization setting 下的对照,与 §10 其它工作(如 ELF、Cola)的 GenPPL / MAUVE 不在同一任务上、不宜直接横比,本节不主张「打败」谁。

10.6.5 为什么说是「continuous-noise 驱动」而不是「多训了 1000 步」

一个自然的质疑是:上面这些好处,会不会只是「在 LLaDA 上又 continue-pretrain 了一会儿」的功劳,跟连续噪声没关系?论文专门设了两个 同等 compute 的对照来排除这个解释:

结果是——这两个对照都没有上面那三个行为(低复读、低 NFE 摘要领先、选择性鲁棒)。所以差异不该归因于「额外训练量」,而更合理地归因于continuous-noise exposure 本身:是「把 binary mask 换成连续高斯软 mask」这一件事,让模型学会了在连续空间里延迟 commit、选择性纠错、稳定生成。

把它放回 §10 的统一视角:DSL 给出了「离散 mask = 连续 SNR 谱的端点」这套框架,DSL-LLaDA 则把它 scale 到 8B 并坐实了一个很实用的结论——一个现成的 masked dLLM,不必从头训练,只用 1000 步 continue-pretrain,把二值 mask 换成连续高斯软 mask,就能获得连续 embedding 空间去噪能力,并解锁连续扩散那一侧的好处(延迟 commitment、low-NFE 推理)。配合 MDM-CPT / XDLM 两个 same-compute 对照,作者把这些增益归因到「连续噪声暴露」本身,而非额外算力或随机 token 训练。对整章而言,这是「从 discrete dLLM checkpoint 低成本赋予 continuity」这条主线,目前最直接的一份 8B 真实证据。

延伸阅读:

11 · Q&A

11.A 关于 ELF 自身

Q1: ELF 为什么能少用 ~12× 训练数据?

真正的来源是 frozen T5-small 的 contextual embedding 已经包含了语言的几何先验。 ELF 不学"语言是什么",只学"如何在这个空间里 transport"。等效于一种 transfer learning,T5 的预训练成本 (Google C4 上约 1T tokens)没被算进 ELF 的 45.2B tokens 里。Tab 5 只比较了 ELF vs baseline 的自身训练 token,不计 encoder pretrain。 这是 ELF 数据效率 claim 最主要的 caveat。

合理的回应:可以把 ELF 看作"在 T5 表示上的 transfer-learning DLM"。MDLM/Duo 也间接用了 word-level tokenizer 的语言先验 (虽然程度不同)。但承认这是"有限的 12×",不是免费午餐。

Q2: 如果换 encoder(LLaMA-3-8B hidden state / 更强的开源 encoder)会更强吗?

大概率是的,但 paper 没做这个实验(App C.1 只 sweep 了 T5-small / base / large 三个 size)。 这是 ELF 最大的 follow-up 机会,也是它最脆弱的 claim—— ELF 的天花板可能就是 encoder 的天花板。如果有人用 LLaMA-3-8B hidden state + ELF flow 训一个 105M model,能不能逼近 LLaMA 自身 quality? 这个实验没人做过,是 obvious next step。

Q3: 为什么 final step 才做 CE?中间步加 CE 不更好吗?

Paper App C.1 ablate 了:x-prediction 在所有 embedding 维度都稳定,v-prediction 在高维 degrade,ε-prediction 全面崩溃。 背后假设(paper 引用 Li & He 2511.13720):clean 文本数据在 embedding 空间是低维流形。 中间步加 CE 等价于"把噪声大的 z_t 强行映射回 token",这等于把 quantization wall 偷偷搬回来—— 破坏了 ELF 的"only final step rounding"核心 framing。

Q4: 训练时 CFG(Eq 3)为什么不直接推理时跑两遍?

可以,但训练时把 SC-CFG 烤进 vtarget 后,推理只需要一次 forward,省一半算力。 代价是训练时3 次 forward(2 个 no-grad + 1 个 gradient-tracked)。 推理跑无数次,训练跑一次。划算。 这个 trick 是 image diffusion 圈 Chen et al. "Visual Generation without Guidance" (ICML 2025) 的方法,ELF 直接搬过来。 注意 ELF 仍然保留另一个推理时 CFG(input-cond CFG=2,用于 XSum/WMT),跟 SC-CFG 是两个独立机制。

Q5: 32 步 PPL 24.08,比 dataset 自身 PPL 怎么样?

论文用的就是 frozen GPT-2 Large 当 judge。Fig 7 里的 "Dataset" 虚线就是 OWT 真实文本在 GPT-2 Large 下的 reference PPL。 ELF-B 32 步 SDE 24.08 接近这条 reference;ELF-M 在 64 步 SDE / SC-CFG=3 下 21.69,更接近该 reference。 注意 "Dataset reference" 不是严格的理论下限—— 低于它也可能伴随低 entropy(即重复但流畅)。 而且这只是 GPT-2 Large judge 下的"流畅性"指标,不是通用质量下限: 换更强 judge(GPT-4 / Llama-3)数字会变。

Q6: bottleneck 为什么是 128,不是 256/64?

App C.2 sweep 了 {32, 128, 512}(不含 256/64,论文没测)。 128 是 ODE / SDE 双采样下的 frontier balance 点: 32-d 在 SDE 下 PPL 最低但 entropy < 5(多样性不够),512-d 把 PPL 推高了。 背后假设:clean text 在 embedding space 是低维流形,128-d 就足够"覆盖"这个流形。 对比 image diffusion 也常用类似 bottleneck(DiT、SD 都有)。256 这个中间值 paper 没测,按 frontier 趋势插值应该 fine。

Q7: Muon 为什么比 AdamW 强?尤其在 SDE 下?

App C.5。Muon 对 2D 参数用 Newton-Schulz orthogonalization 把梯度先正交化再 step, 这抑制 ill-conditioned 方向,相当于 implicit second-order。 SDE 采样需要更精确的 v 预测(噪声重新注入会放大 v 的误差), Muon 训出来的模型在 v 上更"光滑",所以 SDE 推理时优势更大。 Paper 只定性说 "more pronounced under SDE",没给具体数字。 工程上 Muon 还有 fp32 NS5 + Nesterov bias correction 等 patches(详 §4.10)。

Q8: ELF 怎么处理长序列?1024 已经是上限吗?

论文实验最长 seq=1024(OWT)和 1088(XSum)。当前 checkpoint 和 config 没有验证 8K 以上长度。 Architecture 上没有 causal LM 那种生成方向限制,但: (a) RoPE buffer 按 max_length 预构建(要扩 8K 需重建或加 RoPE scaling); (b) 全双向 self-attention 是 O(L²) 复杂度。 关于 "能 scale 到 8K context 吗"——architecture 上理论可行但 paper 没测; 可能需要 RoPE scaling + linear/sparse attention 改造,是 obvious extension。

11.B 关于 Cola-DLM / 两条路线对比

Q9: Cola-DLM 那种 VAE 路线不是更"正统"吗?为什么 ELF 看起来更干净?

VAE 路线的代价是 encoder 也得训,要解决 posterior collapse、latent drift、reconstruction trade-off 等额外问题。 ELF 用 frozen T5 把 encoder 部分外包给现成的预训练模型,所以 paper 短、ablation 干净、可以专心 ablate 7 个 ELF 自身的超参。 Cola 必须同时管理一大堆 VAE 超参(latent dim, block size, logSNR, KL ratio, anti-drift KL, multi-stage scheduling, ...)。 两边的代价不同:ELF 把复杂度外包给 Google 的 T5 训练;Cola 把复杂度内化到自己的训练 pipeline。

Q10: Cola 报 LAMBADA 50.80、MMLU 19.30,看起来很低?

对,这是 generative few-shot 协议下的数字(把多选题转成生成)。 按 Cola 自己的 model card 与同协议对照,Cola 的 LAMBADA 接近部分 AR baseline,MMLU 明显偏弱。 但 Cola 论文强调的是 scaling 趋势:曲线还在涨,没饱和。 需要说明的是,Cola 当下的绝对分数确实不高——但它的卖点是架构可行性 + scaling shape,而非当前数字。

Q11: 为什么 ELF 没用 KV cache?

ELF 的 DiT 是全双向 attention,每步 forward 都要重新算所有位置的 attention。 KV cache 需要 causal mask(前面 K/V 缓存供后面 query 用),ELF 没这个结构。 这是 ELF 推理慢于 AR 的根本原因之一: 即使 ELF 32 步 << AR 1024 token,每步的 attention 复杂度还是 O(L²)。 Cola 用 block-causal 解决了这个——这是 Cola 的核心架构卖点。

Q12: ELF 和 Cola 谁更接近 production-ready?

都没有。但 ELF 更接近"clean scientific demonstration",Cola 更接近"engineering system"。 具体看:

12 · 引用 & 资源

资源链接
ELF paper PDFarXiv 2605.10938
ELF GitHub (官方 JAX)lillian039/ELF
ELF PyTorch portpytorch_elf branch
ELF HF checkpointsembedded-language-flows
Cola-DLM paperarXiv 2605.06548
Cola-DLM GitHubByteDance-Seed/Cola-DLM
Cola-DLM bloghongcanguo blog
Cola-DLM HFByteDance-Seed/Cola-DLM

参考文献(arXiv)

本文涉及论文的 arXiv 链接(2026 新工作的编号经 cross-model 核对)。

FM 家族(§3,含主角)

第三条路线 / Cola-DLM(§9–§10)

离散 DLM 基础

Continuous DLM 旧路(embedding / latent)

Factorized-gap 修补 / 流形 / trick


作者 Ruofeng Yang(杨若峰) (Shanghai Jiao Tong University, 2026-05)。文档由 ARIS (Auto Research in Sleep) 的生态 ARIS-in-AI-Offer 工作流产出,由 Claude Opus 4.7 整合 Codex GPT-5.5 xhigh + Gemini auto-gemini-3 跨模型讨论后撰写。本文是关于 ELF 等 continuous DLM 论文的第三方阅读笔记 / 综述,所有论文内容、图表、代码版权归各自原作者所有。 图片均截自 ELF arXiv PDF v1(2026-05-11)。代码片段来自官方 pytorch_elf 分支 @ b29d883