Architecture / Sequential Memory

RNN、LSTM 与 GRU:隐藏状态、门控记忆与序列建模

把序列逐步读入隐藏状态:普通 RNN 递归更新,LSTM/GRU 用门控机制控制写入、遗忘和输出。

Mechanism Lab

动画:隐藏状态如何沿时间传播,门控如何控制记忆

动画从普通 RNN 的递归链开始,展示长链梯度风险,然后打开 LSTM 的 forget/input/output gates 和 GRU 的 update/reset gates,最后接预测头。

Step 1 / 5

Recurrence

每个时间步读取 x_t 和上一状态 h_{t-1},写出新的 h_t。

h_t=f(x_t,h_{t-1})

Animation Control

Reduced-motion users receive the same step states without continuous motion.

01 / 直觉

核心直觉

RNN 的核心思想是把历史压缩进 hidden state:当前输入 x_t 与上一时刻 h_{t-1} 共同决定新状态 h_t。

普通 RNN 参数共享且适合任意长度序列,但长距离依赖会遇到梯度消失或爆炸。

LSTM 引入 cell state 和 forget/input/output gates,让模型能保留、写入或暴露长期记忆。

GRU 用 update/reset gates 合并部分 LSTM 结构,参数更少、训练更轻,常在中等规模序列任务中表现稳定。

02 / 数学

从递归状态到门控记忆

01 / 普通 RNN 状态

输入序列逐步进入同一个递归单元。所有时间步共享 W_x、W_h 和 b,因此模型可以处理可变长度序列。

h_t = phi(W_x x_t + W_h h_{t-1} + b)

02 / 输出层

根据任务,可以在每个时间步输出 y_t,也可以只用最后一个隐藏状态做 sequence-level 预测。

p(y_t|x_{<=t}) = softmax(W_y h_t)

03 / BPTT 梯度链

反向传播需要穿过多个时间步,梯度包含一串 Jacobian 乘积。如果谱半径长期小于 1 会消失,大于 1 会爆炸。

dL/dh_t includes prod_s W_h^T diag(phi_s)

04 / LSTM 门控

LSTM 用 forget gate 决定旧记忆保留多少,用 input gate 决定新候选写入多少,用 output gate 决定暴露多少 hidden state。

c_t=f_t*c_{t-1}+i_t*g_t, h_t=o_t*tanh(c_t)

05 / GRU 更新

GRU 用 update gate 在旧状态和候选状态之间插值,用 reset gate 控制候选状态是否读取旧记忆。

h_t=(1-z_t)*h_{t-1}+z_t*h_tilde

06 / 序列表示

对于分类、预测或事件检测,可使用最后状态、所有状态的池化,或注意力加权的状态汇总。

r = pool(h_1,...,h_T) or alpha_t h_t

03 / 代码

NumPy 演示:普通 RNN、LSTM 和 GRU 的一步更新

下面代码把三个序列单元的状态更新写成显式矩阵运算,方便比较普通递归和门控递归的差别。

import numpy as np

def sigmoid(x):
    return 1 / (1 + np.exp(-x))

def rnn_step(x_t, h_prev, p):
    return np.tanh(x_t @ p["Wx"] + h_prev @ p["Wh"] + p["b"])

def lstm_step(x_t, h_prev, c_prev, p):
    joined = x_t @ p["Wx"] + h_prev @ p["Wh"] + p["b"]
    i, f, o, g = np.split(joined, 4)
    i = sigmoid(i)
    f = sigmoid(f)
    o = sigmoid(o)
    g = np.tanh(g)
    c_t = f * c_prev + i * g
    h_t = o * np.tanh(c_t)
    return h_t, c_t

def gru_step(x_t, h_prev, p):
    z = sigmoid(x_t @ p["Wxz"] + h_prev @ p["Whz"] + p["bz"])
    r = sigmoid(x_t @ p["Wxr"] + h_prev @ p["Whr"] + p["br"])
    candidate = np.tanh(x_t @ p["Wxh"] + (r * h_prev) @ p["Whh"] + p["bh"])
    return (1 - z) * h_prev + z * candidate

rng = np.random.default_rng(9)
T, d_in, d_h = 6, 3, 4
X = rng.normal(size=(T, d_in))
h = np.zeros(d_h)
c = np.zeros(d_h)

rnn_params = {
    "Wx": rng.normal(size=(d_in, d_h)) / np.sqrt(d_in),
    "Wh": rng.normal(size=(d_h, d_h)) / np.sqrt(d_h),
    "b": np.zeros(d_h),
}
lstm_params = {
    "Wx": rng.normal(size=(d_in, 4 * d_h)) / np.sqrt(d_in),
    "Wh": rng.normal(size=(d_h, 4 * d_h)) / np.sqrt(d_h),
    "b": np.zeros(4 * d_h),
}
gru_params = {
    name: rng.normal(size=shape) / np.sqrt(shape[0])
    for name, shape in {
        "Wxz": (d_in, d_h), "Whz": (d_h, d_h),
        "Wxr": (d_in, d_h), "Whr": (d_h, d_h),
        "Wxh": (d_in, d_h), "Whh": (d_h, d_h),
    }.items()
}
gru_params.update({"bz": np.zeros(d_h), "br": np.zeros(d_h), "bh": np.zeros(d_h)})

for x_t in X:
    h_rnn = rnn_step(x_t, h, rnn_params)
    h_lstm, c = lstm_step(x_t, h, c, lstm_params)
    h_gru = gru_step(x_t, h, gru_params)
    h = h_gru

print("RNN state:", h_rnn.round(3))
print("LSTM state:", h_lstm.round(3))
print("GRU state:", h_gru.round(3))

04 / 案例

案例:研究日志、宏观时间序列和文本序列

  • 在实证研究工作流里,序列不只是一句话。它可以是每日政策新闻、季度宏观指标、用户行为日志、论文段落序列或 Agent 工具调用轨迹。
  • 普通 RNN 适合展示“状态如何递推”的基本机制,例如根据过去几期指标预测下一期风险;但当依赖跨越很长时间,梯度问题会让训练不稳定。
  • LSTM 适合需要长期记忆的任务,例如在一长段审稿意见中记住早先提到的数据限制,再在后文生成一致的修改计划。
  • GRU 适合较轻量的序列特征提取,例如把清洗日志、命令序列或短文本事件流编码成固定长度状态,再交给下游分类或异常检测模型。

05 / 风险

常见误区

把最后一个隐藏状态当作总是足够的序列表示;长序列常需要 pooling、attention 或分段建模。
忘记截断 BPTT 或梯度裁剪,导致训练慢、显存高或梯度爆炸。
在时间序列预测中随机打乱时间顺序,造成未来信息泄漏。
认为 LSTM/GRU 已经解决所有长期依赖;极长文本或复杂检索仍常需要 attention 或外部记忆。
没有区分 many-to-one、many-to-many 和 sequence-to-sequence 任务形态,导致标签对齐错误。
把循环模型状态解释成因果机制;hidden state 是预测表示,不是识别假设。

参考资料