2131 words
11 minutes
Mamba 架构解析

Mamba#

优势: 对于长度为 n 的序列,Mamba 的计算复杂度仅为 O(n log n),而 Transformer 则为 O(n^2)。

可以从线性 RNN 的角度去理解状态空间模型(State-Space Model)。

传统架构的局限性#

卷积神经网络 (CNN)#

对于 CNN,如果是较远距离的向量,需要经过多层卷积才能被融合。这使其难以理解输入中的长程依赖关系。

也就是对于图中的第一个向量与第四个向量的信息,需要经过 3 次卷积才能被融合。

CNN 多层卷积

循环神经网络 (RNN)#

对于 RNN,不是将网络应用于两个连续的输入向量,而是应用于一个输入向量和一个网络的上一个输出,所以只需要一层就可以让最终的输出向量包含输入中所有向量的信息。

RNN 单层处理

但是 RNN 有两个问题:

  1. 无法并行训练: 因为依赖于上一步的输出结果。
  2. 极其难以训练: 存在梯度爆炸与梯度消失问题。

线性 RNN 的解决方案#

线性 RNN 对上述两个问题进行了解决。

1. 解决并行训练问题#

通过线性循环算子解决。

线性 RNN 的基础公式为: yi = Wy * yi−1 + Wx * xi

线性 RNN 公式

对于输出 yiWy 乘上第 i-1 个的输出向量,Wx 乘上第 i 个输入向量 xi。其中 Wx * xi 的部分可以提前做并行计算,所以可以简化。

“扫描” (Scan) 操作#

扫描操作允许并行计算。以简单的加法计算前 n 个数之和为例:

串行计算(线性递推): 也就是逐个进行计算。

串行扫描

并行计算: 如果想要并行计算,则使用下面的方式操作。

并行扫描

首先将每两个相邻的数相加,而后对于第 i 个数,其输出就是第 i 个和第 i-1 个数相加的结果。第二次迭代则是对于输出列表的第 i 个数和第 i-2 个数相加,以此类推。

例如,对于第四个数 1,首先计算 15 + 1 = 16,同时并行计算 7 + 3 = 10。到了第二次迭代,就用结果 1610 相加,算出了前四个数的和。

这个迭代过程将会持续到步长与序列长度相等为止。如果并行执行,该过程将在 O(log n) 步后完成。

该算法不仅仅适用于加法,还可以是多种满足结合律的运算符 f(a, b)

结合律运算符

所以可以将以下形式的运算符替代 f(a, b),其中 x2 是矩阵而非单独的数,这个公式等价于线性 RNN 的公式。

矩阵运算符

上述的 scan 方法可以等同于线性递推:f(上一个的输出, 本次的输入)。所以我们既可以使用 scan 并行方法训练,也可以使用线性递推方法进行推理。

Scan 与递推等价

但是,W1 * W2 * ... 这样一连串的矩阵乘法,会导致计算非常慢。在原始的 RNN 中,这是一个 d^2 级别的运算,但现在变成了 d^3 次方的计算。

问题公式

解决方法:矩阵对角化 通过将矩阵 W 分解为 PDP-1,相邻的 PP-1 就会抵消掉。

矩阵对角化

因为对角矩阵的“矩阵乘法”恰好等价于它的“逐元素积”,如此就可以把 d^3 的复杂度降低为 O(d)

2. 解决梯度爆炸与梯度消失#

梯度更新的公式如下:

梯度更新公式

在循环了 n 次后,权重 w 增加 0.001,就与原 w 相差了约 0.001 * w^(n-1)

梯度传播

  • 如果 w > 1,就会导致梯度爆炸,神经网络的权重就会变化过大,学到的东西就会遭到破坏。
  • 如果 w < 1,就会导致梯度消失,导致权重基本不会更新,学不到任何东西。

相比之下,卷积神经网络的权重有很多个,有大于 1 的也有小于 1 的,这会让网络保持一种稳定的状态。

线性 RNN 的解决方案: 因为本身的权重就是复数形式,所以他们以复极坐标形式进行参数化,a 代表幅度,b 代表角度,同时将幅度 a 限制在小于 1。

w = a * e^(ib)

复数参数化

通过初始化,将 e^(-e^a) 的值限制在 [0.999, 1] 之间,对于 b 也是 [0, PI/10] 弧度之间。这使得在初始化时,所有的权重都非常接近 1。

权重初始化

同时也会在输入的时候给每个值乘上一个 delta (Δ)。由于 e^(-e^a) 被限制在 0.999-1 之间,所以 delta 初始化的结果就是接近于 0 的。

Delta 初始化

神奇的是,就是这样初始化模型,就会让它学会记住上下文信息,并持续上万步。

与状态空间模型 (SSM) 的联系#

对于状态空间模型,其权重的初始化也是类似的。其中 delta (Δ) 是可学习的(这个 Δ 和上面的不一样),通常被初始化为 0.0001 到 0.1 之间。当次方的数接近于 0 的时候,最终得到初始化的权重就是接近于 1 的,从而使得训练保持稳定。

同时这个 a + bi 就是 SSM 里的 HiPPO 所初始化的固定矩阵 A。

w = e^(Δ * (a+bi))

SSM 权重公式

SSM 与线性 RNN 的共同点: 由于这个线性 RNN 是根据 SSM 的思想所作,所以很多地方是与 SSM 一样的。

  1. 使用相同的计算模式: 在训练时使用卷积模式(并行训练),而在推理的时候就回到线性递推模式。
  2. 使用了相同的数学工具: 对角化。两者都将核心的权重矩阵 A 简化为了对角矩阵,大幅度降低了计算效率。

不同点:

  • SSM 使用 HiPPO 作为矩阵 A 的连续初始化,得到的就是对应的 a+bi,而后通过 e^(delta * A) 得到离散化的矩阵 A
  • 线性 RNN 则是本身就是求的一个离散化的 A,并且初始化也是随机的,而 SSM 则是固定的 HiPPO 的方法初始化。

HIPPO 的初始化方法: HIPPO 初始化矩阵

Mamba 的核心创新#

但是线性 RNN 也依旧有个问题,就是它无法选择性地遗忘输出向量中的信息。因为循环中只有一个权重 A 的缘故,所以假如 A 趋近于 0,那么就会导致模型什么都记不住;如果趋近于 1,输出向量就会一直积累过去的信息。

而 Mamba 的解决方法是不在每一步使用相同的权重,而是依赖于动态的输入权重。Mamba 对每一个输入向量应用一个线性函数,以此来生成一个独立的权重。

因为有了这个动态的权重,所以对于矩阵 A 就不需要再复杂地使用各种方式初始化了,而是变为了一个固定的对角矩阵。因为现在控制记忆的是 delta (Δ) 了。

Δ 的大小由原来的 D 变成了 (B, L, D),意味着对于一个 batch 里的每个 token (总共有 BxL 个) 都有一个独特的 Δ。且每个位置的 B 矩阵、C 矩阵、Δ 都不相同,这意味着对于每个输入 token,现在都有独特不同的 B 矩阵、C 矩阵,可以解决内容感知问题。通过结合输入的序列长度和批量大小,矩阵 B、C 的维度由 (D, N) 变到 (B, L, N),使得每个 token 都有一个独立的 B、C、Δ,从而可以使每个 token 都有独立的卷积核。

如上述的,就是为每一个 token 准备了一个独立的卷积核。

这样就可以让某些特定的输入被遗忘,某些特定的输入被记住了,同时也有助于缓解梯度爆炸与消失的问题。

同时,Mamba 也将输出向量的维度从原来的 8 扩大为了 16,以此可以保存更多历史信息。而在传递下一层的时候,就会将输出向量重新投射回原始维度,而后传递到下一层。

Mamba 架构解析
https://mizuki.mysqil.com/posts/mamba/
Author
Lain
Published at
2025-10-24
License
Unlicensed

Some information may be outdated

封面
示例歌曲
示例艺术家
封面
示例歌曲
示例艺术家
0:00 / 0:00