MSA: Memory Sparse Attention for Efficient End-to-End Memory Model Scaling to 100M Tokens
论文笔记:MSA: Memory Sparse Attention for Efficient End-to-End Memory Model Scaling to 100M Tokens
元信息
| 项目 | 内容 |
|---|---|
| 机构 | 未明确标注(多机构合作) |
| 日期 | March 2025 |
| 项目主页 | 未提供 |
| 对比基线 | RAG, HippoRAG2, MemoryAgent, KaLMv2 |
| 链接 | arXiv |
一句话总结
提出 Memory Sparse Attention,通过可微的稀疏注意力机制实现端到端的百万级 token 长上下文建模,从 16K 到 100M tokens 仅 8.8% 性能退化
核心贡献
端到端可微的稀疏注意力框架: 通过 Router K Projector + Top-k 选择实现 稀疏注意力,保持线性复杂度的同时支持梯度反传
Memory Interleave 机制: 支持自适应多跳检索推理,模型自主决定检索文档数量和时机,无需外部编排
100M tokens 推理方案: 通过 GPU-CPU 分层存储 + 异步预取,在 2×A800 上实现 1 亿 token 的端到端推理
问题背景
要解决的问题
如何让 LLM 拥有终身记忆(lifetime memory),能在百万甚至亿级 token 的上下文中精确检索和推理
现有方法的局限
参数化记忆(LoRA、CPT):需要重新训练,存在灾难性遗忘,无法支持动态更新
外部存储(RAG):检索器和生成器分离,语义表示不一致,多跳推理能力弱
线性注意力(DeltaNet、RWKV):有损压缩,精度低,存在灾难性遗忘
标准稀疏注意力(DSA):不支持终身级别的上下文规模
本文的动机
将稀疏注意力从”有限窗口”扩展到”终身规模”,同时保持端到端可微性和高精度,避免 RAG 的检索-生成割裂问题
方法详解
模型架构
MSA 基于 Qwen3-4B-Instruct 进行改造:
- 输入: 查询 + 外部记忆文档集合
- Backbone: Qwen3-4B(18 层,8 头,头维度 128)
- 核心模块: 后半部分层(9-18 层)替换为 MSA 层,前半部分保持标准 自注意力
- 输出: 自回归生成答案 / 文档 ID 序列
- 总参数: ~4B
核心模块
模块1: MSA Layer(稀疏注意力层)
设计动机: 将检索和生成统一到单一注意力层,避免 RAG 中检索器与生成器的表示割裂
具体实现:
- 每个文档的隐状态 通过三组投影得到 、(内容 KV)和 (路由 Key)
- 路由 Key 通过 64-token 的 均值池化 压缩为 chunk 级表示
- 查询的路由向量 与所有文档的 计算 余弦相似度,取 Top-16 文档
- 选中文档的完整 KV 与查询 KV 拼接后执行标准注意力
模块2: Parallel & Global RoPE(位置编码策略)
设计动机: 解耦文档位置和全局位置,使模型对记忆规模不敏感
具体实现:
- 每个文档独立使用 旋转位置编码,位置 ID 从 0 开始(Parallel RoPE)
- 活跃上下文使用 Global RoPE,位置偏移量 = 检索文档数 (默认 16)
- 两种 RoPE 在注意力计算时自然融合,无需额外参数
模块3: Memory Interleave(记忆交错机制)
设计动机: 支持多跳推理,让模型自主决定何时检索、检索多少文档
具体实现:
- 模型自回归生成文档 ID 序列,以特殊分隔符结束
- 检索对应文档文本,追加到查询上下文
- 重复上述过程直到模型认为证据充分,转入答案生成
- 文档数量自适应,不需要预设固定值
训练流程
持续预训练(158.95B tokens)
两阶段优化:
- Warm-up 阶段:侧重路由学习,,学习率
- 主训练阶段:侧重语言建模,,学习率
后训练(SFT)
Stage 1: 8K 上下文长度的指令微调
Stage 2: 扩展到 64K 上下文,严格数据清洗
推理架构(三阶段流水线)
- 全局记忆编码(离线):预计算所有文档的 、、,存入结构化缓存
- 路由与上下文组装(在线):查询与路由 Key 匹配,选出 Top-16 文档,从 CPU 异步加载对应 KV
- 稀疏生成(在线):在组装好的稀疏上下文上自回归生成
关键公式
公式1: KV 投影与路由投影
含义: 从文档 的隐状态 生成三组投影——标准内容 Key/Value 和专用路由 Key
符号说明:
- : 第 个文档的隐状态矩阵
- : 第 个注意力头的标准 KV 投影矩阵
- : 第 个注意力头的路由 Key 投影矩阵
公式2: 相关性评分
含义: 计算查询与第 个文档第 个 chunk 的相关性分数,跨头取均值后在 token 维度取最大值
符号说明:
- : 查询的路由向量(第 个头)
- : 第 个文档第 个 chunk 的压缩路由 Key
- : 余弦相似度
公式3: 上下文组装
含义: 将 Top-k 选中文档的 KV 与查询的 KV 拼接,形成稀疏注意力上下文
符号说明:
- : Top-k 选中文档的索引集合
- : 第 个文档的内容 Key/Value
- : 查询的 Key/Value
公式4: 注意力输出
含义: 在组装好的稀疏上下文上执行标准注意力计算
符号说明:
- : 查询的 Query 向量
公式5: 监督对比损失
含义: 监督路由器的层级决策,使正样本文档的相关性分数高于负样本
符号说明:
- : 正样本文档集合
- : 负样本文档集合
- : 第 个正样本的相关性分数
- : 第 个负样本的相关性分数
- : 温度参数
公式6: 训练复杂度
含义: 训练复杂度由独立文档处理主导,为线性
符号说明:
- : 记忆总大小(文档数)
- : 单个文档长度
- : 查询长度
- : 池化大小(64)
- : Top-k 选择数(16)
公式7: 推理复杂度
含义: 每次查询的推理复杂度为线性,离线预处理 可分摊
符号说明:
- : 答案生成长度
公式8: 训练损失函数
Warm-up 阶段:
主训练阶段:
含义: 两阶段课程学习——先重点训练路由器,再重点训练语言建模
符号说明:
- : 标准语言模型损失(交叉熵)
- : 监督对比损失
关键图表
Figure 1: Scalability / 可扩展性展示
{:width 600}
说明: MSA 在 MS MARCO 数据集上从 16K 扩展到 100M tokens 的性能曲线,仅 8.8% 退化(4.023 → 3.669)。展示了 MSA 集成 Top-k 选择与 稀疏注意力 的可微分设计,兼具可扩展性。
Figure 2: MSA Layer Architecture / MSA 层架构
{:width 600}
说明: MSA 层的路由机制与注意力计算流程。文档通过 Router K Projector 生成路由 Key,经 均值池化 压缩后与查询路由向量计算 余弦相似度,选出 Top-k 文档,最终拼接 KV 执行注意力。
Figure 3: Three-Stage Inference with Memory Interleave / 三阶段推理与记忆交错
{:width 600}
说明: 三阶段推理流水线。Stage 1 离线编码所有文档 KV;Stage 2 在线路由选择 Top-16 文档并从 CPU 加载 KV;Stage 3 结合 Memory Interleave 进行多跳自回归生成。GPU 存储路由 Key(~56GB),CPU 存储内容 KV。
Figure 4: NIAH Benchmark Results / 大海捞针评测结果
{:width 600}
说明: 32K-1M token 范围内的 Needle-In-A-Haystack 准确率。MSA 在 1M tokens 时仍保持 94.84%,仅比 32K 基线(98.77%)下降 3.93 个百分点。对比 Qwen3-4B 原始模型在 1M 时暴降至 24.69%,MemoryAgent-14B 下降 5.76 个百分点。
Table 1: 长期记忆方法对比
| 方法类型 | 终身记忆 | 精度 | 兼容主流 LLM | 计算复杂度 | 记忆管理 | 灾难性遗忘 |
|---|---|---|---|---|---|---|
| 参数化(LoRA/CPT) | 否 | 高 | 高 | 训练高/推理低 | 难 | 有 |
| 测试时训练(Titans) | 否 | 中 | 低 | 中 | 中 | 有 |
| RAG | 是 | 中 | 中 | 易 | 低 | |
| MemAgent | 是 | 中 | 中 | 中 | 易 | 中 |
| 稀疏注意力(DSA) | 否 | 高 | 高 | 中 | 易 | 无 |
| 线性注意力(DeltaNet/RWKV) | 否 | 低 | 低 | 易 | 有 | |
| MemGen | 否 | 中 | 高 | 中 | 中 | 无 |
| MSA(Ours) | 是 | 高 | 高 | 易 | 无 |
说明: MSA 是唯一同时具备终身记忆、高精度、高兼容性、线性复杂度且无灾难性遗忘的方法
Table 2: 同骨干网络 RAG 对比(LLM Judge 0-5 分)
| 数据集 | Tokens | Qwen3-4B R@1 | Qwen3-4B R@10 | Qwen3-4B(RR) R@10 | HippoRAG2 R@10 | MSA |
|---|---|---|---|---|---|---|
| MS MARCO v1 | 7.34M | 2.893 | 3.005 | 3.017 | 3.019 | 4.141 |
| Natural Questions | 1.47M | 3.452 | 3.297 | 3.385 | 3.374 | 3.545 |
| DuReader | 277K | 3.726 | 3.594 | 3.607 | 3.415 | 4.155 |
| TriviaQA | 10M | 4.133 | 4.273 | 4.391 | 4.367 | 4.621 |
| NarrativeQA | 538K | 1.611 | 2.860 | 3.536 | 2.655 | 3.395 |
| PopQA | 1.18M | 2.959 | 3.299 | 3.266 | 3.249 | 3.433 |
| 2WikiMultiHopQA | 722K | 1.065 | 3.136 | 3.159 | 3.330 | 4.280 |
| HotpotQA | 1.35M | 2.252 | 3.787 | 4.022 | 3.970 | 4.061 |
| MuSiQue | 1.41M | 0.936 | 1.928 | 1.965 | 2.095 | 2.211 |
| 平均 | — | 2.559 | 3.242 | 3.372 | 3.275 | 3.760 |
说明: MSA 以自适应检索击败所有同骨干 RAG 基线,平均领先 16.0%(vs Qwen3-4B R@1)。多跳数据集(2Wiki、HotpotQA)优势尤为显著。
Table 3: SOTA RAG 系统对比(LLM Judge 0-5 分)
| 数据集 | KaLMv2+Qwen3-235B R@10 | KaLMv2+Qwen3-235B(RR) R@10 | KaLMv2+Llama3.3 R@10 | KaLMv2+Llama3.3(RR) R@10 | MSA |
|---|---|---|---|---|---|
| MS MARCO v1 | 3.027 | 2.995 | 2.919 | 2.952 | 4.141 |
| Natural Questions | 3.694 | 3.645 | 3.662 | 3.647 | 3.545 |
| DuReader | 4.044 | 3.891 | 3.742 | 3.780 | 4.155 |
| TriviaQA | 4.578 | 4.555 | 4.719 | 4.695 | 4.621 |
| NarrativeQA | 2.427 | 3.375 | 2.382 | 3.317 | 3.395 |
| PopQA | 3.396 | 3.376 | 3.305 | 3.362 | 3.433 |
| 2WikiMultiHopQA | 3.582 | 3.583 | 3.445 | 3.541 | 4.280 |
| HotpotQA | 4.225 | 4.194 | 4.127 | 4.203 | 4.061 |
| MuSiQue | 2.647 | 2.605 | 2.258 | 2.614 | 2.211 |
| 平均 | 3.506 | 3.580 | 3.396 | 3.568 | 3.760 |
说明: MSA(4B 参数)与使用 235B 参数骨干的 SOTA RAG 系统竞争力相当,平均分数最高。在 MS MARCO 和 2WikiMultiHopQA 上大幅领先,但在 Natural Questions 和 MuSiQue 上略低。
Table 4: 消融实验(0-5 分)
| 配置 | 平均 | MS MARCO | NQ | DuReader | HotpotQA |
|---|---|---|---|---|---|
| MSA-S2(完整) | 3.976 | 4.141 | 3.545 | 4.155 | 4.061 |
| MSA-S1(无 Stage 2 SFT) | 3.694 | 3.197 | 3.493 | 4.064 | 4.020 |
| w/o Memory Interleave | 3.497 | 3.175 | 3.485 | 4.076 | 3.250 |
| w/o 持续预训练 | 2.537 | 2.267 | 2.448 | 3.144 | 2.289 |
| w/o 原始文本 | 2.325 | 2.625 | 2.190 | 2.186 | 2.297 |
关键发现: 持续预训练是最关键组件(去除后下降 31.3%),其次是原始文本(下降 37.1%)。Memory Interleave 对多跳任务影响最大(HotpotQA 下降 19.2%)。课程学习(S2 vs S1)贡献 7.6% 平均提升。
Table 5: 预训练数据构成
| 数据类别 | 查询数 | Token 数 | 主要来源 |
|---|---|---|---|
| 长上下文与指令微调 | 5.8M | 6.46B | KaLM finetune data |
| 学术科学文献 | 2.0M | 28.74B | S2ORC, SPECTER |
| 通用 QA 与社区知识 | 4.9M | 75.80B | Yahoo Answers, WikiAnswers, MS MARCO, PAQ 等 |
| 新闻与摘要 | 2.1M | 36.43B | AG News, NPR, CNN/DailyMail, XSum |
| 领域特定 | 2.1M | 28.94B | Amazon Reviews, CodeSearchNet, WikiHow |
| 总计 | 17.9M | 158.95B | 跨 17 个数据源 |
说明: 训练数据覆盖科学文献、QA、新闻、代码等多领域,非 KaLM 数据集均下采样至 500K 查询
实验
数据集
| 数据集 | 规模 | 特点 | 用途 |
|---|---|---|---|
| MS MARCO v1 | 7.34M tokens | 大规模信息检索 | QA 评测 |
| Natural Questions | 1.47M tokens | 开放域 QA | QA 评测 |
| DuReader | 277K tokens | 中文阅读理解 | QA 评测 |
| TriviaQA | 10M tokens | 大规模 trivia | QA 评测 |
| NarrativeQA | 538K tokens | 叙事理解 | QA 评测 |
| PopQA | 1.18M tokens | 流行文化 QA | QA 评测 |
| 2WikiMultiHopQA | 722K tokens | 多跳推理 | QA 评测 |
| HotpotQA | 1.35M tokens | 多跳推理 | QA 评测 |
| MuSiQue | 1.41M tokens | 多步推理 | QA 评测 |
| NIAH | 32K-1M tokens | 大海捞针 | 长上下文评测 |
实现细节
Backbone: Qwen3-4B-Instruct-2507
MSA 层: 模型后半部分(9-18 层),共 10 层
注意力头数: 8
头维度: 128
Top-k: 16 文档
Chunk 大小: 64 tokens(Mean Pooling kernel)
预训练: 158.95B tokens,两阶段学习率(1e-4 → 6e-6)
SFT Stage 1: 8K 上下文
SFT Stage 2: 64K 上下文
推理硬件: 2×A800 GPU(160GB aggregate VRAM)
路由 Key 存储: ~56GB GPU VRAM
内容 KV 存储: CPU DRAM,异步预取
NIAH 评测结果
| 模型 | 32K | 1M | 退化 |
|---|---|---|---|
| MSA | 98.77% | 94.84% | -3.93pp |
| MemoryAgent-14B | — | — | -5.76pp |
| Qwen3-4B-Instruct | — | 24.69% | 严重崩溃 |
可扩展性结果
| 上下文长度 | MSA 评分 | 退化率 |
|---|---|---|
| 16K | 4.023 | — |
| 100M | 3.669 | 8.8% |
批判性思考
优点
统一框架: 将检索和生成融合到单一注意力机制,避免了 RAG 的模块割裂问题,端到端可微分训练
极致可扩展性: 100M tokens 仅 8.8% 退化,远超现有方法。线性复杂度使得理论上可无限扩展
工程实用性: GPU-CPU 分层存储方案使得在消费级硬件上也能处理超长上下文
自适应多跳: Memory Interleave 让模型自主决定检索策略,比固定 Top-k 的 RAG 更灵活
局限性
基座模型限制: 仅在 Qwen3-4B 上验证,未在更大模型(7B/14B/70B)上实验,可扩展性到更大模型未知
预训练代价高: 158.95B tokens 的持续预训练成本不低,且消融实验表明去除预训练后性能暴跌 31.3%,说明方法强依赖于大量训练
评测指标单一: 主要依赖 LLM Judge(0-5 分),缺乏 EM/F1 等客观指标的对比
路由 Key 存储: 100M tokens 需要 ~56GB VRAM 仅用于存储路由 Key,对显存要求仍然较高
无代码开源: 截至目前未提供代码和预训练模型
潜在改进方向
多级路由: 当前单级 Top-k 路由可能不够高效,可考虑层次化路由(先粗选再精选)减少计算
路由 Key 压缩: 56GB 路由 Key 存储仍然较大,可尝试量化或更激进的池化策略
跨模型泛化: 验证 MSA 层作为插件在不同 LLM 架构上的即插即用能力
流式记忆更新: 当前是离线编码全部文档,可探索增量更新机制
可复现性评估
- 代码开源(未提供)
- 预训练模型(未提供)
- 训练细节完整(数据、超参数、流程详尽)
- 数据集可获取(主要使用公开数据集)
关联笔记
基于
RAG: MSA 的对比基线,MSA 将检索融入注意力取代了外部检索器
Qwen3: 基座模型 Qwen3-4B-Instruct-2507
对比
HippoRAG2: 同骨干 RAG 基线
KaLMv2: SOTA 检索器,与 235B 骨干配合
MemoryAgent: 基于 Agent 的记忆系统,NIAH 对比
DeltaNet: 线性注意力方法,精度低
RWKV: 线性注意力方法,存在灾难性遗忘
方法相关
Sparse Attention: 核心方法——稀疏注意力
RoPE: 位置编码策略
Supervised Contrastive Loss: 路由器训练的辅助损失
Mean Pooling: chunk 级压缩策略
Cosine Similarity: 路由相关性度量
Curriculum Learning: 两阶段训练策略
硬件/数据相关
A800: 推理硬件
速查卡片
MSA: Memory Sparse Attention for Efficient End-to-End Memory Model Scaling to 100M Tokens
- 核心: 端到端可微的稀疏注意力框架,支持 100M tokens 终身记忆
- 方法: Router K Projector + Top-k 文档选择 + Memory Interleave 多跳推理
- 结果: 16K→100M 仅 8.8% 退化;NIAH 1M tokens 94.84%;超越同骨干 RAG 16.0%
- 代码: 未开源
笔记创建时间: 2026-03-28