Architecture / Sequence Transduction

Seq2Seq 与 Encoder-Decoder:从固定语义向量到注意力对齐

把一个输入序列映射成另一个输出序列:encoder 读入 source,decoder 自回归生成 target,attention 在每一步重新选择输入证据。

Mechanism Lab

动画:Encoder 读入 source,Attention 路由证据,Decoder 生成 target

动画从 source token 和 encoder states 开始,展示固定 context 的瓶颈,再打开 attention heatmap,让 decoder 在每个输出步按权重读取输入证据。

Step 1 / 5

Encode

Encoder 逐步读取 source token,形成位置级隐藏状态 h_i。

h_i=f_enc(x_i,h_{i-1})

Animation Control

Reduced-motion users receive the same step states without continuous motion.

01 / 直觉

核心直觉

Seq2Seq 解决的是“长度不一定相同的序列到序列”问题,例如翻译、摘要、代码生成、表格到文字、审稿意见到修改计划。

Encoder 把 x_1,...,x_T 编成隐藏状态;早期模型只用最后状态 c 作为固定语义向量,这会形成信息瓶颈。

Decoder 按 p(y_t | y_<t, x) 逐步生成输出。训练时通常用 teacher forcing 输入真实上一词,推理时只能输入自己刚生成的词。

Attention 让 decoder 在每个输出步计算 source 位置权重 alpha_{t,i},把固定向量 c 变成动态上下文 c_t,从而缓解长句和局部对齐问题。

02 / 数学

Seq2Seq 的概率分解、瓶颈和注意力推导

01 / 条件序列概率

目标是建模给定输入序列 x 后输出序列 y 的条件分布。链式法则把整句概率拆成逐步生成概率。

p(y|x) = prod_{t=1}^M p(y_t | y_{<t}, x)

02 / Encoder 状态

RNN encoder 逐步读取 source token,得到每个位置的隐藏状态 h_i;双向 encoder 还会拼接前向和后向状态。

h_i = f_enc(E_x[x_i], h_{i-1})

03 / 固定上下文瓶颈

原始 encoder-decoder 常把整句压成一个向量 c,例如最后状态 h_T。Decoder 每一步都依赖这个固定 c。

c = q(h_1,...,h_T), s_t=f_dec(E_y[y_{t-1}], s_{t-1}, c)

04 / 注意力打分

Attention 在第 t 个输出步用 decoder 状态 s_{t-1} 和每个 encoder 状态 h_i 计算对齐分数,再 softmax 成权重。

e_{t,i}=a(s_{t-1},h_i), alpha_{t,i}=softmax_i(e_{t,i})

05 / 动态上下文

上下文向量不再固定,而是 source hidden states 的加权平均。不同输出词可以看不同输入位置。

c_t = sum_i alpha_{t,i} h_i

06 / 训练目标与推理差异

训练最小化负对数似然,常用 teacher forcing;推理时用贪心或 beam search,会出现 exposure bias 和长度偏差。

L = -sum_t log p(y_t^* | y_{<t}^*, x)

03 / 代码

NumPy 演示:带 dot-product attention 的最小 Seq2Seq forward pass

下面代码不依赖深度学习框架,显式展示 encoder states、attention weights、dynamic context 和 decoder logits 的计算路径。

import numpy as np

def softmax(z):
    z = z - z.max()
    exp_z = np.exp(z)
    return exp_z / exp_z.sum()

rng = np.random.default_rng(13)
vocab_in, vocab_out = 18, 20
d_emb, d_h = 5, 6

source = np.array([2, 5, 7, 11])      # x_1 ... x_T
target_in = np.array([1, 4, 8])       # <bos>, y_1, y_2 under teacher forcing

E_src = rng.normal(size=(vocab_in, d_emb)) / np.sqrt(d_emb)
E_tgt = rng.normal(size=(vocab_out, d_emb)) / np.sqrt(d_emb)

Wxh = rng.normal(size=(d_emb, d_h)) / np.sqrt(d_emb)
Whh = rng.normal(size=(d_h, d_h)) / np.sqrt(d_h)
bh = np.zeros(d_h)

Watt = rng.normal(size=(d_h, d_h)) / np.sqrt(d_h)
Wdec = rng.normal(size=(d_emb + d_h, d_h)) / np.sqrt(d_emb + d_h)
Ws = rng.normal(size=(d_h, d_h)) / np.sqrt(d_h)
bd = np.zeros(d_h)
Wo = rng.normal(size=(2 * d_h, vocab_out)) / np.sqrt(2 * d_h)
bo = np.zeros(vocab_out)

# Encoder: produce one hidden state per source token.
h = np.zeros(d_h)
encoder_states = []
for token in source:
    x_i = E_src[token]
    h = np.tanh(x_i @ Wxh + h @ Whh + bh)
    encoder_states.append(h)
encoder_states = np.stack(encoder_states)  # [T, d_h]

# Decoder: teacher forcing plus attention at every output step.
s = encoder_states[-1]
logits = []
alignments = []
for token in target_in:
    y_prev = E_tgt[token]
    scores = encoder_states @ Watt @ s
    alpha = softmax(scores)
    context = alpha @ encoder_states

    decoder_input = np.concatenate([y_prev, context])
    s = np.tanh(decoder_input @ Wdec + s @ Ws + bd)
    logits_t = np.concatenate([s, context]) @ Wo + bo

    logits.append(logits_t)
    alignments.append(alpha)

print("encoder states:", encoder_states.shape)
print("decoder logits:", np.stack(logits).shape)
print("attention weights for step 2:", alignments[1].round(3))

04 / 案例

案例:把一种研究语言翻译成另一种研究语言

  • Seq2Seq 的经典例子是机器翻译,但在 StatsPAI 场景中,它也可以把“研究任务描述”翻译成“代码草稿”,把“回归表”翻译成“结果段落”,把“审稿意见”翻译成“修改清单”。
  • 例如输入是四段审稿意见:数据来源、识别假设、稳健性、写作结构。Encoder 产生每段的状态;Decoder 生成 response letter 的每个句子。
  • 没有 attention 时,所有证据都被压进一个固定向量,后面的句子容易忘记早先的审稿点。有 attention 时,每个生成句子可以重新对齐到对应意见。
  • 如果要用于实证研究助手,alignment heatmap 可以作为审计线索:生成某句 rebuttal 时,模型主要参考了哪条评论或哪张表。

05 / 风险

常见误区

把固定 context 向量当作足够表达长输入;长文本、表格和多段评论通常需要 attention 或检索。
忽略 teacher forcing 与推理时自回归输入之间的分布差异,导致训练 loss 低但生成质量差。
Beam search 宽度越大不一定越好;它可能放大常见短句、重复句或长度偏差。
没有正确 mask padding token,attention 会把概率质量分给不存在的输入位置。
把 attention heatmap 当作因果解释。它能帮助审计模型路由,但不能替代识别假设或人工核验。
在研究自动化场景中,只优化生成流畅度而不检查引用、数据、代码和结果是否一致。

参考资料