Architecture / Sequential Memory
RNN、LSTM 与 GRU:隐藏状态、门控记忆与序列建模
把序列逐步读入隐藏状态:普通 RNN 递归更新,LSTM/GRU 用门控机制控制写入、遗忘和输出。
Mechanism Lab
动画:隐藏状态如何沿时间传播,门控如何控制记忆
动画从普通 RNN 的递归链开始,展示长链梯度风险,然后打开 LSTM 的 forget/input/output gates 和 GRU 的 update/reset gates,最后接预测头。
Step 1 / 5
Recurrence
每个时间步读取 x_t 和上一状态 h_{t-1},写出新的 h_t。
h_t=f(x_t,h_{t-1})Animation Control
Reduced-motion users receive the same step states without continuous motion.
01 / 直觉
核心直觉
RNN 的核心思想是把历史压缩进 hidden state:当前输入 x_t 与上一时刻 h_{t-1} 共同决定新状态 h_t。
普通 RNN 参数共享且适合任意长度序列,但长距离依赖会遇到梯度消失或爆炸。
LSTM 引入 cell state 和 forget/input/output gates,让模型能保留、写入或暴露长期记忆。
GRU 用 update/reset gates 合并部分 LSTM 结构,参数更少、训练更轻,常在中等规模序列任务中表现稳定。
02 / 数学
从递归状态到门控记忆
01 / 普通 RNN 状态
输入序列逐步进入同一个递归单元。所有时间步共享 W_x、W_h 和 b,因此模型可以处理可变长度序列。
h_t = phi(W_x x_t + W_h h_{t-1} + b)02 / 输出层
根据任务,可以在每个时间步输出 y_t,也可以只用最后一个隐藏状态做 sequence-level 预测。
p(y_t|x_{<=t}) = softmax(W_y h_t)03 / BPTT 梯度链
反向传播需要穿过多个时间步,梯度包含一串 Jacobian 乘积。如果谱半径长期小于 1 会消失,大于 1 会爆炸。
dL/dh_t includes prod_s W_h^T diag(phi_s)04 / LSTM 门控
LSTM 用 forget gate 决定旧记忆保留多少,用 input gate 决定新候选写入多少,用 output gate 决定暴露多少 hidden state。
c_t=f_t*c_{t-1}+i_t*g_t, h_t=o_t*tanh(c_t)05 / GRU 更新
GRU 用 update gate 在旧状态和候选状态之间插值,用 reset gate 控制候选状态是否读取旧记忆。
h_t=(1-z_t)*h_{t-1}+z_t*h_tilde06 / 序列表示
对于分类、预测或事件检测,可使用最后状态、所有状态的池化,或注意力加权的状态汇总。
r = pool(h_1,...,h_T) or alpha_t h_t03 / 代码
NumPy 演示:普通 RNN、LSTM 和 GRU 的一步更新
下面代码把三个序列单元的状态更新写成显式矩阵运算,方便比较普通递归和门控递归的差别。
import numpy as np
def sigmoid(x):
return 1 / (1 + np.exp(-x))
def rnn_step(x_t, h_prev, p):
return np.tanh(x_t @ p["Wx"] + h_prev @ p["Wh"] + p["b"])
def lstm_step(x_t, h_prev, c_prev, p):
joined = x_t @ p["Wx"] + h_prev @ p["Wh"] + p["b"]
i, f, o, g = np.split(joined, 4)
i = sigmoid(i)
f = sigmoid(f)
o = sigmoid(o)
g = np.tanh(g)
c_t = f * c_prev + i * g
h_t = o * np.tanh(c_t)
return h_t, c_t
def gru_step(x_t, h_prev, p):
z = sigmoid(x_t @ p["Wxz"] + h_prev @ p["Whz"] + p["bz"])
r = sigmoid(x_t @ p["Wxr"] + h_prev @ p["Whr"] + p["br"])
candidate = np.tanh(x_t @ p["Wxh"] + (r * h_prev) @ p["Whh"] + p["bh"])
return (1 - z) * h_prev + z * candidate
rng = np.random.default_rng(9)
T, d_in, d_h = 6, 3, 4
X = rng.normal(size=(T, d_in))
h = np.zeros(d_h)
c = np.zeros(d_h)
rnn_params = {
"Wx": rng.normal(size=(d_in, d_h)) / np.sqrt(d_in),
"Wh": rng.normal(size=(d_h, d_h)) / np.sqrt(d_h),
"b": np.zeros(d_h),
}
lstm_params = {
"Wx": rng.normal(size=(d_in, 4 * d_h)) / np.sqrt(d_in),
"Wh": rng.normal(size=(d_h, 4 * d_h)) / np.sqrt(d_h),
"b": np.zeros(4 * d_h),
}
gru_params = {
name: rng.normal(size=shape) / np.sqrt(shape[0])
for name, shape in {
"Wxz": (d_in, d_h), "Whz": (d_h, d_h),
"Wxr": (d_in, d_h), "Whr": (d_h, d_h),
"Wxh": (d_in, d_h), "Whh": (d_h, d_h),
}.items()
}
gru_params.update({"bz": np.zeros(d_h), "br": np.zeros(d_h), "bh": np.zeros(d_h)})
for x_t in X:
h_rnn = rnn_step(x_t, h, rnn_params)
h_lstm, c = lstm_step(x_t, h, c, lstm_params)
h_gru = gru_step(x_t, h, gru_params)
h = h_gru
print("RNN state:", h_rnn.round(3))
print("LSTM state:", h_lstm.round(3))
print("GRU state:", h_gru.round(3))04 / 案例
案例:研究日志、宏观时间序列和文本序列
- 在实证研究工作流里,序列不只是一句话。它可以是每日政策新闻、季度宏观指标、用户行为日志、论文段落序列或 Agent 工具调用轨迹。
- 普通 RNN 适合展示“状态如何递推”的基本机制,例如根据过去几期指标预测下一期风险;但当依赖跨越很长时间,梯度问题会让训练不稳定。
- LSTM 适合需要长期记忆的任务,例如在一长段审稿意见中记住早先提到的数据限制,再在后文生成一致的修改计划。
- GRU 适合较轻量的序列特征提取,例如把清洗日志、命令序列或短文本事件流编码成固定长度状态,再交给下游分类或异常检测模型。
05 / 风险