MSA: Memory Sparse Attention for Efficient End-to-End Memory Model Scaling to 100M Tokens

作者: Yu Chen, Runkai Chen, Sheng Yi, Xinda Zhao, Xiaohong Li, Jianjin Zhang, Jun Sun, Chuanrui Hu, Yunyun Han, Lidong Bing, Yafeng Deng, Tianqiao Chen 年份: 2025 会议: arXiv 分类: 高效推理与部署

论文笔记:MSA: Memory Sparse Attention for Efficient End-to-End Memory Model Scaling to 100M Tokens

元信息

项目内容
机构未明确标注(多机构合作)
日期March 2025
项目主页未提供
对比基线RAG, HippoRAG2, MemoryAgent, KaLMv2
链接arXiv

一句话总结

提出 Memory Sparse Attention,通过可微的稀疏注意力机制实现端到端的百万级 token 长上下文建模,从 16K 到 100M tokens 仅 8.8% 性能退化

核心贡献

端到端可微的稀疏注意力框架: 通过 Router K Projector + Top-k 选择实现 稀疏注意力,保持线性复杂度的同时支持梯度反传

Memory Interleave 机制: 支持自适应多跳检索推理,模型自主决定检索文档数量和时机,无需外部编排

100M tokens 推理方案: 通过 GPU-CPU 分层存储 + 异步预取,在 2×A800 上实现 1 亿 token 的端到端推理

问题背景

要解决的问题

如何让 LLM 拥有终身记忆(lifetime memory),能在百万甚至亿级 token 的上下文中精确检索和推理

现有方法的局限

参数化记忆LoRA、CPT):需要重新训练,存在灾难性遗忘,无法支持动态更新

外部存储RAG):检索器和生成器分离,语义表示不一致,多跳推理能力弱

线性注意力DeltaNetRWKV):有损压缩,精度低,存在灾难性遗忘

标准稀疏注意力(DSA):不支持终身级别的上下文规模

本文的动机

将稀疏注意力从”有限窗口”扩展到”终身规模”,同时保持端到端可微性和高精度,避免 RAG 的检索-生成割裂问题

方法详解

模型架构

MSA 基于 Qwen3-4B-Instruct 进行改造:

  • 输入: 查询 qq + 外部记忆文档集合 {d1,d2,,dL}\{d_1, d_2, \ldots, d_L\}
  • Backbone: Qwen3-4B(18 层,8 头,头维度 128)
  • 核心模块: 后半部分层(9-18 层)替换为 MSA 层,前半部分保持标准 自注意力
  • 输出: 自回归生成答案 / 文档 ID 序列
  • 总参数: ~4B

核心模块

模块1: MSA Layer(稀疏注意力层)

设计动机: 将检索和生成统一到单一注意力层,避免 RAG 中检索器与生成器的表示割裂

具体实现:

  • 每个文档的隐状态 HiH_i 通过三组投影得到 KiK_iViV_i(内容 KV)和 KiRK_i^R(路由 Key)
  • 路由 Key 通过 64-token 的 均值池化 压缩为 chunk 级表示 KˉijR\bar{K}^R_{ij}
  • 查询的路由向量 QqRQ^R_q 与所有文档的 KˉR\bar{K}^R 计算 余弦相似度,取 Top-16 文档
  • 选中文档的完整 KV 与查询 KV 拼接后执行标准注意力

模块2: Parallel & Global RoPE(位置编码策略)

设计动机: 解耦文档位置和全局位置,使模型对记忆规模不敏感

具体实现:

  • 每个文档独立使用 旋转位置编码,位置 ID 从 0 开始(Parallel RoPE)
  • 活跃上下文使用 Global RoPE,位置偏移量 = 检索文档数 kk(默认 16)
  • 两种 RoPE 在注意力计算时自然融合,无需额外参数

模块3: Memory Interleave(记忆交错机制)

设计动机: 支持多跳推理,让模型自主决定何时检索、检索多少文档

具体实现:

  • 模型自回归生成文档 ID 序列,以特殊分隔符结束
  • 检索对应文档文本,追加到查询上下文
  • 重复上述过程直到模型认为证据充分,转入答案生成
  • 文档数量自适应,不需要预设固定值

训练流程

持续预训练(158.95B tokens)

两阶段优化:

  • Warm-up 阶段:侧重路由学习,L=0.1LLLM+Laux\mathcal{L} = 0.1\mathcal{L}_{\text{LLM}} + \mathcal{L}_{\text{aux}},学习率 1×1041 \times 10^{-4}
  • 主训练阶段:侧重语言建模,L=LLLM+0.1Laux\mathcal{L} = \mathcal{L}_{\text{LLM}} + 0.1\mathcal{L}_{\text{aux}},学习率 6×1066 \times 10^{-6}

后训练(SFT)

Stage 1: 8K 上下文长度的指令微调

Stage 2: 扩展到 64K 上下文,严格数据清洗

推理架构(三阶段流水线)

  1. 全局记忆编码(离线):预计算所有文档的 Kˉ\bar{K}Vˉ\bar{V}KˉR\bar{K}^R,存入结构化缓存
  2. 路由与上下文组装(在线):查询与路由 Key 匹配,选出 Top-16 文档,从 CPU 异步加载对应 KV
  3. 稀疏生成(在线):在组装好的稀疏上下文上自回归生成

关键公式

公式1: KV 投影与路由投影

Ki,h=HiWKh,Vi,h=HiWVh,Ki,hR=HiWKRhK_{i,h} = H_i W^h_K, \quad V_{i,h} = H_i W^h_V, \quad K^R_{i,h} = H_i W^h_{K^R}

含义: 从文档 ii 的隐状态 HiH_i 生成三组投影——标准内容 Key/Value 和专用路由 Key

符号说明:

  • HiH_i: 第 ii 个文档的隐状态矩阵
  • WKh,WVhW^h_K, W^h_V: 第 hh 个注意力头的标准 KV 投影矩阵
  • WKRhW^h_{K^R}: 第 hh 个注意力头的路由 Key 投影矩阵

公式2: 相关性评分

Sij=maxtoken t(meanhead h(cos((Qq,hR)t,Kˉij,hR)))S_{ij} = \max_{\text{token } t} \left( \underset{\text{head } h}{\text{mean}} \left( \cos\left( (Q^R_{q,h})_t, \bar{K}^R_{ij,h} \right) \right) \right)

含义: 计算查询与第 ii 个文档第 jj 个 chunk 的相关性分数,跨头取均值后在 token 维度取最大值

符号说明:

  • Qq,hRQ^R_{q,h}: 查询的路由向量(第 hh 个头)
  • Kˉij,hR\bar{K}^R_{ij,h}: 第 ii 个文档第 jj 个 chunk 的压缩路由 Key
  • cos(,)\cos(\cdot, \cdot): 余弦相似度

公式3: 上下文组装

Kctx=[{Kˉi}iI;Kq],Vctx=[{Vˉi}iI;Vq]K_{\text{ctx}} = [\{\bar{K}_i\}_{i \in \mathcal{I}}; K_q], \quad V_{\text{ctx}} = [\{\bar{V}_i\}_{i \in \mathcal{I}}; V_q]

含义: 将 Top-k 选中文档的 KV 与查询的 KV 拼接,形成稀疏注意力上下文

符号说明:

  • I\mathcal{I}: Top-k 选中文档的索引集合
  • Kˉi,Vˉi\bar{K}_i, \bar{V}_i: 第 ii 个文档的内容 Key/Value
  • Kq,VqK_q, V_q: 查询的 Key/Value

公式4: 注意力输出

Output=Attention(Qq,Kctx,Vctx)\text{Output} = \text{Attention}(Q_q, K_{\text{ctx}}, V_{\text{ctx}})

含义: 在组装好的稀疏上下文上执行标准注意力计算

符号说明:

  • QqQ_q: 查询的 Query 向量

公式5: 监督对比损失

Laux=1Pi=1Plogexp(si+/τ)exp(si+/τ)+j=1Nexp(si,j/τ)\mathcal{L}_{\text{aux}} = -\frac{1}{|\mathcal{P}|} \sum_{i=1}^{|\mathcal{P}|} \log \frac{\exp(s^+_i / \tau)}{\exp(s^+_i / \tau) + \sum_{j=1}^{|\mathcal{N}|} \exp(s^-_{i,j} / \tau)}

含义: 监督路由器的层级决策,使正样本文档的相关性分数高于负样本

符号说明:

  • P\mathcal{P}: 正样本文档集合
  • N\mathcal{N}: 负样本文档集合
  • si+s^+_i: 第 ii 个正样本的相关性分数
  • si,js^-_{i,j}: 第 jj 个负样本的相关性分数
  • τ\tau: 温度参数

公式6: 训练复杂度

Otrain=O(LG)+O(MLP)+O((M+kGP)2)=O(LG)\mathcal{O}_{\text{train}} = \mathcal{O}(LG) + \mathcal{O}\left(\frac{ML}{P}\right) + \mathcal{O}\left(\left(M + \frac{kG}{P}\right)^2\right) = \mathcal{O}(LG)

含义: 训练复杂度由独立文档处理主导,为线性

符号说明:

  • LL: 记忆总大小(文档数)
  • GG: 单个文档长度
  • MM: 查询长度
  • PP: 池化大小(64)
  • kk: Top-k 选择数(16)

公式7: 推理复杂度

Oinference=O(MLP)+O(T(M+kGP)2)=O(L)\mathcal{O}_{\text{inference}} = \mathcal{O}\left(\frac{ML}{P}\right) + \mathcal{O}\left(T \cdot \left(M + \frac{kG}{P}\right)^2\right) = \mathcal{O}(L)

含义: 每次查询的推理复杂度为线性,离线预处理 O(LG)\mathcal{O}(LG) 可分摊

符号说明:

  • TT: 答案生成长度

公式8: 训练损失函数

Warm-up 阶段:

L=0.1LLLM+Laux\mathcal{L} = 0.1 \mathcal{L}_{\text{LLM}} + \mathcal{L}_{\text{aux}}

主训练阶段:

L=LLLM+0.1Laux\mathcal{L} = \mathcal{L}_{\text{LLM}} + 0.1 \mathcal{L}_{\text{aux}}

含义: 两阶段课程学习——先重点训练路由器,再重点训练语言建模

符号说明:

  • LLLM\mathcal{L}_{\text{LLM}}: 标准语言模型损失(交叉熵)
  • Laux\mathcal{L}_{\text{aux}}: 监督对比损失

关键图表

Figure 1: Scalability / 可扩展性展示

Figure 1: Scalability{:width 600}

说明: MSA 在 MS MARCO 数据集上从 16K 扩展到 100M tokens 的性能曲线,仅 8.8% 退化(4.023 → 3.669)。展示了 MSA 集成 Top-k 选择与 稀疏注意力 的可微分设计,兼具可扩展性。

Figure 2: MSA Layer Architecture / MSA 层架构

Figure 2: MSA Layer{:width 600}

说明: MSA 层的路由机制与注意力计算流程。文档通过 Router K Projector 生成路由 Key,经 均值池化 压缩后与查询路由向量计算 余弦相似度,选出 Top-k 文档,最终拼接 KV 执行注意力。

Figure 3: Three-Stage Inference with Memory Interleave / 三阶段推理与记忆交错

Figure 3: Inference Process{:width 600}

说明: 三阶段推理流水线。Stage 1 离线编码所有文档 KV;Stage 2 在线路由选择 Top-16 文档并从 CPU 加载 KV;Stage 3 结合 Memory Interleave 进行多跳自回归生成。GPU 存储路由 Key(~56GB),CPU 存储内容 KV。

Figure 4: NIAH Benchmark Results / 大海捞针评测结果

Figure 4: NIAH Results{:width 600}

说明: 32K-1M token 范围内的 Needle-In-A-Haystack 准确率。MSA 在 1M tokens 时仍保持 94.84%,仅比 32K 基线(98.77%)下降 3.93 个百分点。对比 Qwen3-4B 原始模型在 1M 时暴降至 24.69%,MemoryAgent-14B 下降 5.76 个百分点。

Table 1: 长期记忆方法对比

方法类型终身记忆精度兼容主流 LLM计算复杂度记忆管理灾难性遗忘
参数化(LoRA/CPT)训练高/推理低
测试时训练(Titans)
RAGO(L)\mathcal{O}(L)
MemAgent
稀疏注意力(DSA)
线性注意力(DeltaNet/RWKV)O(L)\mathcal{O}(L)
MemGen
MSA(Ours)O(L)\mathcal{O}(L)

说明: MSA 是唯一同时具备终身记忆、高精度、高兼容性、线性复杂度且无灾难性遗忘的方法

Table 2: 同骨干网络 RAG 对比(LLM Judge 0-5 分)

数据集TokensQwen3-4B R@1Qwen3-4B R@10Qwen3-4B(RR) R@10HippoRAG2 R@10MSA
MS MARCO v17.34M2.8933.0053.0173.0194.141
Natural Questions1.47M3.4523.2973.3853.3743.545
DuReader277K3.7263.5943.6073.4154.155
TriviaQA10M4.1334.2734.3914.3674.621
NarrativeQA538K1.6112.8603.5362.6553.395
PopQA1.18M2.9593.2993.2663.2493.433
2WikiMultiHopQA722K1.0653.1363.1593.3304.280
HotpotQA1.35M2.2523.7874.0223.9704.061
MuSiQue1.41M0.9361.9281.9652.0952.211
平均2.5593.2423.3723.2753.760

说明: MSA 以自适应检索击败所有同骨干 RAG 基线,平均领先 16.0%(vs Qwen3-4B R@1)。多跳数据集(2Wiki、HotpotQA)优势尤为显著。

Table 3: SOTA RAG 系统对比(LLM Judge 0-5 分)

数据集KaLMv2+Qwen3-235B R@10KaLMv2+Qwen3-235B(RR) R@10KaLMv2+Llama3.3 R@10KaLMv2+Llama3.3(RR) R@10MSA
MS MARCO v13.0272.9952.9192.9524.141
Natural Questions3.6943.6453.6623.6473.545
DuReader4.0443.8913.7423.7804.155
TriviaQA4.5784.5554.7194.6954.621
NarrativeQA2.4273.3752.3823.3173.395
PopQA3.3963.3763.3053.3623.433
2WikiMultiHopQA3.5823.5833.4453.5414.280
HotpotQA4.2254.1944.1274.2034.061
MuSiQue2.6472.6052.2582.6142.211
平均3.5063.5803.3963.5683.760

说明: MSA(4B 参数)与使用 235B 参数骨干的 SOTA RAG 系统竞争力相当,平均分数最高。在 MS MARCO 和 2WikiMultiHopQA 上大幅领先,但在 Natural Questions 和 MuSiQue 上略低。

Table 4: 消融实验(0-5 分)

配置平均MS MARCONQDuReaderHotpotQA
MSA-S2(完整)3.9764.1413.5454.1554.061
MSA-S1(无 Stage 2 SFT)3.6943.1973.4934.0644.020
w/o Memory Interleave3.4973.1753.4854.0763.250
w/o 持续预训练2.5372.2672.4483.1442.289
w/o 原始文本2.3252.6252.1902.1862.297

关键发现: 持续预训练是最关键组件(去除后下降 31.3%),其次是原始文本(下降 37.1%)。Memory Interleave 对多跳任务影响最大(HotpotQA 下降 19.2%)。课程学习(S2 vs S1)贡献 7.6% 平均提升。

Table 5: 预训练数据构成

数据类别查询数Token 数主要来源
长上下文与指令微调5.8M6.46BKaLM finetune data
学术科学文献2.0M28.74BS2ORC, SPECTER
通用 QA 与社区知识4.9M75.80BYahoo Answers, WikiAnswers, MS MARCO, PAQ 等
新闻与摘要2.1M36.43BAG News, NPR, CNN/DailyMail, XSum
领域特定2.1M28.94BAmazon Reviews, CodeSearchNet, WikiHow
总计17.9M158.95B跨 17 个数据源

说明: 训练数据覆盖科学文献、QA、新闻、代码等多领域,非 KaLM 数据集均下采样至 500K 查询

实验

数据集

数据集规模特点用途
MS MARCO v17.34M tokens大规模信息检索QA 评测
Natural Questions1.47M tokens开放域 QAQA 评测
DuReader277K tokens中文阅读理解QA 评测
TriviaQA10M tokens大规模 triviaQA 评测
NarrativeQA538K tokens叙事理解QA 评测
PopQA1.18M tokens流行文化 QAQA 评测
2WikiMultiHopQA722K tokens多跳推理QA 评测
HotpotQA1.35M tokens多跳推理QA 评测
MuSiQue1.41M tokens多步推理QA 评测
NIAH32K-1M tokens大海捞针长上下文评测

实现细节

Backbone: Qwen3-4B-Instruct-2507

MSA 层: 模型后半部分(9-18 层),共 10 层

注意力头数: 8

头维度: 128

Top-k: 16 文档

Chunk 大小: 64 tokens(Mean Pooling kernel)

预训练: 158.95B tokens,两阶段学习率(1e-4 → 6e-6)

SFT Stage 1: 8K 上下文

SFT Stage 2: 64K 上下文

推理硬件: 2×A800 GPU(160GB aggregate VRAM)

路由 Key 存储: ~56GB GPU VRAM

内容 KV 存储: CPU DRAM,异步预取

NIAH 评测结果

模型32K1M退化
MSA98.77%94.84%-3.93pp
MemoryAgent-14B-5.76pp
Qwen3-4B-Instruct24.69%严重崩溃

可扩展性结果

上下文长度MSA 评分退化率
16K4.023
100M3.6698.8%

批判性思考

优点

统一框架: 将检索和生成融合到单一注意力机制,避免了 RAG 的模块割裂问题,端到端可微分训练

极致可扩展性: 100M tokens 仅 8.8% 退化,远超现有方法。线性复杂度使得理论上可无限扩展

工程实用性: GPU-CPU 分层存储方案使得在消费级硬件上也能处理超长上下文

自适应多跳: Memory Interleave 让模型自主决定检索策略,比固定 Top-k 的 RAG 更灵活

局限性

基座模型限制: 仅在 Qwen3-4B 上验证,未在更大模型(7B/14B/70B)上实验,可扩展性到更大模型未知

预训练代价高: 158.95B tokens 的持续预训练成本不低,且消融实验表明去除预训练后性能暴跌 31.3%,说明方法强依赖于大量训练

评测指标单一: 主要依赖 LLM Judge(0-5 分),缺乏 EM/F1 等客观指标的对比

路由 Key 存储: 100M tokens 需要 ~56GB VRAM 仅用于存储路由 Key,对显存要求仍然较高

无代码开源: 截至目前未提供代码和预训练模型

潜在改进方向

多级路由: 当前单级 Top-k 路由可能不够高效,可考虑层次化路由(先粗选再精选)减少计算

路由 Key 压缩: 56GB 路由 Key 存储仍然较大,可尝试量化或更激进的池化策略

跨模型泛化: 验证 MSA 层作为插件在不同 LLM 架构上的即插即用能力

流式记忆更新: 当前是离线编码全部文档,可探索增量更新机制

可复现性评估

  • 代码开源(未提供)
  • 预训练模型(未提供)
  • 训练细节完整(数据、超参数、流程详尽)
  • 数据集可获取(主要使用公开数据集)

关联笔记

基于

RAG: MSA 的对比基线,MSA 将检索融入注意力取代了外部检索器

Qwen3: 基座模型 Qwen3-4B-Instruct-2507

对比

HippoRAG2: 同骨干 RAG 基线

KaLMv2: SOTA 检索器,与 235B 骨干配合

MemoryAgent: 基于 Agent 的记忆系统,NIAH 对比

DeltaNet: 线性注意力方法,精度低

RWKV: 线性注意力方法,存在灾难性遗忘

方法相关

Sparse Attention: 核心方法——稀疏注意力

RoPE: 位置编码策略

Supervised Contrastive Loss: 路由器训练的辅助损失

Mean Pooling: chunk 级压缩策略

Cosine Similarity: 路由相关性度量

Curriculum Learning: 两阶段训练策略

硬件/数据相关

A800: 推理硬件

速查卡片

MSA: Memory Sparse Attention for Efficient End-to-End Memory Model Scaling to 100M Tokens

  • 核心: 端到端可微的稀疏注意力框架,支持 100M tokens 终身记忆
  • 方法: Router K Projector + Top-k 文档选择 + Memory Interleave 多跳推理
  • 结果: 16K→100M 仅 8.8% 退化;NIAH 1M tokens 94.84%;超越同骨干 RAG 16.0%
  • 代码: 未开源

笔记创建时间: 2026-03-28