ELF(Embedded Language Flows, MIT; Hu/Qiu et al., senior 作者 Kim/Andreas/He, arXiv 2605.10938, 2026-05-11) = 在冻结的 T5-small contextual embedding 空间里做连续时间 Flow Matching,只在最后一步离散化; denoiser 和 final-step decoder 共享同一个 Transformer 权重。 一句话定位:"continuous DLM 的瓶颈可能不在'连续'本身——而在过去工作把 encoder 放进训练联合学。ELF 把它冻结,diffusion 只学如何在已有几何中 transport,在 ablation 中显著优于 learnable encoder 选项。"
核心五个数(全部来自 paper):
Gen-PPL = 模型自己采样 → GPT-2 Large 给打分的 perplexity,低 = 生成像自然语言;通常配 entropy 看防 mode collapse。
三条可带走的 takeaway(读完这篇 paper 应该被 update 的认知):
历史定位:2026 上半年 continuous-DLM 出现 6+ 篇集中工作。FM 家族按发布时间排: CFM (Feb)、FLM/FMLM (Feb)、DFM (Apr)、LangFlow (Apr)、ELF (May); 外加 ByteDance Cola-DLM (May 末) 的 latent-VAE 路线。 ELF 在 FM 家族里是 已知 size 最小(105M)、设计最干净(无 distillation / 无 latent VAE)、 且无蒸馏 32 步即接近 dataset reference PPL;Cola 是最大尺度、最复杂、reasoning-focused。 详见 §3 (FM 家族 5-way) 和 §9 (Cola 单独详细对比)。
| 年份 | 工作 | 主要贡献 | 评测尺度 |
|---|---|---|---|
| 2021 | D3PM (Austin et al. NeurIPS) | 定义离散扩散框架:mask / uniform / embedding transition | text8, LM1B 小尺度 |
| 2024 | SEDD (Lou et al. ICML Best Paper) | Score entropy loss 给离散空间一个 clean 损失 | OWT scale |
| 2024 | MDLM (Sahoo et al. NeurIPS) | Masked diffusion ELBO ≡ weighted CE — 极大简化训练 | OWT scale |
| 2025 | LLaDA (Nie et al.) | 第一个 8B-scale 离散 DLM,证明 scaling 可行 | 8B params |
| 2025 | Dream 7B (Ye et al.) | 大规模 diffusion LM(具体机制细节见原文) | 7B params |
| 2026 | 外部 landscape | discrete:BD3-LM (semi-AR block) / ReMDM / PRISM(test-time remasking);continuous-latent:VADD / LADD / HDLM / CADD;inference-time coupling:CoDD-style PC layer | 各种 scale |
这套路线都把 D-LLM 定义为:mask/uniform 离散腐蚀 → 逐位置独立 softmax reverse。 随着工作越做越多,发现它撞上四堵墙:
Token 在嵌入前是孤立范畴点,几何邻近性需要从零学。
对比连续扩散在 image / video 上,模型可以利用像素几何("红"和"暗红"邻近),
但离散 token 空间"猫"和"狗"是几何上无关的两个 one-hot vector。
模型必须把所有词之间的语义距离从训练数据慢慢"学"出来。
结果:参数效率低 — 一个 7B 离散 DLM 学完 token 几何后剩下的容量才用来学语言。
标准 D-LLM 的 reverse 是每个 token 独立的 softmax(可类比为图模型的 fully factorized reverse head,treewidth-0)。 真实联合后验 P(X0|Xt) 的所有 token 间依赖被一刀切掉。具体症状(来自论文 App C 描述的 poor generation 区域):
Categorical score / transition 在数学上没有干净的 x / v / ε 对应物。 Image diffusion 圈两年发展出来的工具——self-conditioning, classifier-free guidance, flow matching, rectified-flow, EDM-style noise schedule——没有一个能直接搬到离散空间。 每个都要重新设计,速度被拖慢。
每步都要 unmask / 取 argmax / round;离散误差不可微地累积。 经验上需要很多步才能 recover:ELF Fig 7(a) 对 MDLM / Duo 用了 1024-step baseline 才达到与 ELF-B 32-step 持平的 PPL。 即使做 distillation(MDLM+SDTT, Duo+DCD),few-step variant 也只能到 32-64 步且 PPL 退化。 作为对比,image diffusion 已经发展出 consistency model / mean-flow / flow-distillation 等 1-2 步直采路线。
| 方法 | 在哪个接口"修" | 问题 |
|---|---|---|
| CoDD | 用 tractable probabilistic / circuit layer 替换或增广 factorized output | 训练时还可能是 factorized;inference-time coupling 不改根本 |
| CoDAR | Continuous latent diffusion + fixed encoder + separate / contextual AR decoder (ELF Tab 2 把它列在 latent diffusion 类) | 引入 AR decoder,部分丢了 DLM 全并行性 |
| E2D2 / 类 block diffusion | Semi-autoregressive:block 内 joint denoising,block 间 AR | block-AR 牺牲并行性 |
| VADD / LADD / HDLM / CADD (外部) | 加 latent / hierarchical 结构 | 架构越来越复杂,没有 clean theoretical story |
| ReMDM / PRISM-style | test-time remasking / search | 训练阶段不变,只动推理 |
这些都是 修——保持"离散 + factorized"这个根基不变,在边缘 patching。 ELF 的回答更激进:跳出离散空间。如果 token 嵌入已经被 T5 学好了, 为什么 diffusion 还要在离散符号上跑?为什么不直接在 T5 那个连续 + 有几何结构的空间里跑 flow?
ELF 在两个意义上是 continuous(paper §2 明确说 "continuous in two senses"):
但 ELF 还有一个关键的"冻结":encoder 不学。这是它和过去同类工作的最大区别。
很容易混淆"每个 token 位置对应一个 vector"是不是指"每个词有一个 vector"。不是。先理清三个数字:
| 数字 | 含义 |
|---|---|
| 1024 | 序列长度——位置数量("句子里有 1024 个 token 位") |
| 512 | T5-small contextual embedding 维度——每个位置的 vector 有多少维 |
| 32128 | 词表大小——T5 SentencePiece 总共 32128 种不同的 token |
ELF 的 tensor 流(OWT, L=1024, D=512):
input_ids: [B, 1024] ← 1024 个 token ID(每个是 0~32127 之间的整数)
↓ T5-small frozen encoder
T5 contextual emb: [B, 1024, 512] ← 每个"位置"得到一个 512-d vector
↓ + noise
z_t (noisy): [B, 1024, 512] ← ELF denoise 就在这个 tensor 上
↓ 32 步 Flow Matching denoising
clean embedding: [B, 1024, 512] ← 终点 ≈ 干净的 contextual embedding
↓ decoder head(512 → 32128 logits)
vocab logits: [B, 1024, 32128] ← 只在 t=1 这一步才出现
↓ per-position argmax
output token IDs: [B, 1024]
关键澄清:
[B, 1024, 512] 这个连续 tensor。对比离散 DLM(MDLM / LLaDA):
| 离散 DLM (MDLM) | ELF | |
|---|---|---|
| Denoise 状态 | token IDs [B, 1024](离散) | contextual embedding [B, 1024, 512](连续) |
| 每步输出空间 | 32128 vocab 上的 softmax 分布 | 512-d 连续 velocity |
| 词表 32128 出现频率 | 1024 步每步都做 vocab classification | 只在 t=1 一次 |
这就是 §2.4 / §9 里说的"ELF 跳到 continuous embedding,但仍然每个位置一个 vector,最后做 per-position CE"—— ELF 的连续性已经改了(中间不再 round 到 vocab),但 factorization 在 final CE 那一步仍然 per-position 独立(每个位置独立 softmax,没有 joint posterior 跨位置耦合)。
过去 5 年的 continuous DLM 路线(paper Tab 2 p.15 给了完整 landscape)分两条线:
共同问题:
ELF 在 Tab 2 里是唯一一行 "fix enc / no train per-step discr / no infer per-step discr / no separate decoder" 同时为否的——架构上最干净的设计点。
ELF 的 framing:把"语义几何"和"transport 动力学"两件事在架构上分离:
过去的方法:拍电影 + 现场标定摄像机 + 同时调灯光 — 三件事一起做,一砸就一片乱。
ELF:用已标定好的摄像机 (T5),专心拍电影 (diffusion transport)。
"语言什么样"的问题已经被 Google 用 T5 预训练(~1T tokens)解决了,
ELF 不再重复学这个,把所有训练算力都集中在 transport 上。
45B vs 524B 训练 token 的差距就是这个 framing 的直接结果。
ELF 上面这套"冻结 encoder 才稳"的论证,在 2 周后字节的 Cola-DLM(2026-05 末)那里被正面反驳。 Cola 选择joint training——VAE encoder 和 DiT prior 一起从 Stage 2 开始联合优化,不冻结 encoder。 他们的 RQ2/RQ3 ablation 关键发现:
所以 frozen-vs-joint 这个设计选择没有一个干净的答案。两边都对,只是 regime 不同:
| 规模 | 谁的 ablation 占优 | 解读 |
|---|---|---|
| ~100M-650M 参数 + ~45B tokens | ELF frozen | encoder 容量已够,再 unfreeze 反而拖慢 transport 学习 |
| ~2B 参数 + ~2000 EFLOPs | Cola joint | 大算力下 encoder 也有 headroom,joint co-adapt 释放更多潜力 |
注意两边都没测对方的 regime——ELF 没在 2B+ 规模 sweep joint,Cola 也没在 100M 规模测 frozen。所以这个 tension 严格说是open question。详细分析见 §9 Cola-DLM 对比。
| 方面 | T5-small (ELF 用) | 其它候选 |
|---|---|---|
| 大小 | 35M (encoder-only) | BERT-base 110M、RoBERTa 125M、Sentence-BERT 110M |
| 训练任务 | Span-corruption denoising 预训练;后续 text-to-text multi-task transfer | MLM (BERT) / 对比 (Sentence-BERT) |
| 表示性质 | contextual(同一个 token 在不同句子里不同向量) | BERT 也 contextual;static embedding (word2vec/GloVe) 则不 |
| 几何性质 | span corruption 训练让模型学到"什么 span 合理",几何比较平滑 | BERT MLM 也类似 |
| Vocab | SentencePiece 32128 | BERT WordPiece 30K |
| 是否 generative | encoder-only 用法 | — |
论文没有 sweep encoder 选择(只 ablate T5-small / base / large 三个 size,App C.1)。 这是 ELF 最容易被 reviewer 攻击的地方:能不能换成 BERT? Sentence-BERT? CLIP text? 一个 LLaMA-7B hidden state? 猜测:T5 的 span corruption 让 hidden state 几何特别平滑,恰好适合 flow matching。 但这是猜测,不是 paper claim。
一句话总结:ELF 没有自己学语言的几何,它直接搭便车坐在 T5 已经画好的"文本流形"上做 transport。
最早的词向量做法(word2vec / GloVe / Diffusion-LM 用的 learned embedding matrix):
E,形状 [32128, 512](vocab × dim)E[4321]所以无论是 "I deposited money in the bank"(金融机构)还是 "We walked along the river bank"(河岸),`bank` 都被映射到同一个固定向量。模型自己没法从 embedding 那一层看出"这个 bank 是哪个意思"——只能靠后面的 transformer 自己消歧。
预训练 encoder(BERT / T5 encoder 等)的做法:跑一遍 transformer encoder over the whole sequence,每个 token 位置的输出向量都受整句话影响。
input_ids: [B, 1024] 一个句子,1024 个位置
↓ T5 encoder (6 层 self-attention + MLP)
output: [B, 1024, 512] 每个位置一个 512-d 向量,
且这个向量是 attention over 所有位置算出来的
所以同样是 "bank":
| Static embedding | Contextual embedding | |
|---|---|---|
| 存储 | lookup table [V, D] | encoder forward [L, D] per sequence |
| "bank" 在不同句子里 | 同一个向量 | 不同向量 |
| 信息量 | V 个固定点(vocab=32128) | 几乎无穷多个点(每句话每位置都不同) |
| 几何结构 | 学到的"词类别"聚类 | 语义+句法消歧后的"上下文状态空间" |
| 来源 | 训练时学一张表 | 跑一遍 frozen encoder |
T5 预训练(~1T tokens)已经把文本数据组织成了一个有结构的 512-d 流形:语义相近的句子 / 位置在这个空间里几何上也相近,上下文消歧、词性、句法关系都被编码进了向量的位置和方向。
ELF 把 T5 encoder 冻住当"坐标系"——diffusion 的工作变成了"在这张已经画好的语义地图上做 transport",而不是"先画地图再 transport"。
对比之下:
因为 ELF 实际能用的"语言知识"远不止 45B,而是 45B + T5 那 1T tokens 的预训练迁移。Codex GPT-5.5 xhigh 在 cross-model review 里也直接点了这个:
"45B 'tokens' exclude pretrained T5-small prior; frozen encoder doing real work. What at 2B from scratch?"
—— traces/research-review/2026-05-26_run01/codex_elf_vs_cola.md
代价:上限被 T5-small 锁死。换 7B encoder 还有效吗?scale 到 LLaMA-70B 行不行?这是 ELF 的命门,也是 Gemini auto-gemini-3 review 标的"T5 的生成式插件而非通用架构"那个点(详见 §10 与 ELF 的命门讨论)。
整个 flow 从 ε(标准高斯 × noise_scale=2.0)→ clean embedding 都在连续空间。 但 t=1 时模型必须输出离散 token。ELF 的做法:
decoder_step_active=True,这次 transformer 输出经过 factored decoder head(768 → 512 GELU → 32128)映射到 vocab logits关键:denoise 主干 + 离散化 decoder 是同一份 transformer 权重,靠 4 个 mode token 切换。 论文 App C.4 显示这种 in-context conditioning 比 adaLN-Zero 略好且省 43M 参数(详见 §5)。
2026 年 2-5 月,一共出现了 5 篇基于 Flow Matching 的语言模型工作, 其中 ELF 是 paper 自己 Tab 2 显示的最后一行。这一节把这 5 篇放在一起对比,明确 ELF 在 FM 家族里的独特定位。
CFM — Categorical Flow Maps (Roos et al., UvA/Oxford, arXiv 2602.12233, Feb 2026)。 核心决策:把 categorical generation 写成"endpoint-prediction flow map", 模型预测 simplex 上的 endpoint 分布 πs,t,再用 self-distillation 做少步生成。 评测 Text8 NLL + LM1B Gen-PPL,单步 LM1B Gen-PPL 274.87。
FLM / FMLM — Flow Map Language Models (Lee et al., KAIST + CMU, arXiv 2602.16813, Feb 2026)。 核心决策:在 token-wise one-hot 连续空间建模,simplex-valued denoiser,CE 后验预测; FMLM 是 FLM 蒸馏后的少步版本。LM1B + OWT 评测, FLM 1024-step LM1B Gen-PPL 96.91、OWT 62.23;FMLM 单步 LM1B 119.34、OWT 168.30。
DFM — Discrete Flow Maps (Potaptchik et al., Harvard/Oxford/MIT/NYU, arXiv 2604.09784, Apr 2026)。 核心决策:把 flow map 从 average velocity 重参数化为 mean denoiser ψs,t, 让 off-diagonal flow map 自然落在 probability simplex 上。直接批判欧氏 L2 与概率几何的不匹配。 LM1B 评测,DFM-ESD 单步 Gen-PPL 68.11(entropy 3.79);DFM-PSD 单步 94.08(entropy 4.06)。
LangFlow — Chen et al., UIUC, arXiv 2604.11748, Apr 2026。 核心决策:回到 learned embedding-space,用 Bregman-divergence 把 token CE 解释为 Flow-Matching posterior matching; 推出 ODE-based NLL upper bound,第一次给 embedding-space DLM 一个可信的 likelihood。 LM1B PPL 30.0 / OWT PPL 24.6(注意:是 held-out NLL upper bound,不是 Gen-PPL); 对应 Gen-PPL 是 LM1B 92.2 / OWT 36.5。
ELF — Embedded Language Flows (Hu/Qiu et al., MIT, arXiv 2605.10938, May 2026, 主角)。 核心决策:在冻结的 T5-small contextual embedding里做 FM,只在 t=1 用共享 transformer 切到 decode mode。 80% MSE + 20% CE,无蒸馏。OWT 32-step SDE Gen-PPL 24.08 ± 0.16。
| 维度 | CFM | FLM/FMLM | DFM | LangFlow | ELF |
|---|---|---|---|---|---|
| State space | Probability simplex ΔK(endpoint predictor πs,t) | one-hot 连续,simplex-valued denoiser | Mean denoiser ψs,t: ℝK → ΔK-1 | Learned token embedding (V×D matrix) | Frozen T5-small contextual embedding (512-d) |
| Interpolant 几何 | Straight-line stochastic interpolant | Gaussian/one-hot interpolant + simplex repar | Linear interpolant (β-reparameterized) | VP γ-path, deterministic ODE | Rectified linear: zt = t·x + (1−t)·ε |
| Training loss | CE diagonal + endpoint consistency (CSD/ECLD) | FLM: CE posterior; FMLM: KL/CE flow-map distillation | CE diagonal + PSD/LSD/ESD KL consistency | CE-as-Bregman (FM posterior matching) | 80% MSE on velocity + 20% CE on decode |
| 默认步数 | 1-step 主推(NFE 1-64 sweep) | FLM: 1024-step / FMLM: 1-step | 1-2-4-8 step | 128-step (LM1B) / 1024-step (OWT) | 32-step SDE (headline) / 64-step (scaling) |
| 是否蒸馏 | ✅ Self-distillation (CSD/ECLD) | ✅ FMLM 从 FLM teacher 蒸馏 | ✅ Diagonal 1M + off-diagonal 200k/100k | ❌ 无 distillation,明说留给未来 | ❌ 无 distillation |
| Time grid | Logit-normal w/ 0.75 diagonal fraction | Decoding-error-rate reparameterization | argmax-linearized + convex mix β̃(t) | Learnable Gumbel scheduler | Logit-normal (P_mean=−1.5) |
| Endpoint 处理 | Simplex endpoint π; argmax 或采样 | Simplex posterior; argmax | Simplex mean denoiser; softmax | Argmax over token probability | 共享 transformer t=1 decode mode + argmax |
| Headline eval | LM1B 1-step Gen-PPL 274.87 | FLM OWT 1024-step 62.23; FMLM 1-step 168.30 | LM1B 1-step Gen-PPL 68.11 (ESD) | OWT Gen-PPL 36.5 @ 1024-step (PPL upper-bound 24.6) | OWT 32-step Gen-PPL 24.08 ± 0.16 |
§3.2 表第一行的 State space 是 5 篇 paper 的核心分歧点。CFM / FLM / DFM 一类说自己"在 simplex 上走 flow", LangFlow 说自己"在 learned embedding 上走",ELF 说自己"在 contextual embedding 上走"——这些到底什么意思?把它讲透。
一个 token 表示成 vocab-size 长的向量,只有一位是 1,其他全是 0:
vocab V = 32128
token "bank" (id=4321) → one_hot 长度 V 的向量
= [0, 0, ..., 0, 1, 0, ..., 0]
↑ 第 4321 位
整个句子(L=64 个 token)→ 形状 [64, 32128],每行是一个 one-hot。
V 维空间里所有"概率分布"组成的集合:长度 V 的非负向量,且各分量之和 = 1。直观理解:
V=3 时的 simplex(直观图):
(1,0,0) ← token A 的 one-hot
/\
/ \
/ . \ ← 内部任意点 = 概率混合
/ .. . \
/________\
(0,1,0) (0,0,1)
token B token C
Flow matching 的核心是 zt = t·x + (1−t)·ε——从噪声 ε 走到数据 x。 关键问题是:x 长什么样、ε 长什么样、在哪个空间里走。
| 路线 | x 长什么样 | 走 flow 的空间 | 哪几篇 |
|---|---|---|---|
| One-hot / Simplex | V 维 one-hot(或 simplex 内的概率向量) | L × V(V ≈ 32k) | CFM, FLM/FMLM, DFM |
| Learned embedding | D 维 static 向量(embedding lookup) | L × D(D ≈ 128) | LangFlow, Diffusion-LM |
| Contextual embedding | D 维 context-aware 向量(encoder 输出) | L × D(D = 512) | ELF |
想象 V=32128 维空间里数据点 x 长什么样:
| 路线 | 数据 manifold 形状 |
|---|---|
| ELF(contextual) | T5 已组织好的"语义流形",相近词在附近 |
| LangFlow / Diffusion-LM(static) | 学一张 lookup 表,V 个固定点的离散点云 |
| CFM / FLM / DFM(one-hot/simplex) | V 个相互垂直的"角"(汉明距离 = 2,谁离谁都一样远) |
最关键的一点:one-hot 空间里,"bank" 和 "dog" 的欧氏距离 = "bank" 和 "shore" 的欧氏距离 = √2。没有任何语义几何——token 之间的相似性必须由 Transformer 自己从头学出来。
假设词表只有 4 个 token:river / bank / money / shore
river: [1, 0, 0, 0]
bank : [0, 1, 0, 0]
money: [0, 0, 1, 0]
shore: [0, 0, 0, 1]
所有 token 互相垂直,"bank-river" 跟 "bank-money" 完全等距。Transformer 必须从一堆垂直 one-hot 输入中,靠 attention 自己摸索出 "bank" 在 "river" 旁边时该往哪边走回到 §2.4 那个比喻:
这就解释了为什么 CFM / FLM / DFM 都需要更多 trick:
对照之下 ELF 用 MSE 是合法的——因为 contextual embedding 不是概率分布,是带几何结构的语义向量,欧氏距离反映语义距离。
ELF 是唯一不学 embedding 也不约束在 simplex 上的方案。这是它最独特的设计点。
ELF 选 "32-64 步直接采样" 这条路而不蒸馏,是反潮流的—— 当时所有 LM-FM 工作都在卷"少步"。
ELF 的 MSE 合法性建立在"embedding 不是分布"这个根本前提。这也是它和其它 4 篇的分水岭。
注意:LangFlow LM1B 30.0 / OWT 24.6 是 ODE NLL upper bound,不是 Gen-PPL。 和 ELF 的 24.08(Gen-PPL)不可直接横比。可比的是 LangFlow OWT Gen-PPL 36.5(1024 步)vs ELF 24.08(32 步)。
不对称的原因:CFM/FMLM/DFM 的状态在 simplex 或 one-hot 上,flow map 有 clean 的几何对象(mean denoiser)可学;ELF 的状态是 T5 embedding,要做 flow map distillation 需要先证明 T5 embedding 上的 flow map 也有 clean 形式——paper 没做。
ELF vs CFM:两者都把离散文本放到连续 FM 里,但CFM 的连续对象是 simplex-valued endpoint,ELF 的是 T5 contextual embedding。 CFM 押注"少步 self-distillation",单步 LM1B 274.87;ELF 押注"无蒸馏 32 步",OWT 24.08。 两条完全互补的路线——蒸馏路 vs 采样路。(点击此标题 → 滚到 §3.2 表 + 高亮 CFM 列)
ELF vs FLM/FMLM:两者都做 continuous flow LM,也都在 OWT 上用 Gen-PPL; 但 FLM 用 V 维 one-hot(vocab 越大状态越大),ELF 用 512 维 T5 embedding + 128 bottleneck。 FLM 的卖点是天然可蒸馏到 FMLM;ELF 的卖点是无蒸馏即少步。 直接对比:FLM 1024 步 OWT Gen-PPL 62.23 vs ELF 32 步 24.08 —— ELF 完胜,但代价是放弃了 vocab-level 状态空间的"天然性"。
ELF vs DFM:两者都认真做 ablation,但变量完全不同—— DFM 在变 PSD/ESD consistency 把 flow map 压到 few-step;ELF 在变 ODE/SDE/γ/SC-CFG 把 base flow 调到 32 步。 DFM 的强点是 1-4 NFE;ELF 的强点是无蒸馏 32 NFE。再次互补。
ELF vs LangFlow:两者都在 embedding-space 做 FM,也都报 OWT 数字;但 LangFlow 的 24.6 是 held-out NLL upper bound,ELF 的 24.08 是 Gen-PPL,不能直接比。 可比项:LangFlow OWT Gen-PPL 36.5 (1024 步) vs ELF 24.08 (32 步) —— ELF 在生成质量上明显优。 反过来,LangFlow 有 likelihood story(可信 NLL bound + 4/7 zero-shot benchmark 超过 AR),ELF 目前没有对应 likelihood claim。 两边互补。
这 5 篇本质在回答同一个问题:语言 FM 应该把"连续性"放在哪里?
Simplex(CFM/DFM)、one-hot(FLM)、learned embedding(LangFlow),还是 frozen contextual embedding(ELF)?
ELF 押注最后一种,并用无蒸馏 32 步 OWT Gen-PPL 24.08 证明这条路在 sample quality 上最有说服力。
本节先把 ELF 网络当成一个抽象函数:
x̂ = netθ(z, t, c, ω, mode)
U(·) 拿到 vocab logits(decode mode 时用)这一节不依赖具体架构——只要 netθ 是个能接受 (z, t, c, ω, mode) 的可学函数即可。具体 Transformer 实现(T5 encoder、DiT block、RoPE、bottleneck、factored decoder head 等)下一节 §5 展开。
ELF 用的是 x-prediction parameterization(App C.1 ablation 选出来的—— x-pred 在高维比 v-pred / ε-pred 稳定,因为"clean text 在低维流形上"):
# sampling_utils.py:115-127
def net_out_to_v_x(net_out, z, t, t_eps=5e-2):
x = net_out # ← 网络输出直接就是 x̂(clean embedding pred)
denom = torch.clamp(1.0 - t, min=t_eps)
v = (x - z) / denom # ← velocity 是后处理算出来的
return v, x
但这不代表跳过逐步去噪。推理仍然 32 步 Euler 迭代:
for step i in range(32):
x̂ = net_θ(z_t, t, c, ω, "denoise") # 每步预测 clean embedding
v = (x̂ - z_t) / (1 - t) # 由 x̂ 换算瞬时 velocity
z_t = z_t + dt · v # Euler 走一小步
# 最后一步:x̂ = net_θ(z_t≈x, t=1, c, ω, "decode") → unembed → token
三件事要分清:
| 真值 | |
|---|---|
| 网络直接输出什么 | clean embedding x̂(一直在预测最终目标 x0) |
| Flow Matching 框架下的 transport 量 | velocity v |
| 推理过程 | 仍然 32 步 Euler 迭代;每步用网络的 x̂ 算 v 再走一小步 |
为什么这么设计:
‖v_pred − v_target‖²,但 v_pred / v_target 都从 x_pred / x0 换算,梯度其实在监督 x_pred → x0和 consistency model / mean flow 的区别:那些工作目标是真的 1 步从 noise 直接跳到 x(学一个"时间无关的 x-prediction")。ELF 没走那条路——仍是 32 步,但每步的局部预测对象是 x 而非 v。FMLM 就是 FLM 的 consistency-distilled 1-步版本;ELF 的 future work 也提了这个方向。
论文 App B.1 把 Alg 3 / 4 写得像两条独立 pipeline。PyTorch port 实际是这样执行(按 train_step.py):
input_ids → x₀ ∈ [B, L, 512],(x₀−μ)/0.2t,加噪 z = t·x + (1−t)·ε·2.0p,z̃ = p·x + (1−p)·ε·5.0z_mixed[row] = decoder_z if decoder_step_active[row]==1 else denoiser_z.detach()关键认识:CFG target 的第二个 no-grad forward 在主 gradient forward 之后。 数学目标不变(用 v_target 监督 v_pred),但实现顺序不是"两个 no-grad 然后一个 grad"。
ELF 用两条不同的 loss 学两件不同的事。这是 ELF 整篇 paper 最容易被误解的点:
| Loss | 学什么 | 作用 | 训练频率 |
|---|---|---|---|
| LMSE(denoiser) | 在 连续 embedding 空间里 transport 噪声到 clean embedding | 学"动力学"——给定 noisy zt 和时间 t,怎么把 zt 沿 flow 推到 x。这是 Flow Matching 的核心。 | 80%(per-example Bernoulli(0.2) 抽 decoder 模式,剩下都是 denoiser) |
| LCE(decoder) | 在 离散 token 空间里把 clean embedding 投影回 vocab | 学"离散化"——把 flow 终点 (clean embedding ≈ x) 映射到 32128 个 token id。这是 ELF 唯一触及离散 token 的训练步骤。 | 20% |
为什么必须有 CE? 你既然懂 MSE,问题就清楚了: MSE 只能让模型预测 clean embedding,但下游任务要的是生成 token。 如果只有 MSE,推理时拿到 clean embedding 后没办法回到 token(找最近的 T5 embedding 找最近邻?精度差且不可微)。 所以必须额外训练一个 decoder head,把 embedding 映射回 vocab logits —— 这个 head 必须用 CE 训练(因为 vocab 是离散的)。 ELF 的优雅之处是:decoder head 和 denoiser 共享同一份 transformer 权重,只用 4 个 mode token 切换 mode。
所以 ELF 的训练目标本质是:
Ltotal = 𝔼(s, c) [ (1−pdecode) · LMSE(transport) + pdecode · LCE(round-to-token) ]
其中 pdecode = 0.2 是 decoder 分支抽样概率(论文 Tab 4 默认)。
Self-conditioning(来自 Chen, Zhang, Hinton, "Analog Bits", ICLR 2023,ELF 引用 [9])是 diffusion / flow matching 的一个推理时迭代精修技巧:
普通 diffusion 每一步 forward 只看 zt:x̂t = net(zt, t)。
Self-cond 让网络额外接收上一步的预测作为输入:x̂t = net(zt, t, x̂t−1)。
推理时这相当于"我已经有一个 partial 估计,refine 它",比从零预测更稳。
但训练时模型并没有"上一步预测" — 因此训练用以下 trick 模拟:
stopgrad(x̂no_sc),对这次反传梯度。
这样模型学到"给 x̂prev 作 input 时怎么 refine"。为什么要 stopgrad?因为不希望梯度通过第一次 forward 回流——
那会让训练目标变成"让 x̂no_sc 也参与优化",破坏 mathematical setup。
ELF 的 self-cond 完全沿用 Chen et al. 这个标准做法。
实现上 self-cond 把网络输入维度从 D 扩到 2D(concatenate [z, x̂prev]),
然后用一个 self_cond_proj: 2D → D 线性层压回 D。
所以你在 Alg 3 看到的 self_cond_proj(concat([z, ...], dim=-1)) 就是这个压回操作;不是 ELF 独创。
有 5 个理由,按重要性排序:
(1−1/ω)·(v_cond − v_uncond) 里:
stopgrad(x̂_no_sc)没 self-cond 就没 SC-CFG,没 SC-CFG 就没有"推理只跑 1 次 forward"的训练时 CFG 优化(见 §4.4)。 所以整套 Eq 3 训练时 CFG trick 都建立在 self-cond 之上。
Alg 3 里 self-cond 看起来占大半篇幅是因为记账复杂,不是概念复杂。剥掉 plumbing 后本质就是 "plain FM + Chen 2023 的 self-cond + Eq 3 的训练时 CFG"。下面这 5 个 step 都是 implementation 细节,不是 5 个独立 ELF 创新:
| Alg 3 里的步骤 | 它在干嘛 |
|---|---|
x_no_sc = net(concat([z, 0])) | no-cond reference,self-cond 输入填 0 |
stopgrad(x_no_sc) | 不让梯度回流第一次 forward(否则训练目标污染) |
x_sc = net(concat([z, stopgrad(x_no_sc)])) | self-cond reference forward |
(1−1/ω)·(v_sc − v_no_sc) | SC-CFG guidance 烤进 v_target(Eq 3) |
where(self_cond_mask, ..., ...) | 50% 概率二选一(让模型也学"没 prior" 的情况,覆盖推理第 1 步) |
一句话:self-cond 是 ELF 站在 Chen 2023 + DiT-style 训练时 CFG 这两个肩膀上的"工程接缝"。 拿掉 self-cond 也能跑 plain FM,但你会失去 (a) 32 步达到同质量的能力 (b) Eq 3 训练时 CFG 这两个 ELF 系统级卖点。
论文 App B 写成两个独立算法。PyTorch port 改成 per-example mix,数学上等价(见 §4.6)。两个算法的关键差异:
| Denoiser (Alg 3) | Decoder (Alg 4) | |
|---|---|---|
| Mode token gate | 0(无 mode 信号) | 1(mode token 激活) |
| 时间 t | per-sample t = σ(N(Pm=−1.5, Ps²=0.8²)) | t = 1(始终终点) |
| Corruption ratio | per-sample t | per-token p = σ(N(0.8, 0.8²))(独立!) |
| Noise scale | 2.0 | 5.0 (OWT) / 1.0 (XSum, WMT) |
| Self-cond input | 50% stopgrad(x'); 50% zeros | 始终 zeros(不学 self-cond) |
| Loss | MSE on velocity(base + CFG-augmented target) | CE per token |
| Output head | FinalLayer (768→512 flow output) | Factored decoder (768→512→32128) |
Algorithm 3 伪代码(denoiser,paper App B.1,去 LaTeX):
# Algorithm 3: ELF denoiser training with conditioning and guidance
# net(z, t, c, w, mode): ELF network with in-context conditioning
# self_cond_proj(z): concat-to-original-dim projection
# self_cond_prob: 0.5
# s: discrete token sequence
# c: condition (optional, only for XSum/WMT)
x = encode(s) # T5 frozen forward
t = sample_t() # logit-normal scalar per sample
w = sample_sc_cfg_scale() # ω ∈ [0.5, 5], paper: power-biased
# PyTorch port: shifted log-uniform
e = randn_like(x) # standard Gaussian(与 paper 一致)
z = t * x + (1 - t) * e # rectified-flow interpolant
v = x - e # base velocity target
# (1) 不带 self-cond 的 forward (no_grad)
z_no_sc = self_cond_proj(concat([z, zeros_like(z)], dim=-1))
x_no_sc = net(z_no_sc, t, c, w, mode="denoise")
v_no_sc = (x_no_sc - z) / (1 - t)
# (2) 带 self-cond 的 forward (no_grad, stopgrad on x_no_sc)
z_sc = self_cond_proj(concat([z, stopgrad(x_no_sc)], dim=-1))
x_sc = net(z_sc, t, c, w, mode="denoise")
v_sc = (x_sc - z) / (1 - t)
# (3) CFG target: post-combination quantity
v_target = v + (1 - 1/w) * (v_sc - v_no_sc)
# Per-example self-cond mask
self_cond_mask = uniform(B) < self_cond_prob
v_pred = where(self_cond_mask, v_sc, v_no_sc)
v_target = where(self_cond_mask, v_target, v)
v_target = stopgrad(v_target)
loss_denoise = mse_loss(v_pred, v_target)
看 Alg 3 最容易迷的就是 (3) 之后 where 和 stopgrad 那 5 行。它们做了 4 件事:
# (3) CFG target: post-combination quantity
v_target = v + (1 - 1/w) * (v_sc - v_no_sc) # 计算"带 CFG"的 target
# Per-example self-cond mask
self_cond_mask = uniform(B) < self_cond_prob # 每个 example 独立抽 Bernoulli(0.5)
v_pred = where(self_cond_mask, v_sc, v_no_sc) # gradient forward 走哪条
v_target = where(self_cond_mask, v_target, v) # target 选哪个
v_target = stopgrad(v_target) # 截梯度
关键 insight:这是同一个 batch 同时训练两种模式。 mask 把 batch 切两半:50% 走 self-cond + CFG 路径,50% 走 plain FM 路径。 这两条路径必须用同一个 mask挑 v_pred 和 v_target,否则配对错乱。
| mask 值 | 输入有 self-cond? | v_pred (gradient) | v_target | 网络学到什么 |
|---|---|---|---|---|
| True (50%) | 有,self-cond 输入 = stopgrad(x̂_no_sc) | v_sc(带 self-cond 的 forward) | (x − ε) + (1 − 1/ω)·(v_sc − v_no_sc) | "给 prior 估计 + ω,输出 CFG-amplified velocity" |
| False (50%) | 没有,self-cond 输入 = 0 | v_no_sc(不带 self-cond 的 forward) | v = (x − ε)(base velocity) | "没 prior,给 plain FM velocity"(推理第 1 步用得到) |
no_grad 上下文里算出来的(已经无梯度),v = (x − ε) 来自 leaf tensor x 和 ε 也无梯度。
所以理论上 v_target 本身就没有 gradient flow 进网络。
stopgrad 是防御性的——(a) 文档上明确"v_target 是 supervision,不可微";(b) 防止有人 fork 代码后去掉 no_grad 时还能保住语义。
这是好的工程习惯。就是 mse_loss(v_pred, v_target) 算 L2 距离:
mse_loss 展开是什么?就是 velocity 上的 L2 距离 + per-token mean,标准 Flow Matching 损失。具体计算:
# v_pred ∈ [B, L, 512],来自 ELF 主干 gradient-tracked forward 后转 velocity:
# x_pred = ELF_transformer(z_self_cond, t, c, ω, decode_mode=False)
# v_pred = (x_pred - z) / clamp(1 - t, t_eps) # paper Eq 4
# v_target ∈ [B, L, 512],由 Eq 3 训练时 CFG 构造(前面 4.4 节详):
# v_base = (x_0 - z) / clamp(1 - t, t_eps) = (x - ε·noise_scale)
# v_target = v_base + (1 - 1/ω) · (v_cond - v_uncond)
# v_target = v_target.detach() # 梯度只走 v_pred
l2_per_token = ((v_pred - v_target) ** 2).mean(dim=-1) # [B, L] channel-wise mean
l2_per_token *= loss_mask # 排除 padding + cond positions
数学上 ELF 的 MSE 等价于:
LMSE = 𝔼(x, c) 𝔼t, ε [ Σi ∈ valid ‖ vθ(zt, x'prev, t, c, ω)i − vtarget,i ‖22 / D ]
其中:
| 符号 | 含义 |
|---|---|
| x | frozen T5-small encode 后的 contextual embedding,归一化后(× 1/0.2 = ×5) |
| zt | noisy embedding:zt = t·x + (1−t)·ε·noise_scale,其中 noise_scale = 2.0 |
| t | per-sequence corruption 时间,t ~ σ(N(−1.5, 0.64)),logit-normal |
| ε | 标准高斯 ∈ ℝD,D = 512(T5-small encoder dim) |
| vθ | ELF 主干预测的速度场(实际预测 x_pred,再转 v = (x_pred − z) / (1−t)) |
| vtarget | 训练时 CFG 烤进的 target velocity(Eq 3,.detach() 截梯度) |
| x'prev | self-conditioning 输入(50% 概率 = stopgrad(uncond x_pred);50% = zeros) |
| ‖·‖22 / D | channel 维(512)取均值,等价于 per-channel MSE |
| valid 位置 | 排除 padding + 条件生成的 cond positions |
Algorithm 4 伪代码(decoder,paper App B.1):
# Algorithm 4: ELF decoder training with conditioning and guidance
x = encode(s)
p = sample_per_token_p() # logit-normal PER TOKEN (different from denoiser)
w = sample_sc_cfg_scale()
e = randn_like(x) # standard Gaussian(与 paper 一致)
z = p * x + (1 - p) * e # per-token corruption ratio
# decoder always uses zero self-cond input
z = self_cond_proj(concat([z, zeros_like(z)], dim=-1))
h = net(z, t=1, c, w, mode="decode") # mode token gate = 1
s_pred = unembed(h) # factored decoder head
loss_decode = ce_loss(s_pred, s)
ce_loss 展开是什么?就是标准的 per-token cross-entropy,没有任何 ELF-specific 改造。具体计算(来自 src/train_step.py):
# decoder_logits ∈ [B, L, 32128],由 ELF 主干 forward + factored decoder head 给出:
# hidden_768 = ELF_transformer(z̃, t=1, c, ω, decode_mode=True)[B, L, 768]
# hidden_512 = GELU_tanh(hidden_768 @ proj_kernel + proj_bias) # 768 → 512
# decoder_logits = hidden_512 @ unembed_kernel + unembed_bias # 512 → 32128
log_probs = F.log_softmax(decoder_logits.float(), dim=-1) # [B, L, 32128]
ce_per_token = -log_probs.gather(-1, target_ids.unsqueeze(-1)).squeeze(-1)
# ⇔ -log p_θ(s_i | z̃_i, t=1, c, ω, mode=decode) # [B, L]
# 然后 mask 掉 padding 位置 + 条件生成的 cond 位置,再聚合
ce_per_token *= loss_mask # loss_mask = (1 - cond_seq_mask) * attention_mask
数学上 ELF 的 CE 等价于:
LCE = − 𝔼(s, c) 𝔼pi, εi [ Σi ∈ valid log pθ(si | z̃i, t=1, c, ω, mode=decode) ]
关键变量含义:
| 符号 | 含义 |
|---|---|
| si | 第 i 个位置的真实 token id(ground truth) |
| z̃i | 该位置的 per-token corrupted clean embedding: z̃i = pi·xi + (1−pi)·εi·noise_scale, 其中 pi ~ σ(N(0.8, 0.64)) 每个 token 独立抽,noise_scale = 5.0 (OWT) / 1.0 (cond) |
| t=1 | decoder 分支永远在终点;时间 token 编码 t=1 |
| c | 条件输入(XSum / WMT 才有;OWT 无条件 c=∅) |
| ω | SC-CFG scale ∈ [0.5, 5](虽然 decoder 分支不学 SC-CFG guidance,但 ω token 仍 prepend 作为输入) |
| mode=decode | 4 个 mode_token gate=1(denoiser 分支时 gate=0,token 被乘 0) |
| pθ(·) | factored decoder head 输出的 softmax 分布(768 → 512 GELU → 32128 → softmax) |
| valid 位置 | 排除 padding token + 条件生成的 cond positions(cond 不要求模型预测) |
两条分支用不同的时间/腐蚀分布。Paper App B.1 + 代码 src/configs/training_configs/train_owt_ELF-B.yml 默认:
| 分支 | 分布 | P_mean | P_std | Noise scale | 说明 |
|---|---|---|---|---|---|
| denoiser | per-sequence logit-normal | −1.5 | 0.8 | 2.0 | t = σ(N(−1.5, 0.64)) → 偏向小 t(噪声多) |
| decoder (OWT) | per-token logit-normal | 0.8 | 0.8 | 5.0 | p = σ(N(0.8, 0.64)) → 偏向大 p(接近 clean);每个 token 独立 |
| decoder (XSum/WMT) | per-token logit-normal | 0.8 | 0.8 | 1.0 | 条件生成用更小 noise |
核心思路:denoiser 训练时多接触噪声大的样本(学习如何 transport), decoder 训练时多接触接近 clean 的样本(学习如何 round 回 token)。 而且 decoder 是每个 token 独立抽 corruption ratio,模拟推理时不同位置的 reconstruction 质量。
ELF 用 Lipman et al. 2023 rectified-flow 标准插值公式:
z = p · x + (1 − p) · ε · noise_scale
| p 值 | z 长什么样 | 翻译 |
|---|---|---|
| p = 0 | z = ε · 5.0 = 全是噪声 | 完全 noisy |
| p = 0.5 | z = 0.5·x + 0.5·ε·5.0 | 一半信号一半噪声 |
| p = 1 | z = x = 完全 clean embedding | 完全 clean |
所以 p 是"信号比例",1−p 是"噪声比例"。p 越大,z 里 clean x 成分越多。
denoiser 的 t 同理:z_t = t·x + (1−t)·ε·2.0,t=1 是 clean 端点,t=0 是噪声端点。
关键原则:训练分布必须匹配推理分布。两条分支推理时见到的样子完全不同:
| denoiser | decoder | |
|---|---|---|
| 推理时跑在哪 | 整条 32 步 ODE 轨迹,所有 token 位置同步沿 t 推进 | 32 步完后跑一次 forward,把 zt=1 映射到 token |
| 推理时各 token 位置的 corruption 状态 | 同一时刻 t,所有位置共享同一个 corruption level | 不同位置 ODE 落地质量不均——有的 95% 干净,有的 70% 干净(attention 把噪声去除得不均衡) |
| 训练 sample 单位 | per-sequence t ~ logit-normal(−1.5) | per-token pi ~ logit-normal(+0.8) |
| shape | t: (B,) → broadcast 到 (B, 1, 1) |
p: (B, L, 1) 每个位置独立 |
| 动机 | 所有位置同 t 才能让 attention 学到正确的 transport 动力学 | per-token p 训练 decoder 在per-position 残余噪声不均时仍能 round 回 token |
Denoiser 不能 per-token:如果训练时同一序列内不同位置取不同 t,attention 看到"有的位置很 noisy 有的很 clean"——但推理时根本不会出现这种状态。模型会学到错误的 transport 模式。
Decoder 必须 per-token:推理时 32 步 ODE 落地的 zt=1 不是完美 clean,不同位置残余噪声幅度不同(attention 在 32 步里对各位置的去噪进度不同步)。 Per-token pi 训练 decoder看上下文判断该位置真正的 token——paper App B.1 原话: "encourages the shared-weight decoder mode to recover corrupted embeddings from their surrounding context, making final-step discretization more robust to imperfect embeddings produced by the denoiser at inference time."
per-token pi 只通过 z̃i 的"信号/噪声混合比例"隐式传递给网络——网络不直接知道某个位置的 pi 是多少。网络的时间输入是固定的 t=1 scalar,decoder 通过 attention 看周围上下文自动判断该位置可信度。
| denoiser | decoder | |
|---|---|---|
| P_mean | −1.5(偏小 t) | +0.8(偏大 p) |
| σ(P_mean) 中位 | ~0.18(多 noisy) | ~0.69(多 clean) |
| 大部分样本落在 | [0.05, 0.4] noisy 区 | [0.5, 0.95] clean 区 |
| Granularity | per-sequence(match transport 同步性) | per-token(match 终点 per-position 不均) |
| Noise scale | 2.0 | OWT 5.0 / 条件 1.0 |
# sampling_utils.py::add_noise (denoiser)
def add_noise(x0, noise, t, config, cond_seq_mask=None):
t_expanded = t.reshape(-1, 1, 1) # [B, 1, 1]
z = t_expanded * x0 + (1 - t_expanded) * noise * config.denoiser_noise_scale
if cond_seq_mask is not None: # 条件生成时保留 cond token 不腐蚀
z = cond_seq_mask * x0 + (1 - cond_seq_mask) * z
return z
# train_step.py — decoder branch (per-token p sampling)
decoder_z_vals = (
torch.randn((B * L,), dtype=dtype, device=device)
* config.decoder_p_std + config.decoder_p_mean # = N(0.8, 0.8²)
)
decoder_lambda_t = torch.sigmoid(decoder_z_vals).reshape(B, L, 1) # [B, L, 1] per-token
decoder_noise = torch.randn(x0.shape, dtype=dtype, device=device) * config.decoder_noise_scale
decoder_z = decoder_lambda_t * x0 + (1 - decoder_lambda_t) * decoder_noise
Classifier-Free Guidance (CFG)(Ho & Salimans, "Classifier-Free Diffusion Guidance", 2022) 是 image diffusion 的一个条件信号放大器:
训练时让一个网络同时学有条件 vθ(z, t, c) 和无条件 vθ(z, t, ∅)(10% 概率把 c 置空)。 推理时按系数 ω > 1 把两者外推:
vfinal = vuncond + ω · (vcond − vuncond)
直觉:朝"和无条件的方向"反方向多走一步,等价于更强地遵从条件 c。ω=1 是原始条件 forward;ω=3 时条件信号被显著放大(细节更锐利、和 prompt 更贴)。 代价:推理每步要跑 2 次 forward(cond 一次 + uncond 一次),算力 ×2。 这是 DALL·E / Stable Diffusion / Imagen 等图像 diffusion 的标配。
ELF OWT 默认 32 步 × 1 forward = 32 forward。如果套标准 CFG → 64 forward,推理算力翻倍。 更要命的是:ELF 把 sampling step 从 baseline 的 1024 步压缩到 32 步是它的核心卖点,再 ×2 就失去了步数效率优势。
这个 trick 是 Chen 等人 2025 年 image diffusion 的工作(ELF App C 引用)。核心想法:
不在推理时做 CFG 组合,而是在训练时就把 CFG 组合烤进 v_target。
让模型直接学 post-combination quantity v_θcfg,推理只需一次 forward 就拿到等价于 CFG 后的 velocity。
| 标准推理时 CFG | 训练时 CFG (Chen 2025 / ELF) | |
|---|---|---|
| 训练 forward 数 | 1(只学 v_cond / v_uncond 之一) | 3(uncond + cond + grad-tracked) |
| 推理 forward 数 / 步 | 2(cond + uncond) | 1 |
| 32 步总推理 forward | 64 | 32 |
| 训练 1.5× ↔ 推理 2× 谁划算 | 训练 5 epoch 1 次,推理跑无数次 — 训练时 CFG 完胜 | |
注意 ELF 的 c 在 Eq 3 里不是文本 condition,而是self-conditioning 输入(见前面 self-cond 那个 callout):
所以 ELF 实际上把两个独立的 CFG 机制分开处理:
| 机制 | 训练时还是推理时 | 放大什么 | OWT 用 | XSum/WMT 用 |
|---|---|---|---|---|
| SC-CFG(Eq 3) | 训练时(baked-in) | self-conditioning 信号 | ω=3 | ω=1(已 baked,不额外推理) |
| Input-cond CFG(标准) | 推理时 | 文本 condition 信号 | —(无条件) | ω=2(推理时 ×2 forward) |
为什么不把 input-cond CFG 也烤进训练? 因为 input-cond CFG 的 ω 通常需要在推理时 sweep 不同值(找最佳质量),烤进训练就锁死了。 SC-CFG 的 ω 在 ELF 里是固定行为(不用 sweep),所以烤进训练划算。 这是一个 nuanced 的工程权衡。
Eq 3 = Chen 2025 训练时 CFG 这个 image diffusion trick,移植到 ELF 的 self-conditioning 维度, 让 32 步无条件采样的算力优势在加 CFG 后仍然保住。剥掉这个 trick, ELF 要么没 SC-CFG(质量低),要么推理 32 步 ×2 forward(失去步数效率卖点)。
Image diffusion 的 classifier-free guidance 通常是推理时跑两遍 forward 然后线性组合:
推理时 CFG: vfinal = vuncond + ω · (vcond − vuncond)
ELF 把它烤进训练——模型直接学 post-combination quantity,推理只需一次 forward。Paper Eq 3:
vtarget = (x − ε) + (1 − 1/ω) · (vθcfg(zt | t, c, ω) − vθcfg(zt | t, ∅, ω))
这个系数不是 ad hoc 写的,是从标准 inference-time CFG 5 步代数推导出来的。Chen et al. ICML 2025 "Visual Generation without Guidance" 的核心贡献就是这个 reparameterization。
Step 1 · 标准 CFG 公式(Ho & Salimans 2022 推理时 CFG):
vfinal(ω) = vu + ω·(vc − vu) = vc + (ω − 1)(vc − vu)
vc, vu 是"不加 CFG"时网络对 cond / uncond 的 logical base 输出。
Step 2 · 训练时 CFG 的网络已经学 post-combination quantity
ELF 网络输出 vθcfg(c, ω) 直接就是 vfinal(推理不再组合)。所以:
两者之差已经被 ω 预放大:
vcfgc − vcfgu = [vu + ω(vc−vu)] − vu = ω·(vc − vu)
Step 3 · 反推 logical base 差
vc − vu = (1/ω)·(vcfgc − vcfgu)
Step 4 · 构造 training target
我们要让网络的 vcfgc 收敛到 vfinal。代入 Step 3:
vtarget = vc + (ω − 1)·(vc − vu)
= vc + (ω − 1)·(1/ω)·(vcfgc − vcfgu)
= vc + (1 − 1/ω)·(vcfgc − vcfgu)
所以 (1 − 1/ω) 等于 (ω − 1) × (1/ω) 化简的结果——前一项是标准 CFG"超越 cond 的放大量",后一项是把"已被 ω 预放大的差值"还原回 logical base。
Step 5 · FM target 替换 vc
base FM target 就是 vc = x − ε(rectified-flow 标准 target),代入得 Eq 3:
vtarget = (x − ε) + (1 − 1/ω)·(vcfgc − vcfgu)
| ω | (1 − 1/ω) | vtarget | 含义 |
|---|---|---|---|
| 1 | 0 | x − ε | 无 CFG,退化为 plain FM ✓ |
| 2 | 0.5 | (x−ε) + 0.5·(vcfgc − vcfgu) | 中等放大 |
| 3 | 0.667 | (x−ε) + 0.667·(...) | ELF 默认 SC-CFG=3 |
| 5 | 0.8 | (x−ε) + 0.8·(...) | SC-CFG 上限 |
| ∞ | 1 | (x−ε) + (vcfgc − vcfgu) | 极限放大 |
ω=1 退化 ✓、ω→∞ 系数趋于 1 ✓、ω=3 对应 ELF OWT 默认 ✓。
这里的 trick 是 self-cond 的"condition"不是输入文本 c,而是 self-cond 输入 x'。 所以"uncond"就是 x'=0,"cond"就是 x'=stopgrad(net1)。CFG scale ω∈[0.5, 5] 是 ELF 自己也学的输入参数(4 个 SC-CFG token 编码)。
实现需要 3 次 forward。代码实际执行顺序是:
梯度只过第 2 次。最终 L2 loss = ‖v_pred − stopgrad(v + (1−1/ω)(v_cond − v_uncond))‖²。
# src/train_step.py — Eq 3 的自条件 CFG target 构造(简化版)
def compute_shared_uncond(z, t_input, x_tokens):
# forward #1: self-cond input = zeros
z_uncond = restore_cond(torch.zeros_like(z), x_tokens, cond_seq_mask)
z_input_uncond = torch.cat([z, z_uncond], dim=-1)
with torch.no_grad(), autocast(bf16):
net_out_uncond = model(z_input_uncond, t_input,
self_cond_cfg_scale=self_cond_cfg_scale)
return net_out_uncond
def get_sc_cond_and_uncond(z, t_input, cond_mask, x_tokens, shared_net_out_uncond):
v_uncond, x_uncond = net_out_to_v_x(shared_net_out_uncond, z, t_input, t_eps)
x_uncond = restore_cond(x_uncond, x_tokens, cond_mask)
# forward #2: self-cond input = stopgrad(x_uncond)
z_input_cond = torch.cat([z, x_uncond], dim=-1) # 注意 stop-grad on x_uncond
with torch.no_grad(), autocast(bf16):
net_out_cond = model(z_input_cond, t_input,
self_cond_cfg_scale=self_cond_cfg_scale)
v_cond, _ = net_out_to_v_x(net_out_cond, z, t_input, t_eps)
return v_cond, v_uncond
def get_sc_guided_v(z, t_input, base_v_target, x_tokens, shared_net_out_uncond):
v_cond, v_uncond = get_sc_cond_and_uncond(...)
sc_w = self_cond_cfg_scale.reshape(B, 1, 1)
sc_guidance = (1 - 1 / sc_w) * (v_cond - v_uncond) # ← Eq 3 第二项
sc_guidance = torch.where(use_self_cond_mask.bool(),
sc_guidance,
torch.zeros_like(sc_guidance))
return (base_v_target + sc_guidance).detach() # ← .detach() 是关键
# forward #3 (gradient-tracked) 在主 batch forward 里:
net_out, decoder_logits = model(model_input, t_mixed,
self_cond_cfg_scale=self_cond_cfg_scale,
decoder_step_active=decoder_step_active)
v_pred, _ = net_out_to_v_x(net_out, denoiser_z, t, t_eps=0.05) # gradient-tracked
v_final_target = get_sc_guided_v(denoiser_z, t, base_v_target=v_target, ...)
l2_per_token = ((v_pred - v_final_target) ** 2).mean(dim=-1) # MSE
推理时 CFG 要每步跑两遍 forward(cond + uncond),32 步采样 = 64 次 forward。 ELF 把 SC-CFG 烤进训练后,推理只跑 1 次 forward / step,效率 ×2。代价:训练时 3 次 forward。但训练只跑一次(5 epochs),推理跑无数次。划算。
注意 ELF 还另外保留了输入条件的推理时 CFG(label drop 训练的)。两个机制独立:
SC-CFG 用 Eq 3 的 self_cond_cfg_scale(推理时通常 = 1,因为已烤入);
input-cond CFG 是标准推理时 cond+uncond 组合(XSum/WMT 默认 = 2)。
label drop 的训练信号只上游改变条件 embedding 状态,不进入 Eq 3 的 SC-CFG 公式。
Paper Algs 3+4 写成两个独立 training step,按 0.8 / 0.2 概率轮换。PyTorch port 改成:
# src/train_step.py — 关键混合 forward
# 每行独立 Bernoulli(0.2) → decoder mode;否则 → denoiser mode
decoder_step_active = torch.bernoulli(
torch.full((B,), config.decoder_prob, dtype=torch.float32),
generator=gen,
).to(device=device, dtype=dtype) # (B,) 1.0=decode 0.0=denoise
decoder_mask_B11 = decoder_step_active.view(-1, 1, 1)
decoder_mask_B1 = decoder_step_active.view(-1, 1)
# t、z 都按 per-example 混合
denoiser_t = t # logit-normal per-sample
decoder_t = torch.ones_like(t) # 永远 1
t_mixed = decoder_step_active * decoder_t + (1.0 - decoder_step_active) * t
z_mixed = decoder_mask_B11 * decoder_z + (1.0 - decoder_mask_B11) * denoiser_z
# 单次 forward — mode token gate 也是 per-example
net_out, decoder_logits = model(
model_input, t_mixed,
self_cond_cfg_scale=self_cond_cfg_scale,
decoder_step_active=decoder_step_active, # (B,) per-row gate
)
# CE / L2 各自用 mask 分流
loss_mask_f = loss_mask.to(ce_per_token.dtype)
ce_mask = loss_mask_f * decoder_mask_B1
l2_mask = loss_mask_f * (1.0 - decoder_mask_B1)
# 关键:单一分母归一化(不是两个分母)
total_sum = (ce_per_token * ce_mask).sum() + (l2_per_token * l2_mask).sum()
loss = total_sum / torch.clamp(loss_mask_f.sum(), min=1.0)
# loss_mask_f = ce_mask + l2_mask,所以这等价于 sum/(ce_mask.sum()+l2_mask.sum())
# 注意:pad_token=="pad" 时 loss_mask 屏蔽 padding;pad_token=="eos" 时全 1
论文 Algs 3+4 是两个独立 step,按 (1−p):p 比例轮换(p = decoder_prob = 0.2)。
PyTorch port 是同一个 step 内 per-example 抽 mode。先注意:
每个 example 内所有 token 共享同一个 mode 抽样结果(不是 token-wise i.i.d.)。
记 b ∈ (0, 1) 为 example 的 mode 指示(1 = decode),B0 = denoiser 行数,B1 = decoder 行数。
固定 batch、固定 loss denominator M = loss_mask.sum(),PyTorch port 的 loss:
Lcode = [ Σb=1 行 Σtoken CE + Σb=0 行 Σtoken L2 ] / M
对 mode 抽样取期望(每行 Bernoulli(p)):
𝔼[Lcode] = (p · 𝔼row[Σ CE | decode] + (1−p) · 𝔼row[Σ L2 | denoise]) · (B / M)
= 𝔼[Lpaper] · (per-example token 数加权和) — 等价于 paper Algs 3+4 的固定 batch-size 加权期望。
注意:
条件生成时还需要另一个 classifier-free guidance,针对 input condition(不是 self-cond)。 10% 概率把 cond sequence 直接 drop(zero out),让模型学到 p(x | ∅) 分布:
# src/train_step.py — label drop for input-condition CFG
if config.label_drop_prob > 0:
drop = label_drop_mask.to(dtype=torch.float32).reshape(-1, 1, 1) # (B, 1, 1) 0/1
cond_mask = cond_seq_mask # (B, S)
# block_mask: 1 仅在 (non-cond row, cond col)
# 目的:让 target token 看不到 cond token
block_mask = (1 - cond_mask).unsqueeze(-1) * cond_mask.unsqueeze(1)
encoder_attention_mask = encoder_attention_mask * (1 - drop * block_mask)
label drop 实际两阶段:
encoder_attention_mask,让 target token 看不到 cond token
(block_mask 只在 non-cond row × cond col 上为 1) — 这样 T5 encode 出来的 x₀ 本身就不含条件信息denoiser_z 和 x₀ 在 cond 位置清零
(torch.where(drop & cond_seq_mask, zeros, denoiser_z))— 匹配 paper "zeroing condition embeddings"| 类别 | 参数 | 默认值 |
|---|---|---|
| Optimizer & Schedule | Optimizer | Muon(2D 参数走 Newton-Schulz)+ Nesterov-AdamW(其余) |
| LR (peak) | 0.002(公式:blr=0.001 × global_batch / 256) | |
| LR schedule | constant after warmup | |
| Warmup | 0.5 epoch(5 epochs 总 ~95K steps,对应 ~9.5K warmup steps) | |
| Weight decay | 0(关闭 — Muon 自带 shape scaling) | |
| Grad clip | 1.0 (norm) | |
| Batching | Global batch size | 512 |
| Sequence length | 1024 | |
| Grad accumulation | 1 (硬件够大不用) | |
| Diffusion | Denoiser P_mean / P_std / noise scale | −1.5 / 0.8 / 2.0 |
| Decoder P_mean / P_std / noise scale | 0.8 / 0.8 / 5.0 (OWT) | |
| Decoder prob (per example) | 0.2 (denoiser 0.8) | |
| Self-cond prob | 0.5(denoiser only;decoder 永远 0) | |
| CFG | SC-CFG range | [0.5, 5],power-bias sample 偏小值 |
| SC-CFG tokens | 4 | |
| Label drop prob (cond only) | 0.1(XSum/WMT),0(OWT) | |
| Numerics | Precision | bf16 autocast;输出头强制 fp32 |
| EMA decay | 0.9999 | |
| Random seed | 42(per-rank seed + rank offset,让噪声 desync) | |
| 训练量 | OWT epochs | ELF-B: 5 ELF-M: 4 ELF-L: 3(大模型收敛快) |
| OWT 总 tokens | ≈ 45.2B(OWT 数据集约 9.04B × 5 ep) | |
| Hardware | TPU v5p × 64,1.5h/epoch(ELF-B) |
| Method | Base training | Distillation training | Effective tokens | Ratio vs ELF |
|---|---|---|---|---|
| MDLM | 512 × 1M × 1024 | — | 524.3B | 11.6× |
| Duo | 512 × 1M × 1024 | — | 524.3B | 11.6× |
| MDLM + SDTT | 512 × 1M × 1024 | 512 × 10K × 5 × 1024 | 550.5B | 12.2× |
| Duo + DCD | 512 × 1M × 1024 | 512 × 10K × 5 × 1024 | 550.5B | 12.2× |
| FLM | 512 × 1M × 1024 | — | 524.3B | 11.6× |
| FMLM | 512 × 1M × 1024 | 512 × 100K × 1024 | 576.7B | 12.8× |
| LangFlow | 512 × 1M × 1024 | — | 524.3B | 11.6× |
| ELF (ours) | 5 × 9.04B | — | 45.2B | 1.0× |
Baseline 估算公式:batch_size × n_steps × seq_length。ELF 用 OWT 总 token × epochs。
精确 ratio 是 11.6× / 12.2× / 12.8×,paper 文字简称"约 12×"。
注意:ELF 的"45.2B"不包括 T5-small 预训练(Google 用了 1T+ tokens 训 T5)—— 这是 paper 最容易被 attack 的地方。
Muon = "Momentum + Newton-Schulz orthogonalization",2024 Keller Jordan 推。核心思想:
nn.Linear.weight、bare nn.Parameter 矩阵),
先算 Nesterov-momentum 平均的梯度,再用 5 步 Newton-Schulz 把它正交化(近似 SVD ⟨U V⊤⟩)sqrt(max(1, fan_out/fan_in)) 形状缩放ELF 的 PyTorch port 用的是 PyPI muon-optimizer 包,加几层 wrapper / patches:
(a) Nesterov-bias-corrected Adam update(替换上游 adam_update);
(b) NS5 强制 fp32 + eps 1e-8(替换上游 bf16 + 1e-7);
(c) Muon update 重写 + shape scaling layout 修正(区分 nn.Linear 和 bare Parameter);
(d) _SafeMuonAuxAdam subclass:zero-fill missing grads + distributed all_gather padding 修复。
全部为了匹配 JAX optax.contrib.muon。
论文 App C.5 ablation:Muon vs AdamW,SDE 采样下差距尤其大(paper 只定性写"more pronounced under SDE",没给数字)。
latent_std=0.2 硬编码,不在 paper Tab 4 里,等于把 T5 raw embedding 放大 5× 后再加噪声。换数据集时需要重新估计,否则 SNR 漂移。γ=2.0, SC-CFG=3seed=42;paper Tab 6 是 6-seed 平均。当前 CLI 没有 --seeds 参数,需要多次运行用 --config_override seed=Nmuon-optimizer>=0.1.0 未 pin,多处 wrapper/patches,上游一改就风险静默漂移。建议 pin 死版本get_sc_guided_v 用了 3 forward 但只对第 2 forward 反传,其余 no_grad。如果 fork 改代码忘了 .detach(),梯度会污染 target 路径| Model | Depth | Hidden | Heads | head_dim | MLP ratio | Bottleneck | Params (DiT) | OWT Epochs |
|---|---|---|---|---|---|---|---|---|
| ELF-B | 12 | 768 | 12 | 64 | 4× | 128 | 105M | 5 |
| ELF-M | 24 | 1056 | 16 | 66 | 4× | 128 | 342M | 4 |
| ELF-L | 32 | 1280 | 16 | 80 | 4× | 128 | 652M | 3 |
三个模型共享同一份 frozen T5-small encoder (~35.3M 参数,encoder-only,不参与训练)。条件生成(XSum / WMT14)时 T5-small encoder 也用来编码 source context — 因此条件 ELF-B 报作 "105M+35M"。 Decoder 没有独立训练 — denoise 和 decode 用同一份 Transformer 权重,靠 model-mode token 切换。
t5-small)| 项 | 值 |
|---|---|
| Vocab size | 32128(SentencePiece vocabulary) |
| Layers | 6 (encoder only) |
| Hidden d_model | 512 |
| Attention heads | 8 (head_dim 64) |
| FFN d_ff | 2048 |
| Activation | ReLU (非 gated) |
| Params | ~35M |
| 训练时 | 整个 encoder 都 requires_grad_(False),作 frozen embedder |
训练 ELF 时永远只 forward 这个 encoder 一次(per batch),bf16 autocast 跑,last_hidden_state
shape [B, L, 512]。然后做归一化:(x − latent_mean) / latent_std,
默认 latent_std = 0.2 等于把 raw T5 embedding 放大 5×。
不在 vocabulary 里,是可学的 nn.Parameter + 输入相关 embedding 的加和:
| 类别 | 个数 | 编码 | 作用 |
|---|---|---|---|
t_emb_tokens | 4 | learnable_tokens[1,4,d] + TimestepEmbedder(t)[B,d].unsqueeze(1) | 把当前扩散时间 t∈[0,1] 注入 |
self_cond_cfg_tokens | 4 | learnable_tokens[1,4,d] + TimestepEmbedder(ω)[B,d].unsqueeze(1) | 把 self-cond CFG scale ω∈[0.5, 5] 注入(训练时随机抽,推理时固定) |
mode_tokens | 4 | learnable_tokens[1,4,d] × active_gate | denoise mode → gate=0(token 被乘 0);decode mode → gate=1(token 激活) |
三组合起来共 12 个 prefix tokens。最终序列顺序是
[time(4) + sc_cfg(4)] + [mode(4)] + [main_x(L)] ——
代码先 cat([mode, main]),再 prepend [time + sc_cfg](见 5.12 forward 第 4-5 步)。
重要细节:RoPE 在所有 prefix(12 个) 上 cos=1, sin=0(不旋转),
主序列从位置 0 开始正常 RoPE 编码。这样添加 prefix 不会破坏 main token 的相对位置。
T5 embedding 512-d → bottleneck → DiT hidden。直觉:clean 文本数据其实在 低维流形上。 论文 App C.2 sweep 了 bottleneck ∈ (32, 128, 512):
class BottleneckTextProj(nn.Module):
def __init__(self, text_encoder_dim, hidden_size, bottleneck_dim):
super().__init__()
self.proj1 = nn.Linear(text_encoder_dim, bottleneck_dim, bias=False) # 512 → 128
self.proj2 = nn.Linear(bottleneck_dim, hidden_size, bias=True) # 128 → 768
def forward(self, x): # [B, L, 512]
return self.proj2(self.proj1(x)) # [B, L, 768]
class TextRotaryEmbeddingFast(nn.Module):
def __init__(self, dim, pt_seq_len=512, ft_seq_len=None,
theta=10000.0, num_empty_token=0):
super().__init__()
ft_seq_len = ft_seq_len or pt_seq_len
# 标准 RoPE 频率
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[:dim//2].float() / dim))
# 位置缩放支持 fine-tune 时 seq_len 与 pretrain 不同
pos = torch.arange(ft_seq_len).float() / ft_seq_len * pt_seq_len
freqs_main = torch.einsum('..., f -> ... f', pos, freqs)
freqs_main = repeat(freqs_main, '... n -> ... (n r)', r=2)
# prefix 的 cos=1, sin=0 → 不旋转
if num_empty_token > 0:
cos_prefix = torch.ones((num_empty_token, freqs_main.shape[-1]))
sin_prefix = torch.zeros_like(cos_prefix)
freqs_cos = torch.cat([cos_prefix, torch.cos(freqs_main)], dim=0)
freqs_sin = torch.cat([sin_prefix, torch.sin(freqs_main)], dim=0)
self.register_buffer("freqs_cos", freqs_cos, persistent=False)
self.register_buffer("freqs_sin", freqs_sin, persistent=False)
def forward(self, t): # [B, n_heads, L_total, head_dim]
cos = self.freqs_cos.to(t.dtype) # 显式 dtype cast 防 bf16 精度漂移
sin = self.freqs_sin.to(t.dtype)
return t * cos + rotate_half(t) * sin
class TimestepEmbedder(nn.Module):
# t (scalar in [0,1]) -> hidden vector (e.g., 768)
# Init: MLP weights normal(0.02), biases zero
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp_0 = nn.Linear(frequency_embedding_size, hidden_size, bias=True)
self.mlp_2 = nn.Linear(hidden_size, hidden_size, bias=True)
@staticmethod
def timestep_embedding(t, dim=256, max_period=10000):
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(half).float() / half)
args = t[:, None].float() * freqs[None]
return torch.cat([torch.cos(args), torch.sin(args)], dim=-1) # [B, 256]
def forward(self, t): # [B]
emb = self.mlp_0(self.timestep_embedding(t, 256)) # [B, 768]
return self.mlp_2(F.silu(emb)) # [B, 768]
class Attention(nn.Module):
def __init__(self, dim=768, num_heads=12, qkv_bias=True, qk_norm=True,
attn_drop=0.0, proj_drop=0.0):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads # 64 for ELF-B
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = RMSNorm(head_dim) if qk_norm else nn.Identity()
self.k_norm = RMSNorm(head_dim) if qk_norm else nn.Identity()
self.proj = nn.Linear(dim, dim, bias=True)
def forward(self, x, rope_fn, attention_mask=None):
B, N, C = x.shape # N = 1036 在 ELF-B 训练时
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
qkv = qkv.permute(2, 0, 3, 1, 4) # [3, B, n_heads, N, head_dim]
q, k, v = qkv[0], qkv[1], qkv[2]
q = self.q_norm(q) # qk-norm(论文与官方实现都开)
k = self.k_norm(k)
if rope_fn is not None: # 应用 RoPE(含 prefix-no-rotation)
q = rope_fn(q)
k = rope_fn(k)
# 实际 layers.py 里包了一层 wrapper:int/float mask -> bool mask 再传 SDPA
x = scaled_dot_product_attention(q, k, v, attn_mask=attention_mask)
return self.proj(x.permute(0, 2, 1, 3).reshape(B, N, C))
注意三件事:(1) qk_norm=True 是 ELF 默认开(提升 bf16 训练稳定性,从 Henry et al. EMNLP'20);
(2) 注意 F.scaled_dot_product_attention 内部用 PyTorch 2.x flash kernel;
源码 wrapper 在 mask 是 2D/3D 时 reshape 加 head 维并 cast 为 bool;
(3) ELF 训练/采样调用模型时都不传 attention_mask(cond+target 全双向 attention),
T5 encoder 那一侧才有 cond/target 不对称 mask。
class SwiGLUFFN(nn.Module):
def __init__(self, dim, hidden_dim, drop=0.0):
super().__init__()
# SwiGLU 标准做法:把 hidden 缩到 2/3 保持 param count
hidden_dim_eff = int(hidden_dim * 2 / 3) # 768*4 = 3072 → 2048
self.w12 = nn.Linear(dim, 2 * hidden_dim_eff, bias=True) # 768 → 4096
self.w3 = nn.Linear(hidden_dim_eff, dim, bias=True) # 2048 → 768
def forward(self, x): # [B, N, 768]
x1, x2 = self.w12(x).chunk(2, dim=-1) # 各 [B, N, 2048]
return self.w3(F.silu(x1) * x2) # [B, N, 768]
class FinalLayer(nn.Module):
# Last layer that maps hidden 768 -> embedding output 512.
def __init__(self, hidden_size, patch_size, out_channels):
super().__init__()
self.norm_final = RMSNorm(hidden_size)
# 关键: kernel + bias 都用 0 初始化(DiT 标配)
# → 开局时模型预测的 clean x_pred ≡ 0;
# velocity 由后处理 v=(x_pred - z)/(1 - t) 计算
self.linear = nn.Linear(hidden_size, patch_size**2 * out_channels, bias=True)
nn.init.zeros_(self.linear.weight)
nn.init.zeros_(self.linear.bias)
def forward(self, x): # [B, L, 768]
return self.linear(self.norm_final(x)) # [B, L, 512]
class ELFBlock(nn.Module):
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0):
super().__init__()
self.norm1 = RMSNorm(hidden_size, eps=1e-6)
self.attn = Attention(hidden_size, num_heads, qkv_bias=True, qk_norm=True)
self.norm2 = RMSNorm(hidden_size, eps=1e-6)
self.mlp = SwiGLUFFN(hidden_size, int(hidden_size * mlp_ratio))
def forward(self, x, rope_fn, attention_mask=None):
x = x + self.attn(self.norm1(x), rope_fn, attention_mask)
x = x + self.mlp(self.norm2(x))
return x
注意是 Pre-Norm + 残差(RMSNorm 在 attn/mlp 之前)。这是 LLaMA / DiT / GPT-J 等长 bf16 训练的常见稳定选择。
class ELF(nn.Module):
# 源码默认 num_model_mode_tokens=0, vocab_size=0
# train.py 实际从 config / tokenizer 把 4 / 32128 传进来
def __init__(self, text_encoder_dim=512, max_length=1024,
hidden_size=768, depth=12, num_heads=12,
mlp_ratio=4.0, bottleneck_dim=128,
num_time_tokens=4, num_self_cond_cfg_tokens=4,
num_model_mode_tokens=4, vocab_size=32128):
super().__init__()
# Self-conditioning projection: [z; x_pred] (2×512) -> 512
self.self_cond_proj = nn.Linear(2 * text_encoder_dim, text_encoder_dim, bias=True)
# Bottleneck text projection (512 -> 128 -> 768)
self.text_proj = BottleneckTextProj(text_encoder_dim, hidden_size, bottleneck_dim)
# Prefix tokens + their (input-dependent) embedders
self.t_embedder = TimestepEmbedder(hidden_size)
self.t_emb_tokens = nn.Parameter(torch.empty(1, num_time_tokens, hidden_size))
self.self_cond_cfg_embedder = TimestepEmbedder(hidden_size)
self.self_cond_cfg_tokens = nn.Parameter(
torch.empty(1, num_self_cond_cfg_tokens, hidden_size))
self.mode_tokens = nn.Parameter(torch.empty(1, num_model_mode_tokens, hidden_size))
# RoPE with prefix no-rotation
head_dim = hidden_size // num_heads
prefix_total = num_time_tokens + num_self_cond_cfg_tokens + num_model_mode_tokens
self.feat_rope = TextRotaryEmbeddingFast(
dim=head_dim, pt_seq_len=max_length, num_empty_token=prefix_total)
# 12 ELFBlocks
self.blocks = nn.ModuleList([
ELFBlock(hidden_size, num_heads, mlp_ratio) for _ in range(depth)])
# Flow-matching output head (zero-init)
self.final_layer = FinalLayer(hidden_size, patch_size=1, out_channels=text_encoder_dim)
# Factored decoder unembedding: 768 -> 512 (GELU) -> vocab
self.proj_kernel = nn.Parameter(torch.empty(hidden_size, text_encoder_dim))
self.proj_bias = nn.Parameter(torch.empty(text_encoder_dim))
self.unembed_kernel = nn.Parameter(torch.empty(text_encoder_dim, vocab_size))
self.unembed_bias = nn.Parameter(torch.empty(vocab_size))
def forward(self, x, t, self_cond_cfg_scale=None, decoder_step_active=None):
# x: [B, L, 512] if no self-cond, OR [B, L, 1024] if self-cond cat([z, x_pred])
# t: [B] 扩散时间 ∈ [0,1]
# self_cond_cfg_scale: [B] or None ω 用 SC-CFG token 编码
# decoder_step_active: None | True/False | Tensor[B] 控制 model-mode token gate
B = x.shape[0]
# ====== 1. self-cond projection: 2C -> C ======
if x.shape[-1] == 2 * self.text_encoder_dim: # 训练 & 推理都走这里
x = self.self_cond_proj(x.float()) # [B, L, 1024] -> [B, L, 512]
# ====== 2. bottleneck text projection ======
x = self.text_proj(x.float()) # [B, L, 512] -> [B, L, 768]
# ====== 3. build prefix context tokens ======
time_emb = self.t_embedder(t) # [B, 768]
prefix = self.t_emb_tokens.expand(B, -1, -1) + time_emb.unsqueeze(1) # [B, 4, 768]
if self_cond_cfg_scale is not None:
sc_emb = self.self_cond_cfg_embedder(self_cond_cfg_scale) # [B, 768]
prefix2 = self.self_cond_cfg_tokens.expand(B, -1, -1) + sc_emb.unsqueeze(1)
prefix = torch.cat([prefix, prefix2], dim=1) # [B, 8, 768]
# ====== 4. model-mode tokens with per-example gating ======
if decoder_step_active is None:
gate = 0.0 # 默认 denoise: tokens 被乘 0
elif isinstance(decoder_step_active, torch.Tensor):
gate = decoder_step_active.view(-1, 1, 1) # [B, 1, 1] per-example
else:
gate = float(decoder_step_active) # 1.0 in final decode
mode_tokens = self.mode_tokens.expand(B, -1, -1) * gate # [B, 4, 768]
x = torch.cat([mode_tokens, x], dim=1) # [B, L+4, 768]
model_mode_offset = self.num_model_mode_tokens # = 4
# ====== 5. prepend (time + sc-cfg) tokens ======
# 最终顺序: [time(4), sc-cfg(4), mode(4), main(L)]
x = torch.cat([prefix, x], dim=1) # [B, L+12, 768]
prefix_len = prefix.shape[1] # = 8 (time + sc-cfg)
# ====== 6. 12 ELFBlocks with RoPE ======
for block in self.blocks:
x = block(x, rope_fn=self.feat_rope, attention_mask=None) # full bidi
# ====== 7. strip prefix (动态 = prefix_len + model_mode_offset) ======
x = x[:, prefix_len + model_mode_offset:] # [B, L, 768]
# ====== 8. flow-matching head (always computed) ======
flow_output = self.final_layer(x.float()) # [B, L, 512]
# ====== 9. decoder head (only if decoder_step_active) ======
decoder_logits = None
if decoder_step_active is not None:
x_f32 = x.float()
hidden = F.gelu(x_f32 @ self.proj_kernel + self.proj_bias, approximate="tanh")
decoder_logits = hidden @ self.unembed_kernel + self.unembed_bias # [B, L, 32128]
return flow_output, decoder_logits
| 步骤 | 张量 | shape | 说明 |
|---|---|---|---|
| 0 | input_ids | [512, 1024] | tokenized text,int64 |
| 1 | T5 encoder output | [512, 1024, 512] | frozen contextual embedding |
| 2 | normalized x₀ | [512, 1024, 512] | 除以 latent_std=0.2(相当于 ×5) |
| 3 | noisy z_t | [512, 1024, 512] | z = t·x₀ + (1−t)·ε·2.0;t = sigmoid(N(−1.5, 0.8²)) logit-normal |
| 4 | self-cond input | [512, 1024, 1024] | cat([z_t, x_pred]),channel-wise concat,2×512 |
| 5 | self_cond_proj output | [512, 1024, 512] | 线性投影回 512 |
| 6 | bottleneck proj1 | [512, 1024, 128] | 512 → 128(无 bias,强约束) |
| 7 | bottleneck proj2 | [512, 1024, 768] | 128 → 768 |
| 8 | time emb | [512, 768] | sinusoidal(256) → MLP(768) |
| 9 | time prefix tokens | [512, 4, 768] | learnable 加上 time_emb 广播 |
| 10 | sc-cfg prefix tokens | [512, 4, 768] | learnable 加上 ω embedding |
| 11 | mode tokens (gated) | [512, 4, 768] | per-example gate 0 (denoise) 或 1 (decode) |
| 12 | concat: mode + main | [512, 1028, 768] | 中间态:mode 暂时在最前 |
| 13 | concat: prefix + above | [512, 1036, 768] | 最终顺序 [time(4) + sc-cfg(4) + mode(4) + main(1024)] |
| 14 | RoPE 应用于 q,k | [512, 12, 1036, 64] | n_heads=12, head_dim=64;prefix 位置 cos=1, sin=0 |
| 15 | each ELFBlock output | [512, 1036, 768] | 共 12 个 block,shape 不变 |
| 16 | strip prefix | [512, 1024, 768] | 取后 1024 个 token |
| 17 | FinalLayer (flow head) | [512, 1024, 512] | RMSNorm + Linear 768→512,预测 clean x₀ |
| 18a | (decode only) proj | [512, 1024, 512] | x_f32 @ proj_kernel + proj_bias,再 GELU(tanh) |
| 18b | (decode only) logits | [512, 1024, 32128] | hidden @ unembed_kernel + unembed_bias |
| 组件 | 计算 | 参数 |
|---|---|---|
| self_cond_proj | 1024×512 + 512 | 524,800 |
| BottleneckTextProj.proj1 | 512×128 (no bias) | 65,536 |
| BottleneckTextProj.proj2 | 128×768 + 768 | 99,072 |
| TimestepEmbedder (×2: time, sc-cfg) | (256×768+768) + (768×768+768) ≈ 788K | ×2 = 1,575,936 |
| t_emb_tokens + sc_cfg_tokens + mode_tokens | 3 × (4×768) | 9,216 |
| 每个 ELFBlock | ~7.09M(见下) | — |
| RMSNorm × 2 | 2 × 768 | 1,536 |
| Attention.qkv | 768 × 2304 + 2304 | 1,771,776 |
| Attention.q_norm + k_norm | 2 × 64 | 128 |
| Attention.proj | 768 × 768 + 768 | 590,592 |
| SwiGLU.w12 | 768 × 4096 + 4096 | 3,149,824 |
| SwiGLU.w3 | 2048 × 768 + 768 | 1,573,632 |
| 子总(一个 block) | 7,087,488 | |
| 12 个 ELFBlock | 12 × 7.09M | 85,049,856 |
| FinalLayer (norm + linear) | 768 + 768×512 + 512 | 394,496 |
| proj_kernel + proj_bias (decoder) | 768×512 + 512 | 393,728 |
| unembed_kernel + unembed_bias | 512×32128 + 32128 | 16,481,664 |
| 合计(trainable) | 各行精确相加 | 104,594,304 (~105M, 假设 vocab=32128) |
| + T5-small encoder (frozen) | ~35M |
关键观察:decoder unembedding 占 16.5M(~16%),这是 32128 词表的主要开销。 而 12 个 transformer block 占 85M(~81%)—— 真正的核心。 self_cond_proj 只占 0.5%,但它是 self-conditioning trick 能 work 的关键 plumbing。
给定 (B, L) 和采样配置(method, n_steps, cfg, sc_cfg, γ):
# src/utils/generation_utils.py — _generate_samples_single_batch (简化)
@torch.no_grad()
def _generate_samples_single_batch(model, generator, z, t_steps,
cond_seq, cond_seq_mask,
config, sampling_config,
cfg_scale, self_cond_cfg_scale):
method = sampling_config.sampling_method # 'ode' or 'sde'
B, L, d = z.shape # d = 512
if cond_seq is None: # OWT 无条件生成
cond_seq = torch.zeros((B, L, d), dtype=z.dtype, device=z.device)
cond_seq_mask = torch.zeros((B, L), dtype=z.dtype, device=z.device)
# 在条件位置上把 z 设回 clean cond_seq(cond 不去噪)
z = restore_cond(z, cond_seq, cond_seq_mask)
x_pred = restore_cond(torch.zeros_like(z), cond_seq, cond_seq_mask)
n = t_steps.shape[0] # = n_steps + 1
sde_gamma = getattr(sampling_config, "sde_gamma", 0.0)
use_bf16 = config.use_bf16 and z.is_cuda
with torch.amp.autocast('cuda', dtype=torch.bfloat16, enabled=use_bf16):
# n_steps - 2 个中间步用 ODE/SDE
for i in range(n - 2):
t = t_steps[i].item()
t_next = t_steps[i + 1].item()
if method == "sde":
z, x_pred = _sde_step(z, t, t_next, x_pred, gamma=sde_gamma, ...)
else: # 'ode'
z, x_pred = _ode_step(z, t, t_next, x_pred, ...)
# 最后一步强制 ODE — t 接近 1 不再注入新噪声
z, x_pred = _ode_step(z, t_steps[-2], t_steps[-1], x_pred, ...)
return z # [B, L, 512]
@torch.no_grad()
def _dlm_decode_batch(z, model, t_final_val, config, self_cond_cfg_scale):
# 最终一步把 latent z (≈ clean embedding) 映射回 token IDs
B = z.shape[0]
t_final = torch.full((B,), float(t_final_val), dtype=z.dtype, device=z.device)
sc_batch = (torch.full((B,), float(self_cond_cfg_scale), dtype=z.dtype, device=z.device)
if config.num_self_cond_cfg_tokens > 0 else None)
# 推理时 self-cond 输入永远 = zeros(与训练 decoder 分支一致)
z_input = torch.cat([z, torch.zeros_like(z)], dim=-1) if config.self_cond_prob > 0 else z
with torch.amp.autocast('cuda', dtype=torch.bfloat16, enabled=config.use_bf16):
_, decoder_logits = model(
z_input, t_final,
self_cond_cfg_scale=sc_batch,
decoder_step_active=True, # ← mode token gate = 1
)
return decoder_logits.argmax(dim=-1) # [B, L]
关键点:主循环 + 最终 decode 是两次独立 forward。前者把噪声 ε transport 到接近 clean x; 后者把 clean embedding 投影到 vocab。
def _ode_step(model, z, t, t_next, x_pred_prev,
config, cfg_scale, self_cond_cfg_scale,
cond_seq, cond_seq_mask):
t_batch = torch.full((z.shape[0],), float(t), dtype=z.dtype, device=z.device)
v_pred, x_pred = _forward_sample(
model=model, z=z, t_batch=t_batch, x_pred_prev=x_pred_prev,
config=config, cfg_scale=cfg_scale, self_cond_cfg_scale=self_cond_cfg_scale,
cond_seq=cond_seq, cond_seq_mask=cond_seq_mask,
)
return z + (t_next - t) * v_pred, x_pred # Euler: z_{i+1} = z_i + dt · v
def _sde_step(model, z, t, t_next, x_pred_prev,
config, cfg_scale, self_cond_cfg_scale,
cond_seq, cond_seq_mask, gamma, generator):
h = float(t_next - t)
alpha = max(0.0, min(1.0, 1.0 - gamma * h)) # 信号保留比例 ∈ [0,1]
t_back = alpha * float(t) # 时间往回拉到 α·t
eps = torch.randn(z.shape, dtype=z.dtype, device=z.device) * config.denoiser_noise_scale
z_back = restore_cond(alpha * z + (1.0 - alpha) * eps, cond_seq, cond_seq_mask)
t_batch = torch.full((z.shape[0],), t_back, dtype=z.dtype, device=z.device)
v_pred, x_pred = _forward_sample(
model=model, z=z_back, t_batch=t_batch, x_pred_prev=x_pred_prev,
config=config, cfg_scale=cfg_scale, self_cond_cfg_scale=self_cond_cfg_scale,
cond_seq=cond_seq, cond_seq_mask=cond_seq_mask,
)
# 从 backtracked state 用 Euler 一步推到 t_next
return z_back + (t_next - t_back) * v_pred, x_pred
三个直觉:
γ = 0: alpha = 1, t_back = t, z_back = z — 退化为 ODE Eulerγ > 0: alpha < 1,z 被 α 缩小并注入新噪声 (1-α)·ε — 相当于把"已经去噪一些"的状态拉回到更早时刻,再 denoise 一次# 函数签名默认参数是 P_mean=-0.8(旧默认)
# 但 caller 在 generation.py 中传入 config.denoiser_p_mean = -1.5(OWT 训练值)
def get_sampling_steps(n_steps, time_schedule="logit_normal",
P_mean=-0.8, P_std=0.8, device=None, dtype=torch.float32):
if time_schedule == "uniform":
return torch.linspace(0.0, 1.0, n_steps + 1, dtype=dtype, device=device)
# logit-normal:从训练同分布抽 n_steps - 1 个点
z = torch.randn((n_steps - 1,), dtype=dtype, device=device) * P_std + P_mean
steps = torch.sigmoid(z)
steps = torch.sort(steps).values # 升序排列
# 强制端点 0 / 1
lo = torch.zeros((1,), dtype=dtype, device=steps.device)
hi = torch.ones((1,), dtype=dtype, device=steps.device)
return torch.cat([lo, steps, hi], dim=0) # [n_steps + 1]
论文 App B.2:"smaller intervals when t is close to 0 and larger intervals as t approaches 1"。 P_mean=−1.5 使 sigmoid(N(−1.5, 0.64)) 的密度向 t 较小的 noisy 区集中, 所以中间 sample 点偏小 → 排序后noisy 区间隔细密、clean 区间隔粗,与训练 t 同分布匹配。
真实的每步 forward要处理两层 CFG:
self_cond_cfg_scale token 编码 ω# sampling_utils.py — 双层 CFG 嵌套
def _forward_sample(model, z, t_batch, x_pred_prev, config,
cfg_scale, self_cond_cfg_scale, cond_seq, cond_seq_mask):
# 内层:带 cond 的 forward(条件 token + SC-CFG)
v_cond, x_cond = _forward_sample_self_cond(
model, z, t_batch, x_pred_prev, config,
self_cond_cfg_scale=self_cond_cfg_scale,
cond_seq=cond_seq, cond_seq_mask=cond_seq_mask,
)
if cfg_scale == 1.0:
return v_cond, x_cond # 无 input-cond CFG,直接返回
# 外层:input-cond CFG → 再跑一次 uncond
z_uncond = restore_cond(z, torch.zeros_like(z), cond_seq_mask)
x_pred_prev_uncond = (None if x_pred_prev is None else
restore_cond(x_pred_prev, torch.zeros_like(x_pred_prev), cond_seq_mask))
v_uncond, x_uncond = _forward_sample_self_cond(
model, z_uncond, t_batch, x_pred_prev_uncond, config,
self_cond_cfg_scale=self_cond_cfg_scale,
cond_seq=torch.zeros_like(cond_seq), cond_seq_mask=cond_seq_mask,
)
# 标准 CFG 线性组合
v_out = v_uncond + cfg_scale * (v_cond - v_uncond)
x_out = x_uncond + cfg_scale * (x_cond - x_uncond)
return restore_vx(v_out, x_out, cond_seq, cond_seq_mask)
每步代价(worst case,cond 生成):2 次模型 forward(cond + uncond)。
OWT 无条件生成 cfg_scale=1,每步只1 次 forward。SC-CFG 被烤进训练所以不用 ×2。
| 方法 | 每步 forward 数 | 32 步总 forward | 说明 |
|---|---|---|---|
| 传统 inference-time CFG 方法(启用时) | 2 (cond + uncond) | — | 常用于 MDLM/LLaDA/Duo 等启用 CFG 的配置 |
| ELF OWT 无条件 (SC-CFG baked-in, cfg=1) | 1 | 32 + 1 decode | SC-CFG 单次 forward 已含 ω 信息 |
| ELF 条件 (input-cond CFG=2) | 2 (cond + uncond) | 64 + 1 decode | SC-CFG=1 不额外 ×2,input-cond CFG 才 ×2 |
论文 page 26 Table 6(6 个 evaluation seed 平均 ± SE):
| Steps | SC-CFG | γ | Gen-PPL ↓ | Entropy ↑ |
|---|---|---|---|---|
| 8 | 3 | 2.0 | 67.32 ± 2.25 | 5.14 ± 0.085 |
| 16 | 3 | 2.0 | 33.66 ± 1.09 | 5.16 ± 0.026 |
| 32 | 3 | 1.5 | 24.08 ± 0.16 | 5.15 ± 0.002 |
三个观察:
论文 page 26 Table 7。三个 size × 两种采样器 × CFG sweep。SDE 全方位优于 ODE。 表内灰色项是 entropy < 5.0(多样性不足,不算 valid):
| Sampler | SC-CFG | ELF-B 105M (PPL/Ent) | ELF-M 342M (PPL/Ent) | ELF-L 652M (PPL/Ent) |
|---|---|---|---|---|
| SDE (γ=1.0) | 0.5 | 36.77 / 5.28 | 39.21 / 5.35 | 37.50 / 5.41 |
| 1.0 | 29.50 / 5.23 | 33.45 / 5.30 | 31.82 / 5.37 | |
| 1.5 | 25.25 / 5.18 | 28.42 / 5.26 | 28.72 / 5.35 | |
| 2.0 | 22.53 / 5.14 | 25.34 / 5.23 | 26.47 / 5.32 | |
| 3.0 | 19.72 / 5.10 | 21.69 / 5.18 | 23.31 / 5.28 | |
| 3.5 | 37.56 / 5.30 ⁱ | 36.48 / 5.34 ⁱ | 22.28 / 5.27 | |
| 4.0 | 36.50 / 5.29 ⁱ | 34.93 / 5.33 ⁱ | 21.37 / 5.26 | |
| ODE | 0.5 | 104.29 / 5.51 | 88.51 / 5.51 | 68.27 / 5.52 |
| 1.0 | 65.30 / 5.40 | 62.47 / 5.44 | 49.72 / 5.45 | |
| 1.5 | 44.85 / 5.31 | 46.71 / 5.37 | 39.97 / 5.40 | |
| 2.0 | 34.65 / 5.23 | 37.66 / 5.32 | 33.72 / 5.36 | |
| 3.0 | 26.62 / 5.15 | 28.80 / 5.24 | 26.57 / 5.29 |
ⁱ = 论文表中标灰的 cell:CFG > 3 后 ELF-B / ELF-M 出现 PPL 反转/上升, 不是 entropy < 5.0;它们 entropy 仍 ≥5.29。论文 App C 同时用 entropy < 5 或 PPL > 300 作"poor generation"红区。
关键观察(数字均来自 Tab 7 valid 区间内):
src/configs/sampling_configs/cond_sampling_configs.yml:
- sampling_method: ode
num_sampling_steps: [64]
cfgs: [2] # input-cond CFG = 2
self_cond_cfg_scales: [1] # SC-CFG = 1(已烤进训练)
time_schedule: logit_normal
条件生成默认使用 ODE(paper/config 默认;论文未明确解释,可理解为条件任务在标准 CFG=2 下已经稳定,不需要 SDE 额外随机性)。 SC-CFG=1 因为推理时不需要额外 ω 调整;input-cond CFG=2 是标准 image diffusion 推理时 CFG。
README 承诺:PyTorch port 在 8× L40S / H200 上跑应与 paper TPU v5p-64 数字漂移 ≲ 1 Gen-PPL 或 < 0.5 BLEU/ROUGE。漂移来源:
optax.contrib.muonREADME 强调用 use_bf16=true(匹配训练 precision)和 use_compile=true(torch.compile ~3-4× speedup)作为推荐 eval flag。
论文 Alg 6 写 α = 1 − γ·dt。但 dt 不是 1/N(uniform 间隔),是 logit-normal 抽出来的——
不同步的 dt 跨度差很大(t 靠近 0 时 dt 可能 0.01,靠近 1 时可能 0.4)。
所以同一个 γ 在不同 step 里实际"重置强度"完全不同。代码用 clip(α, 0, 1),仅在 γ·dt ≥ 1 时才把 α clip 到 0:
比如 32-step 默认 γ=1.5,最大 dt 通常 ~0.4,γ·dt = 0.6 不会 clip;
但 8-step γ=2.0 时如果 dt ≥ 0.5 就会 clip。
这是 ELF 实现的关键细节,建议念 paper 时跟代码对照(src/utils/sampling_utils.py:226-251)。
Tab 7 各 size 内自身最低 valid PPL(valid = 落在 entropy 合理区): ELF-B SDE CFG=3 → 19.72;ELF-M SDE CFG=3 → 21.69;ELF-L SDE CFG=4 → 21.37。 绝对最低是 ELF-B 的 19.72;ELF-L 21.37 比 ELF-M 21.69 更低。 但 paper 强调 frontier 而非单点最低:scaling 整体向更优方向推进(同 entropy 下更低 PPL,同 PPL 下更高 entropy)。
ELF-B 32 步达到 Gen-PPL ≈ 24。从 Fig 7(a) 视觉读数:ELF-B 32-step 已接近或优于若干 baseline 在高 step(如 1024)下的水平, 推理时间 substantially less than prior methods(paper §4.2 原文)。
三种蒸馏过的 few-step variant:
这些都需要额外蒸馏阶段(10K-100K extra steps),但 32 步 PPL 还是不如未蒸馏的 ELF-B。 即在这套系统配置下,ELF 不加额外 distillation 仍然超过这些 distilled baselines("架构层面的优势"是我对这个现象的解读,非论文逐字 claim)。
详见 4.9 节 Tab 5:ELF 45.2B 总 token,所有 baseline 都在 524-577B 区间, 其中蒸馏 variant 因为还要加蒸馏 epoch,token 更多。精确 ratio 11.6× / 12.2× / 12.8×(paper 简称 ~12×)。 Tab 5 只统计 ELF/baseline 自身训练与蒸馏 token,不包含外部 encoder 预训练成本。
| Model | Size | De-En BLEU ↑ | XSum R-1 ↑ | R-2 ↑ | R-L ↑ |
|---|---|---|---|---|---|
| AR (Transformer) | 99M | 25.2 | 30.5 ± 0.13 | 10.2 ± 0.11 | 24.4 ± 0.12 |
| MDLM | 99M | 18.4 | 33.4 ± 0.11 | 11.6 ± 0.10 | 25.8 ± 0.10 |
| Duo | 170M+35M | 21.3 ‡ | 31.4 ± 0.12 | 10.1 ± 0.10 | 25.0 ± 0.12 |
| E2D2 | 99M | 24.8 | 28.4 ± 0.11 | 8.3 ± 0.09 | 22.0 ± 0.10 |
| SeqDiffuSeq | — | 21.3 | 19.3 † | 1.7 † | 14.1 † |
| CDCD | — | 24.9 | — | — | — |
| ELF-B (ours) | 105M+35M | 26.4 | 36.0 ± 0.13 | 12.2 ± 0.11 | 27.8 ± 0.12 |
† = 直接取自该方法原 paper(De-En 数据的默认来源); ‡ = ELF 团队用公开 codebase 重跑(XSum 数据的默认来源);Duo De-En 在 ELF 团队的对比里也是 ‡(重跑)。
重要观察:
论文 App C (pages 20-23) 系统 ablate 7 个设计选择。这是"为什么 ELF 能 work"的实证支撑:
| Ablation | Sweep | 结论 / 默认 | 差距 |
|---|---|---|---|
| C.1 Prediction target | x-pred / v-pred / ε-pred | x-pred 全 dim 稳定 | ε-pred 全 dim 都 collapse(512/768/1024);v-pred 在 512 dim ok,越高越差 |
| C.2 Bottleneck dim | 32 / 128 / 512 | 128 最佳 frontier | 32 偏低 entropy;512 偏高 PPL;128 balance |
| C.3 Denoiser mode prob | 0.2 / 0.5 / 0.8 | 0.8 (denoise) / 0.2 (decode) | 0.5 / 0.2 (denoise) PPL/entropy frontier 都明显劣化 |
| C.4 Conditioning style | in-context tokens / adaLN-Zero | in-context 略优 + 省 43M 参数 | 性能 ≈ adaLN,但 ELF-B 148M → 105M |
| C.5 Optimizer | Muon / AdamW | Muon 全面优于 AdamW | SDE 下差距最显著(paper 定性 "more pronounced under SDE") |
| C.6 Sampler + time grid | ODE / SDE × uniform / logit-normal | SDE + logit-normal | SDE 通常降低 PPL(幅度随 model/CFG 变化,CFG=2 时 21-35%);logit-normal 在各 step 都更优,few-step 时尤其 |
| C.7 Cond CFG scale | 1 / 2 / 3 / 4 | CFG=2 最佳 | 1→2 substantially improves;3、4 逐步下降,过强 guidance 反而 degrade |
三种 prediction 是数学等价的,但训练 signal 完全不同。论文用三个 encoder size (T5-small/base/large = 512/768/1024 dim) sweep:
解释:clean 文本数据在 embedding 空间是低维流形。x-pred 预测的就是这个流形上的点; ε-pred 预测的是高维 Gaussian,模型必须学一个全维度等熵分布——更难。 这条 finding 支持"continuous DLM 的关键不是连续,是 x-prediction" 的 framing。
很反直觉:如果 decoder 占比上升到 0.5,按理说 decoder 应该学得更好——但实际整体 frontier 都退化。 解释:decoder 共享 transformer 主干。如果训练时频繁切换 mode,主干被两个目标拉扯;只占 20% 时 decoder 学到的是 "在已经 transport 到 clean 的 embedding 上做最后一步映射"——比例小但效果反而好。
Logit-normal time grid 让 noisy 区间更密——8-step 时这极其重要。
Uniform 8 步 PPL 远高于 logit-normal 8 步。32-step 之后两者差距收窄。
γ sweep(paper 选默认值):8/16 步默认 γ=2.0;32 步 γ=1.5;64 步 γ=1.0。
论文文字说 γ 控制 PPL/entropy trade-off,paper 默认 γ=1.0 作为各 step budget 的 balance;
8/16 步用更大 γ=2.0 是因为粗步长需要更多 stochasticity 修正噪声累积。
| 表 | 覆盖 | 最佳 valid Gen-PPL |
|---|---|---|
| Tab 6 (system-level, ELF-B, 6 seeds) | 8/16/32 step | 32-step SDE γ=1.5 SC-CFG=3: 24.08 ± 0.16 |
| Tab 7 (scaling, 64-step) | B/M/L × ODE/SDE × CFG 0.5-4 | 各 size 内最低 valid PPL:ELF-B 19.72 (CFG=3) / ELF-M 21.69 (CFG=3) / ELF-L 21.37 (CFG=4)。CFG>3 部分 cell 标灰 |
use_bf16 + use_compile这张图回答了"continuous DLM 到底在做什么"——它在 embedding 空间里描出一条平滑轨迹, 最后一步才把轨迹终点投影到 token 词表。和离散 DLM 每步都做 vocab argmax 完全是两套范式。
Cola-DLM(ByteDance Seed, arXiv 2605.06548, 2026 年 5 月) 是和 ELF 几乎同时(2 周内)冒出来的同类工作。两边都是 continuous DLM,但设计哲学几乎相反: ELF 求简,Cola 求强。下面是我从 Cola 的公开 blog 和 arXiv 摘要整理的对比。
ELF = "把 encoder 冻结,diffusion 只学 transport"——最小架构、最高数据效率,刷 OWT Gen-PPL。
Cola-DLM = "diffusion 不应该恢复 noisy token observation,应该建模 semantic latent prior"—— 两阶段训练(VAE pre-train → joint VAE+DiT)+ block-wise 推理,~2B 参数,刷 reasoning task average。
| 维度 | ELF (MIT) | Cola-DLM (ByteDance) |
|---|---|---|
| 核心对象 | Contextual embedding 上 Flow Matching | Text VAE latent 上 block-causal FM |
| Encoder | Frozen T5-small (35M) | Learnable Text VAE (~500M) |
| Latent space | Token-aligned, 512-d (bottleneck 128) | Explicit z, d=16 (默认) |
| Diffusion 目标 | 恢复 clean contextual embedding | 把噪声 transport 到 learned latent prior |
| Decoder | 共享 Transformer (final-step 切换 mode) | 独立 VAE decoder + KV cache |
| Backbone | DiT,全双向 attention | Block-causal DiT (intra-block 双向、inter-block 因果) |
| 训练 | 单阶段 80/20 mix | 两阶段训练 (VAE pretrain → joint VAE+DiT) + block-wise 推理 |
| 损失 | 80% MSE + 20% CE | Stage 2: λVAE·LVAE + λFM·LFM + λref·KL(q‖qref) |
| 参数 | 105M / 342M / 652M | ~2B 总 (1.8B DiT + 500M VAE) |
| 采样 | 32-64 步 ODE/SDE Euler | 8-16 步 / block, block-causal + KV cache |
| 评测 | Gen-PPL, BLEU, ROUGE | Task Avg (LAMBADA/MMLU/SIQA/RACE/...) |
| 对标 | MDLM/Duo/FLM/LangFlow | AR + LLaDA at 2B |
| 维度 | ELF | Cola |
|---|---|---|
| Diffusion 建模对象 | Token-aligned contextual embedding(T5 encoder output 上每个 position 一个 512-d vector) | 压缩后的 semantic latent z(VAE 编码出来的较低维向量) |
是否有显式 p(x)=∫p(x|z)p(z)dz | 无(直接对 contextual embedding 做 FM,最后一步 decode 到 token) | 有(VAE 提供 p(x|z),diffusion 学 p(z)) |
| Decoder | 共享 Transformer 主干,mode token 切换 + factored linear head | 独立 VAE decoder |
| 训练时中间步是否做 token-space loss | 不做(中间全是 MSE on embedding;只在最后步混 20% CE) | 不做(diffusion 在 latent 空间,CE 由 VAE 在两阶段训练里分担) |
| Attention 模式 | 非因果 / 全局 attention | Block-causal(block 内非因果,block 间因果,可 KV-cache) |
| 需要调的"杠杆" | 1 套 logit-normal schedule + denoise/decode 比例 + bottleneck dim | VAE 质量 + latent dim + logSNR + block size + anti-drift KL + CFG + 评测协议 |
p(x)=∫p(x|z)p(z)dz,diffusion 建 pψ(z)、VAE decoder 建 pθ(x|z)Cola 不像 ELF 那样单阶段端到端训练。两个训练阶段 + 独立的推理流程:
| 阶段 | 训练对象 | 损失 | 目标 |
|---|---|---|---|
| Stage 1 — VAE pretraining | Text VAE encoder + decoder | LVAE = −𝔼[log pθ(x|z0)] + β · KL(qφ‖pbase) + λmask·Lmask (带 BERT-style masking loss) |
学一个稳定的 text↔latent 映射,避免 semantic collapse / decoder shortcutting |
| Stage 2 — Joint VAE + DiT | VAE + block-causal DiT | Lstage2 = λVAE·LVAE-like + λfm·LFM + λref·KL(qφ‖qφ_ref) (reference KL 抑制 latent drift) |
DiT 在已稳定的 latent 上学 flow matching prior;reference KL 不让 VAE 漂移 |
| Inference (非训练阶段) | — | — | Block-wise prior transport(DiT 生成 latent block)+ VAE decoder(latent → tokens, KV-cached) |
关键超参(来自 blog RQ4 sweet spot 表,released checkpoint 配置):
Cola DiT 在序列维度上做块因果(block-causal)分解:
这种设计的好处:
Block size sweep(来自他们 RQ2/RQ3):
Cola 不报 Gen-PPL。他们用"generative few-shot"协议把多选题转成生成任务, 和 AR + LLaDA 在 ~2B 同 scale 下对比。来自 ByteDance-Seed/Cola-DLM GitHub model card 的 released 数字:
| Task | Cola @ 2000 EFLOPs | 说明 |
|---|---|---|
| LAMBADA | 50.80 | 段落补全 |
| MMLU | 19.30 | 57 类多选 — 数字仍低于 AR 同 scale |
| SIQA | 28.90 | 社会场景推理 |
| RACE | 19.60 | 阅读理解 |
| Story Cloze | 30.77 | 故事结尾选择 |
| OBQA | 23.00 | 开放式问答 |
| HellaSwag | 10.70 | 常识 NLI |
| SQuAD | 30.90 | 抽取式问答 |
| Task Avg | 26.75 | 8 任务平均 |
注意:blog 内 RQ2/RQ3 ablation 表给出更小的训练 budget 下数字(LAMBADA 31.1-34.6 / MMLU 5.4-10.1 / SIQA 11.1-23.6 等), 但这些是消融区间,不是 headline。上面才是 released checkpoint 数字。
他们的 scaling 实验跑到 ~2000 EFLOPs。Blog 说 "Task Avg 在 ~2000 EFLOPs 还在 rising"—— 官方称仍有 headroom,没看到饱和。
| 消融 | 结论 |
|---|---|
| Fixed-VAE vs Joint-VAE training | Joint 在大算力下 win。冻结 VAE 训练 (像 ELF 那样) 在小算力 ok,但 scaling 时无法继续涨 |
| All-Scratch baseline | 从零训 VAE + DiT 始终不如先 Stage 1 pretraining |
| Interval freezing (Stage 1.5) | VAE 在中间阶段冻结一段再放开 — 比一直 joint 差 |
| Sampling steps | 少步数明显不足;~8-10 步基本恢复;16-32 步饱和 |
| Patch size 2 (压缩 latent) | 整体比 patch 1 差;但当 prompt length 对齐 patch 边界时反而略好(18.12 vs 17.31)— "Token-level segmentation 不一定最优" |
这里有个有意思的对比:Cola 通过 ablation 证明 joint VAE+DiT 训练在大算力下更优;ELF 通过 ablation 证明 frozen encoder 在小数据小算力下更优。两个结论不矛盾—— 它们对应不同的训练规模、不同的"哪个组件更值得优化容量"的判断。
真正的来源是 frozen T5-small 的 contextual embedding 已经包含了语言的几何先验。 ELF 不学"语言是什么",只学"如何在这个空间里 transport"。等效于一种 transfer learning,T5 的预训练成本 (Google C4 上约 1T tokens)没被算进 ELF 的 45.2B tokens 里。Tab 5 只比较了 ELF vs baseline 的自身训练 token,不计 encoder pretrain。 这是这篇论文最容易被 attack 的点。
怎么辩护:把 ELF 看作"在 T5 表示上的 transfer-learning DLM"。MDLM/Duo 也间接用了 word-level tokenizer 的语言先验 (虽然程度不同)。但承认这是"有限的 12×",不是免费午餐。
大概率是的,但 paper 没做这个实验(App C.1 只 sweep 了 T5-small / base / large 三个 size)。 这是 ELF 最大的 follow-up 机会,也是它最脆弱的 claim—— ELF 的天花板可能就是 encoder 的天花板。如果有人用 LLaMA-3-8B hidden state + ELF flow 训一个 105M model,能不能逼近 LLaMA 自身 quality? 这个实验没人做过,是 obvious next step。
Paper App C.1 ablate 了:x-prediction 在所有 embedding 维度都稳定,v-prediction 在高维 degrade,ε-prediction 全面崩溃。 背后假设(paper 引用 Li & He 2511.13720):clean 文本数据在 embedding 空间是低维流形。 中间步加 CE 等价于"把噪声大的 z_t 强行映射回 token",这等于把 quantization wall 偷偷搬回来—— 破坏了 ELF 的"only final step rounding"核心 framing。
可以,但训练时把 SC-CFG 烤进 vtarget 后,推理只需要一次 forward,省一半算力。 代价是训练时3 次 forward(2 个 no-grad + 1 个 gradient-tracked)。 推理跑无数次,训练跑一次。划算。 这个 trick 是 image diffusion 圈 Chen et al. "Visual Generation without Guidance" (ICML 2025) 的方法,ELF 直接搬过来。 注意 ELF 仍然保留另一个推理时 CFG(input-cond CFG=2,用于 XSum/WMT),跟 SC-CFG 是两个独立机制。
论文用的就是 frozen GPT-2 Large 当 judge。Fig 7 里的 "Dataset" 虚线就是 OWT 真实文本在 GPT-2 Large 下的 reference PPL。 ELF-B 32 步 SDE 24.08 接近这条 reference;ELF-M 在 64 步 SDE / SC-CFG=3 下 21.69,更接近该 reference。 注意 "Dataset reference" 不是严格的理论下限—— 低于它也可能伴随低 entropy(即重复但流畅)。 而且这只是 GPT-2 Large judge 下的"流畅性"指标,不是通用质量下限: 换更强 judge(GPT-4 / Llama-3)数字会变。
App C.2 sweep 了 {32, 128, 512}(不含 256/64,论文没测)。 128 是 ODE / SDE 双采样下的 frontier balance 点: 32-d 在 SDE 下 PPL 最低但 entropy < 5(多样性不够),512-d 把 PPL 推高了。 背后假设:clean text 在 embedding space 是低维流形,128-d 就足够"覆盖"这个流形。 对比 image diffusion 也常用类似 bottleneck(DiT、SD 都有)。256 这个中间值 paper 没测,按 frontier 趋势插值应该 fine。
App C.5。Muon 对 2D 参数用 Newton-Schulz orthogonalization 把梯度先正交化再 step, 这抑制 ill-conditioned 方向,相当于 implicit second-order。 SDE 采样需要更精确的 v 预测(噪声重新注入会放大 v 的误差), Muon 训出来的模型在 v 上更"光滑",所以 SDE 推理时优势更大。 Paper 只定性说 "more pronounced under SDE",没给具体数字。 工程上 Muon 还有 fp32 NS5 + Nesterov bias correction 等 patches(详 §4.10)。
论文实验最长 seq=1024(OWT)和 1088(XSum)。当前 checkpoint 和 config 没有验证 8K 以上长度。
Architecture 上没有 causal LM 那种生成方向限制,但:
(a) RoPE buffer 按 max_length 预构建(要扩 8K 需重建或加 RoPE scaling);
(b) 全双向 self-attention 是 O(L²) 复杂度。
关于 "能 scale 到 8K context 吗"——architecture 上理论可行但 paper 没测;
可能需要 RoPE scaling + linear/sparse attention 改造,是 obvious extension。
VAE 路线的代价是 encoder 也得训,要解决 posterior collapse、latent drift、reconstruction trade-off 等额外问题。 ELF 用 frozen T5 把 encoder 部分外包给现成的预训练模型,所以 paper 短、ablation 干净、可以专心 ablate 7 个 ELF 自身的超参。 Cola 必须同时管理一大堆 VAE 超参(latent dim, block size, logSNR, KL ratio, anti-drift KL, multi-stage scheduling, ...)。 两边的代价不同:ELF 把复杂度外包给 Google 的 T5 训练;Cola 把复杂度内化到自己的训练 pipeline。
对,这是 generative few-shot 协议下的数字(把多选题转成生成)。 按 Cola 自己的 model card 和同协议对照(我没单独验证 AR 同 scale 的具体数字), Cola 的 LAMBADA 接近部分 AR baseline,MMLU 明显偏弱。 但 Cola 论文强调的是 scaling 趋势:曲线还在涨,没饱和。 如果 reviewer 说 "Cola 数字一般"——对,它的卖点是 architecture 可行性 + scaling shape,不是当下数字。 这类对比不要直接列 "AR 47-55" 这种具体数字,没 cite 不严谨。
ELF 的 DiT 是全双向 attention,每步 forward 都要重新算所有位置的 attention。 KV cache 需要 causal mask(前面 K/V 缓存供后面 query 用),ELF 没这个结构。 这是 ELF 推理慢于 AR 的根本原因之一: 即使 ELF 32 步 << AR 1024 token,每步的 attention 复杂度还是 O(L²)。 Cola 用 block-causal 解决了这个——这是 Cola 的核心架构卖点。
都没有。但 ELF 更接近"clean scientific demonstration",Cola 更接近"engineering system"。 具体看:
| 资源 | 链接 |
|---|---|
| ELF paper PDF | arXiv 2605.10938 |
| ELF GitHub (官方 JAX) | lillian039/ELF |
| ELF PyTorch port | pytorch_elf branch |
| ELF HF checkpoints | embedded-language-flows |
| Cola-DLM paper | arXiv 2605.06548 |
| Cola-DLM GitHub | ByteDance-Seed/Cola-DLM |
| Cola-DLM blog | hongcanguo blog |
| Cola-DLM HF | ByteDance-Seed/Cola-DLM |