Transfusion 的混合损失函数

分类: 预训练与微调 · 难度: 中级 · 关联讲座: L17

Transfusion(Meta, 2024)提出了在单一 Transformer 中同时执行语言建模和扩散生成的统一架构。核心创新是为不同模态使用不同的损失函数——文本用 next-token prediction,图像用 DDPM 扩散去噪——共享注意力层捕获跨模态交互,同时通过 Mixture of Transformers(MoT)为每种模态分配专用 FFN。


📐 Transfusion 的混合损失函数

核心思想:同一个 Transformer 同时学两种生成任务,不同模态用不同的损失函数。

L=LLM+λLdiffusion\mathcal{L} = \mathcal{L}_{\text{LM}} + \lambda \cdot \mathcal{L}_{\text{diffusion}}

文本部分(离散 token)用标准语言建模损失:

LLM=tlogP(xtx<t)\mathcal{L}_{\text{LM}} = -\sum_{t} \log P(x_t | x_{<t})

图像部分(连续 patch embedding)用扩散去噪损失(DDPM 风格):

Ldiffusion=Et,ϵ[ϵϵθ(zt,t,c)2]\mathcal{L}_{\text{diffusion}} = \mathbb{E}_{t, \epsilon} \left[ \| \epsilon - \epsilon_\theta(z_t, t, c) \|^2 \right]

其中 zt=αˉtz0+1αˉtϵz_t = \sqrt{\bar\alpha_t} z_0 + \sqrt{1-\bar\alpha_t} \epsiloncc 是 Transformer 上下文(可包含文本)。

Mixture of Transformers (MoT) 的计算分离:共享注意力层(捕获跨模态交互),为每种模态分配独立的 FFN(捕获模态特有特征):

hout=Attn(h)+FFNmodal(h)h_{\text{out}} = \text{Attn}(h) + \text{FFN}_{\text{modal}}(h)

参数量:标准 nn 层 Transformer 的 FFN 替换为 mm 种 modal-specific FFN,参数总量 ×m/n\times m/n 倍增加,但激活参数量不变(稀疏激活)。

🔢 早期融合 vs 晚期融合的对比

方法代表模型图像表示文本-图像交互层级优势劣势
晚期融合(Late Fusion)CLIP独立 ViT 编码仅最终 embedding 对齐检索效率高、可解耦难以建模细粒度交互
交叉注意力(Cross-Attn)Flamingo独立 ViTDecoder 每层注入图像信息灵活、可扩展两套参数,部署复杂
早期融合(Early Fusion)ChameleonVQ-VAE token 化图文 token 混排,统一 Transformer深度交互,架构简洁图像 token 化损失信息

Chameleon 使用 8192 码本的 VQ-VAE 将 256×256256 \times 256 图像编码为 32×32=102432 \times 32 = 1024 个 token,与文本 token 直接拼接送入 Transformer。

💡 为什么多模态很难?

模态鸿沟(Modality Gap):文本是离散的符号序列(低维语义),图像是连续的高维信号(像素强度)。两者的统计特性、信息密度、序列长度都完全不同。

早期融合的根本挑战:图像 VQ-VAE token 化会损失精细信息(颜色梯度、纹理),但换来了与文本的统一表示空间。Transfusion 的妥协方案是保留图像的连续表示,用扩散头生成,用注意力层交互——两全其美但实现复杂。

⚠️ 常见误区

  1. 误区:注意力可以自由地在文本和图像 token 之间流动,所以早期融合就是天然跨模态 → 正确:Transformer 对序列长度的 O(n2)O(n^2) 注意力代价在拼入 1024 图像 token 后剧增,实践中通常需要 FlashAttention + 图像块降采样。

  2. 误区:多模态能力等于视觉理解能力 → 正确:现有 VLM 在空间推理、计数、OCR 等任务上仍显著落后人类,这些任务需要图像特有的结构理解,不是语言 token 机制的强项。