📄 论文阅读 · Nous Research

Lighthouse Attention:让长上下文预训练只在"塔尖"上算注意力

原文 · Long Context Pre-Training with Lighthouse Attention(arXiv:2605.06554) | 作者 Bowen Peng, Subho Ghosh, Jeffrey Quesnelle(Nous Research) | 整理日期 2026-06-13

💡 一句话抓住重点

把 128K–1M 长序列预训练卡住的不是模型质量,是注意力 $\Theta(N^2)$ 的算力墙。Lighthouse 的做法是:在注意力外面套一层"金字塔池化 + 打分选 top-K",把全序列压成一条只有几万 token 的稠密子序列,丢给原封不动的 FlashAttention 去算;训练快结束时撤掉这层壳、用普通 SDPA 续训一小段,就能恢复成一个正常的稠密注意力模型——512K 上前向快 21×,端到端比稠密从头训省 1.4–1.7×,最终 loss 还更低。

21×
512K 上下文
前向加速
17.3×
512K 前向+反向
加速
1.4–1.7×
≥100K 端到端
训练提速
0 个
新增参数 /
辅助 loss
01 它要解决的问题

长上下文训练的瓶颈,从来不是"模型不够好",是"算不动"

前沿模型的上下文已经卷到 128K、1M 甚至更长——agentic 多步推理、长文档理解、交错的多模态输入都在往上推。但在这种尺度下,训练才是真正的硬件瓶颈:标准的缩放点积注意力(SDPA)有 $\Theta(N^2)$ 的算力和显存开销。FlashAttention 把常数项压下去了,但没有改变这个二次方的渐近复杂度——当 $N\geq 10^5$ 时,这一项就开始主宰一切。

一条主流的破局思路是用选择(selection)替代稠密注意力:每个 query 只关注一小撮 key。但论文指出,现有这些"稀疏注意力"方法虽然在推理上有效,却带着两个对长上下文预训练很不友好的设计基因:

⚖️

① 不对称(Asymmetry)

query 保持全分辨率,只把 key/value 池化。于是层次结构只是一块"被压缩的可寻址记忆",而不是真正的多尺度表示。

🔗

② 架构纠缠(Entanglement)

选择逻辑被焊死在注意力核里,导致现代 GPU 上高度优化的稠密注意力核无法复用——每个稀疏方法都得自带一套 kernel。

⚡ 训练专属的"灵魂拷问"

推理期的稀疏方法天生"质量有下限"——因为它只拿稠密前向当裁判来评估。但训练期的稀疏方法要过一道更狠的考:训练结束后,你训出来的这套权重还是不是一个合格的稠密注意力模型? 这正是 Lighthouse 立的核心正确性判据。

02 相关工作(重点展开)

在"打破二次方注意力"这条赛道上,前人都试过哪些路?

这是理解 Lighthouse 定位的关键一节。论文把破解 $\Theta(N^2)$ 的努力分成两大脉络——压缩与剪枝(Compression & Pruning)层次化与训练期正确性(Hierarchies & Training-time correctness)。下面把每条路、代表方法、以及它们的"软肋"都摊开讲。

脉络一:压缩与剪枝——三种"少算一点"的思路

面对二次方注意力,第一类回应是直接放弃 softmax、改用一个有界大小的状态;第二类保留 softmax 但在块(block)粒度上剪枝;第三类则在token 粒度上剪枝。论文把它们整理成下面这张谱系表:

家族代表方法核心做法软肋 / 代价
① 放弃 softmax
(有界状态)
Linear Attention、Mamba/SSM、Gated Linear Attention、RetNet、Log-linear Attention用一个固定大小的循环状态概括全部历史,换来线性甚至对数线性的渐近复杂度整个过去都压进一个状态,限制了长程精确召回——记不住远处的细节
② 块粒度剪枝
(保留 softmax)
训练无关:MInference、FlexPrefill、XAttention、SpargeAttention
端到端:MoBA、NSA
按块决定"保留/丢弃",能干净地映射到 tiled matmul每块只能做一个二选一决定,且只池化 key–value 一侧(不对称)
③ token 粒度剪枝推理期 KV-cache 驱逐:H2O、TOVA、SnapKV、LazyLLM、Quest、SparQ
端到端学习索引器:DSA(DeepSeek Sparse Attention)
逐 token 打分,保留最重要的若干 token;DSA 用一个可学的 indexer 端到端训练一旦选出 token,就被焊进注意力算子成为自定义稀疏 matmul / 逐 query gather——彻底堵死了对现成稠密 kernel 的复用
💡 我的看法:这张表的"题眼"在最后一列

三类方法的渐近复杂度各有胜负,但论文真正盯住的是它们共同的工程债:选择逻辑只要进了 kernel,就意味着你要为每个稀疏方案手写、手调一套 CUDA。而过去十年硬件厂商和社区把 FlashAttention 这套稠密 kernel 优化到了极致——稀疏方法等于主动放弃了这份免费的硬件红利。Lighthouse 的整个设计哲学,就是"我偏不进 kernel"。

脉络二:层次化注意力 & 训练期正确性

多分辨率(multi-resolution)注意力近年又回到了稀疏 LLM 注意力里,主要有两种玩法:

🏔️

"注意力直接读金字塔"

NSA、InfLLM-V2、Twilight、DoubleP 构造层次结构,让注意力本身从压缩分支 / 质心摘要 / 量化代理里去读。

🔌

"插件式替换索引器"

HISA 是 DSA 索引器的训练无关替换:跑一个"块→token"两阶段打分,把选中的 token 原封不动喂给 DSA 已有的 Sparse MLA 算子。

但论文一针见血地指出这两类的共同局限:层次结构只作用在 key 和 value 上,而最终选出来的结果仍然要喂进一个自定义的稀疏注意力核。换句话说,不对称 + 架构纠缠这两个原罪,它们一个都没躲掉。

🔥 联网补充:DSA / NSA / MoBA 是谁?

DSA(DeepSeek Sparse Attention)来自 DeepSeek,用可学索引器逐 token 打分选 top-k 喂进稀疏 MLA;NSA(Native Sparse Attention)端到端训练块级选择;MoBA(Mixture of Block Attention)来自 Moonshot(Kimi),把注意力做成"块的混合专家"。这三个是 2025 年最受关注的端到端稀疏注意力代表——而它们恰恰都是 Lighthouse 在"训练期正确性"这个问题上要正面叫板的对象(详见下方对照)。

Lighthouse 和它们到底差在哪?三个轴

对比轴NSA / HISA / InfLLM-V2 等Lighthouse
池化对象只池化 key/value(不对称)$Q,K,V$ 对称池化成连贯的多分辨率三元组 $(Q^{(\ell)},K^{(\ell)},V^{(\ell)})$
金字塔的用途注意力直接从层次结构里读 → 压缩记忆金字塔只用来排序和选择,选完之后是真正的稠密多尺度表示
注意力 kernel自定义稀疏 kernel,核内带稀疏索引原版 FlashAttention 跑在稠密子序列上,核内零稀疏逻辑
训练方式部分含可学选择器 / 辅助 loss穿过不可微 top-k(无 straight-through 估计),无辅助 loss、无新参数
正确性判据推理期方法继承稠密底座的下限;训练期方法的"留下的权重还能当稠密模型吗"无人正面回答明确以"短暂 dense-SDPA 续训后能否追平从头稠密训练"作为核心判据
💡 我的看法

这一节其实暗含一个很聪明的"举证责任"切换。推理期稀疏方法(含 HISA)有个先天的安全垫——它们从来不碰训练 loop,所以质量下限就是底座稠密模型。而训练期稀疏方法(MoBA、NSA)没有这个垫子,必须自证"我训出来的权重还是个好的稠密模型"。Lighthouse 干脆把这个最苛刻的问题立成了自己的 KPI,再用实验正面答下来——这是它在论文叙事上最有说服力的地方。

03 方法

把注意力层换成一条"四阶段流水线"——但绝不碰注意力核本身

Lighthouse 用一条四阶段流水线替换掉标准注意力层。关键词是"包住但不修改":选择发生在注意力之前,FlashAttention 在被挑出来的子序列上原样运行,注意力之后再把结果散射回原位置。选择由一个无参数的打分函数驱动,作用在该层自己 $Q,K,V$ 的多分辨率金字塔上——所以 Lighthouse 不引入任何新的可学参数

3.1 先回顾:标准注意力为什么贵

设输入 $X\in\mathbb{R}^{N\times d_{\text{model}}}$,投影矩阵 $W_Q,W_K,W_V$,因果掩码 $M$,标准缩放点积注意力为:

$$Q=XW_Q,\;K=XW_K,\;V=XW_V,\qquad \mathrm{Attn}(Q,K,V)=\mathrm{softmax}\!\left(\tfrac{QK^\top}{\sqrt{d}}+M\right)V$$
$QK^\top$$N\times N$ 的注意力矩阵——正是这一项带来 $\Theta(N^2 d)$ 的时间和显存开销
$\Theta(N^2 d)$FlashAttention 只能降常数、不能降阶;$N\geq 10^5$ 时它主宰全部成本

3.2 四阶段流水线总览

🏔️

① Pyramid

把 $Q,K,V$ 对称平均池化成 $L$ 级金字塔(因子 $p$)

🎯

② Score & Top-k

无参数打分,跨所有层级联合选出 top-$K$ 条目

③ Dense Attn

gather 成长度 $S$ 的稠密子序列,喂给原版 FlashAttention

📤

④ Scatter-back

把每个输出散射回它代表的 $p^\ell$ 个基础位置

⚡ 全篇最关键的"梯度路径"设计

top-$k$ 这一步是离散、不可微的——索引不携带梯度,打分函数也不参与训练。梯度只通过 ④散射 → ③FlashAttention → gather 回流到 $W_Q,W_K,W_V$。结果是:投影矩阵学到的是"被选中时要有用",而不是"要善于选择"——绕开了可学选择器那种臭名昭著的优化脆弱性(不需要 Gumbel-softmax,也不需要 straight-through 估计)。

3.3 金字塔构造(Pyramid)

第 $\ell$ 级是上一级的非重叠窗口池化。第 $\ell$ 级第 $i$ 个窗口定义为 $\mathcal{W}^{(\ell)}_i=[\,ip^\ell,\,(i+1)p^\ell-1\,]$,然后对三个投影对称地做均值池化:

$$Q^{(\ell)}_i=\mathrm{Pool}_\mu\{Q_j\mid j\in\mathcal{W}^{(\ell)}_i\},\quad K^{(\ell)}_i=\mathrm{Pool}_\mu\{K_j\},\quad V^{(\ell)}_i=\mathrm{Pool}_\mu\{V_j\}$$
$\ell=0$第 0 级就是原始全分辨率序列(窗口 = 单个 token)
对称池化这是和 NSA/HISA/InfLLM-V2 的本质区别——它们只池化上下文侧;Lighthouse 对 $Q,K,V$ 三者一视同仁

对称带来两个后续阶段会用到的关键性质:池化后的 query $Q^{(\ell)}_i$ 和池化后的 key $K^{(\ell)}_j$ 处在同一表示空间;每个金字塔条目都是一个概括同一段 $p^\ell$ token 跨度的连贯 $(Q,K,V)$ 三元组。金字塔总条目数 $\sum_{\ell=0}^{L-1}N/p^\ell \leq N\cdot p/(p-1)$,所以构造金字塔只花 $\Theta(N)$ 时间和显存。

3.4 打分与选择(Scoring & Selection)

每个金字塔条目拿两个标量分——一个作为 query(QK),一个作为 key(KQ)。第 0 级直接用每头的 $\ell_2$ 范数:

$$s^{\mathrm{QK}}_{0,i}=\|Q_i\|_2,\qquad s^{\mathrm{KQ}}_{0,i}=\|K_i\|_2$$

更粗的层级不重算,而是从第 0 级max-pool 上来——让一个粗跨度继承它内部最强 token 的重要性。然后跨所有层级、把 QK 和 KQ 两条流拼起来联合选 top-$K$:

$$\mathcal{I}=\mathrm{TopK}\!\left(\{s^{\mathrm{QK}}_{\ell,i},\,s^{\mathrm{KQ}}_{\ell,i}:(\ell,i)\in\mathcal{P}\},\,k\right)$$
max-pool粗层不从池化后的投影重算范数,而是继承其覆盖范围内 token 范数的最大值——cheap 且"宁可错保不可漏选"
最粗层全保最粗层级永远全量保留——它便宜,且保证每个基础位置至少有一个贡献者;剩余预算花在更细层级上

3.5 稠密子序列注意力(Gathered-Sequence Attention)

按索引 $\mathcal{I}$ 把选中的三元组拼成一条连续子序列 $\widetilde{Q},\widetilde{K},\widetilde{V}$,长度为:

$$S=\frac{N}{p^{L-1}}+(L-1)\,p\,k$$
$N/p^{L-1}$最粗层贡献的全部条目数
$(L-1)pk$其余每层最多贡献 $pk$(因子 $p$ 是因果边界的记账)

代入一个真实数字感受一下:当 $N=10^6,\,L=4,\,p=4,\,k=4096$ 时,$S\approx 6.5\times10^4 \ll N$——一百万 token 的序列,实际进注意力的只剩六万五。然后直接用原版 SDPA / FlashAttention:

$$\widetilde{O}=\mathrm{Attn}(\widetilde{Q},\widetilde{K},\widetilde{V};\widetilde{M})$$

由于 gather 是按金字塔坐标拓扑排序过的,因果掩码 $\widetilde{M}$ 退化成一个标准的 $S\times S$ 因果掩码——式中没有任何稀疏索引。这就是 Lighthouse 能直接复用 FlashAttention 的根本原因。

3.6 散射回写(Scatter-Back)

最后一步用一个确定性的"整数原子散射"kernel,把每个条目的输出 $\widetilde{O}_m$ 分发回它所代表的 $p^\ell$ 个基础位置,得到完整的 $O_t$。整条流水线里只有② top-K④ scatter-back 是自定义 CUDA/Triton kernel,其余全是 torch.compile 能融合的 PyTorch 原语。

04 复杂度与 kernel 设计

为什么它能做到"近线性",又为什么能直接吃 FlashAttention 的红利

整条流水线唯一超线性于 $N$ 的项,就是那条稠密子序列注意力 $\Theta(S^2 d)$。而 $S=N/p^{L-1}+(L-1)pk$。如果选 $L=\log_p(N/k)$ 来平衡 $S$ 的两项,就得到 $S=\Theta(k\log_p(N/k))$,注意力成本变成:

$$\Theta(k^2\log^2 N\cdot d)\quad\text{——在固定 }k\text{ 下,对 }N\text{ 是多对数(polylog)级}$$

加上线性的打分和 $\Theta(N\log k)$ 的选择,单层总算力对 $N$ 在固定 $k$ 下是近线性(仅差一个 $\log k$ 因子)。

💡 我的看法:真正的工程杀招是"gather 与 attention 解耦"

NSA、DSA、HISA、MoBA 都把选择嵌进自定义稀疏 kernel,而 Lighthouse 交给 FlashAttention 的是一条连续的稠密子序列。这带来三个白捡的好处:(1) 前向/反向和稠密 Transformer 逐 bit 一致;(2) 上下文并行可以让 gather 走标准 ring attention 轮转,不需要任何稀疏感知的集合通信;(3) 直接撑起 1M token / 32 卡 Blackwell 的训练。这是"不进 kernel"哲学最甜的回报。

05 实验

三个问题:能恢复吗?能扩展吗?四个旋钮怎么调?

实验基座:一个 530M 参数的 Llama-3 风格 decoder($d_{\text{model}}{=}1024$,30 层,8 头,head dim 128),在 C4 上以序列长度 98,304 训练,AdamW,总预算固定 16,000 步(约 50B tokens)。第 0、1、28、29 层保留稠密 SDPA,其余 26 层用 Lighthouse。两阶段配方:阶段一用 Lighthouse 训练,阶段二载入 checkpoint 用稠密 SDPA 续训一小段。

5.1 可恢复性(Recoverability)——全篇的"命门实验"

固定 16k 步总预算,变化阶段一的长度(10k / 11k / 12k),剩下的步数切回稠密 SDPA 续训,与一个从头稠密训练、架构/数据/token 全部对齐的基线对比。

配置LH 步 + SDPA 步B200 小时 ↓Tok/s (k) ↑最终 Loss ↓
SDPA 从头训练(基线)— / 16k303.245.60.7237
LH→SDPA12k + 4k214.774.70.7102
LH→SDPA11k + 5k219.675.40.7001
LH→SDPA10k + 6k228.075.00.6980

规律很干净:切回稠密时 loss 会先瞬间尖刺(1.12–1.57,因为模型第一次跑它没训过的注意力),然后约 1–1.5k 步内恢复,并反超基线。三个切换点最终都落在 0.6980–0.7102,全部低于稠密基线的 0.7237——而且 B200 小时还更省

🔥 关键结论

"hierarchical 训练不会掏空模型在推理时使用全注意力的能力"——这是 Lighthouse 那道"灵魂拷问"的正面回答,而且恢复对切换点不敏感(配方不靠精确的时间表吃饭),在不比从头稠密训练多花一个 token 的前提下做到。

5.2 扩展性:和稠密注意力的延迟对比

forward latency
论文 Fig.3(a):前向延迟 vs 上下文长度(单 B200,$L{=}3,p{=}4$,稀疏度约 1:64)。 绿线 SDPA(cuDNN) 按 $\Theta(N^2 d)$ 飙升——512K 时超 400ms;红线 Lighthouse 按 $\Theta(S^2 d)$ 几乎贴地。512K 上 Lighthouse 快 21.0×;换个说法,SDPA 跑到 ~113K 上下文才达到 Lighthouse 在 512K 的延迟。
backward latency
论文 Fig.3(b):前向+反向延迟。 同样的趋势——512K 上 Lighthouse 快 17.3×,相当于 SDPA 在 ~122K 上下文的运行时间。反向比前向加速略低,因为反向要穿过 gather/scatter,但量级优势依旧保持。

全模型训练故事类似但需要小心:530M 架构在单卡 B200 上超过 ~100K 就会因激活/梯度/优化器状态 OOM(与注意力方法无关)。于是论文实现了上下文并行:金字塔池化、打分、top-K 都分片本地跑,gather 出的子序列走原版 ring attention 轮转,无稀疏感知集合通信。CP 引入约 10% 的每卡吞吐开销,但 Lighthouse-vs-SDPA 的乘性加速在同等 CP 几何下完整保留,一路干净地扩展到 1M token / 32 卡。

5.3 设计消融:四个旋钮(scorer / $p$ / $L$ / $k$)

发现具体数据
① 每个配置都打平或超过稠密基线所有 Lighthouse 配置最终 loss 都 ≤ 稠密基线 0.7237——可恢复性不挑超参
② 无参数 norm scorer 几乎免费与 dilated softmax 在 ±0.01 内(无一致赢家),但无参数、且 B200 小时省约 9%(179.6–180.9 vs 197.2–199.7)
③ 更小的 $p,L,k$ 反而略好全网格最低 loss 是 $L{=}3,p{=}2,k{=}1536$(0.6825)。最反直觉的是 $k$:loss 随 $k$ 缩小单调下降(4096→1536:0.6951→0.6825),疑似 hierarchical 选择在该 token 预算下起了正则化作用

吞吐侧:Lighthouse 阶段一稳定在 84–126k tok/s/GPU,对比稠密 SDPA 的 ~46k,约 每步优势。端到端(10k+6k 配方)总运行时间从 22.5h(179.6 B200-h)到 27.0h,对比稠密从头训练的 37.9h(303.2 B200-h)——即 1.4–1.7× 端到端提速。

5.4 长上下文检索(Needle-in-a-Haystack)

niah heatmap
论文 Fig.4:单数字 passkey 检索热力图(深度 0–100%,上下文 4K–96K,随机基线 10%)。 四个 Lighthouse 两阶段 run + 稠密基线(0.72)。绿=高命中、红=低。三个 Lighthouse run 打平或超过稠密基线:$k{=}2048$ dilated 以 0.76 夺冠,$k{=}1536$ dilated 0.73,$k{=}2048$ norm 0.72 持平;只有 $k{=}1536$ norm 掉到 0.65。
⚡ 一个 loss 看不出的反差

检索上 $k$ 越大越好(与 loss 侧"$k$ 越小越好"相反),且 norm scorer 伤检索比伤 loss 更狠(固定 $k$ 时换 dilated→norm,检索掉 0.04–0.08,是全网格最大单轴差距)。结论:默认值该选 dilated-大$k$ 还是 norm-小$k$,取决于下游任务是 loss 驱动还是检索驱动——没有银弹。

06 结论、局限与未来

一个"训练用稀疏、推理用稠密"的干净配方

Lighthouse 把 $Q,K,V$ 对称池化进多分辨率金字塔、把选择放在注意力核外面,让注意力步退化成稠密子序列上的原版 FlashAttention。它无参数、端到端训练、无辅助 loss、无 straight-through 估计,还能原样继承上游 FlashAttention 的每一次改进。一段短暂的稠密 SDPA 续训后,在等 token 预算下追平或超过从头稠密训练(loss 和长上下文检索都成立),≥100K 上对 cuDNN SDPA 端到端快 1.4–1.7×,并干净扩展到多节点 Blackwell 的 1M token。

⚠️

局限一:解码不适配

对称 $Q/K/V$ 池化假设所有 query 在同一次前向里共现——这违背了自回归解码。所以推理就绪的模型靠的是 dense-SDPA 续训,所有下游评测都在续训之后跑,而非直接在 hierarchical 前向上。

📐

局限二:不是严格线性

内层注意力是 $\Theta(S^2 d)$——固定 $k$ 时对 $N$ 次二次方,但不是严格线性。当 $k$ 必须随 $N$ 增长的场景,仍未被刻画。

💡 未来方向

① 把稠密 SDPA 续训换成非对称稀疏目标(DSA/NSA/HISA/MoBA),就能产出原生可服务的 checkpoint;② 逐层 / 逐头自适应 $k$ 可能胜过固定预算;③ 多尺度金字塔天然能延伸到视觉、音频、视频;④ 服务侧集成(连续批处理、投机解码、KV-cache 管理)是把训练加速翻译成部署收益的最后一公里。

07 整理者点评

这篇论文真正聪明的地方在哪?

站在做长上下文训练和推理(KV-Cache、前缀缓存这些)的人的角度,我觉得这篇有三点值得反复琢磨:

1 它把"工程债"变成了"设计原则"

"选择不进 kernel"听起来像个工程偷懒,实则是整篇论文的承重墙——正因为注意力步是逐 bit 的稠密 FlashAttention,CP/ring attention 才能零成本接上,1M token 训练才水到渠成。这是把"复用现成 kernel"从一个 nice-to-have 提升成了架构第一原则。

2 "梯度不穿过选择器"是反共识但务实的

主流做法总想让选择可学(Gumbel-softmax、straight-through)。Lighthouse 反其道:top-k 完全不可微,让投影学"被选中时有用"而非"善于选择"。这等于主动放弃了一部分表达力,换来训练稳定性——而消融显示无参数 norm scorer 几乎免费。这种"少即是多"的取舍在工程上很对胃口。

3 "训练-推理两阶段"是诚实的,也是它最大的软肋

论文没有藏着掖着:对称池化和自回归解码天然冲突,所以它老老实实承认"推理要靠续训恢复成稠密模型"。这既是诚实,也意味着它目前不是一个推理加速方案——加速全在训练侧。对我们这种被线上 KV-Cache 跨机命中折磨的场景,它直接的帮助有限;但"训练期用便宜的层次注意力、推理期切回稠密"这个范式,本身是值得借鉴的思路。

🚫 容易踩的理解误区

别把 Lighthouse 当成"又一个推理稀疏注意力"。它和 DSA/NSA/MoBA 的赛道不完全重叠:那些方法的目标是推理省算力,而 Lighthouse 的核心 KPI 是训练省算力且不损害最终稠密模型质量。读它的实验时,记住所有评测都跑在"续训成稠密之后"——这是理解它一切结论的前提。