Architecture / Sequence Modeling

Attention: Q/K/V, Softmax Weights, and Context Vectors

Turn “where to look” into a differentiable similarity matrix: queries ask, keys are searched, and values are aggregated.

Mechanism Lab

Animation: Q/K/V turns tokens into a context vector

The animation projects tokens into Q/K/V, shows one query row forming a score matrix against every key, then applies scaling, masking, softmax, and a weighted sum over values.

Step 1 / 5

Tokens

The input sequence forms hidden matrix H, with one row per token.

H in R^{n x d_model}

Animation Control

Reduced-motion users receive the same step states without continuous motion.

01 / Intuition

Core Intuition

Attention is a learned retrieval operation: the current query is compared with every key to produce one row of weights.

After softmax, the weights are nonnegative and sum to one, so the output is a relevance-weighted mixture of value vectors.

The sqrt(d_k) scale controls dot-product variance and prevents large dimensions from saturating softmax too early.

Masks define visibility: encoders can usually see all tokens, decoders see only the past, and padding masks remove meaningless tokens.

02 / Math

Deriving scaled dot-product attention from token representations

01 / Linear projections

Let H in R^{n x d_model} be n token representations. Three learned matrices project the same tokens into query, key, and value spaces.

Q=H W_Q, K=H W_K, V=H W_V

02 / Similarity scores

The dot product between query i and key j gives an unnormalized score for how much position i should attend to position j.

S_ij = q_i^T k_j, S = QK^T

03 / Why scale by sqrt(d_k)

If components of q_i and k_j are approximately independent with mean zero and variance one, the dot-product variance grows with d_k. Scaling brings it back to order one.

Var(q_i^T k_j)=d_k -> Var(S_ij/sqrt(d_k))=1

04 / Softmax normalization

Softmax is applied row by row. The resulting alpha_ij values are nonnegative and each row sums to one.

A_ij = exp(S_ij)/sum_l exp(S_il)

05 / Weighted sum

The output z_i is the weighted average of all values. With mask M, invisible positions receive -infinity before softmax and zero weight after it.

Z = softmax(QK^T/sqrt(d_k)+M)V

06 / Multi-head decomposition

Multiple heads learn several similarity spaces in parallel, concatenate their outputs, and mix them with an output projection.

MHA(H)=Concat(head_1,...,head_m)W_O

03 / Code

NumPy demo: scaled dot-product attention from scratch

This snippet explicitly computes Q/K/V, scaled scores, a causal mask, softmax weights, and the final context vectors.

import numpy as np
import pandas as pd

def softmax(a, axis=-1):
    a = a - np.max(a, axis=axis, keepdims=True)
    exp = np.exp(a)
    return exp / exp.sum(axis=axis, keepdims=True)

def scaled_dot_attention(X, Wq, Wk, Wv, mask=None):
    Q = X @ Wq
    K = X @ Wk
    V = X @ Wv
    d_k = Q.shape[-1]

    scores = (Q @ K.T) / np.sqrt(d_k)
    if mask is not None:
        scores = np.where(mask, scores, -1e9)

    weights = softmax(scores, axis=-1)
    context = weights @ V
    return context, weights, scores

tokens = ["policy", "raises", "wages", "jobs"]
rng = np.random.default_rng(7)
X = rng.normal(size=(len(tokens), 6))

Wq = rng.normal(size=(6, 4)) / np.sqrt(6)
Wk = rng.normal(size=(6, 4)) / np.sqrt(6)
Wv = rng.normal(size=(6, 4)) / np.sqrt(6)

# Decoder-style causal mask: token i cannot attend to future tokens j > i.
causal_mask = np.tril(np.ones((len(tokens), len(tokens)), dtype=bool))
context, weights, scores = scaled_dot_attention(X, Wq, Wk, Wv, causal_mask)

print(pd.DataFrame(weights, index=tokens, columns=tokens).round(3))
print("row sums:", weights.sum(axis=1).round(6))
print("context shape:", context.shape)

04 / Case

Case: how a research assistant links questions, evidence, and conclusions

  • Suppose the user asks, “How does a minimum-wage policy affect employment?” The query comes from the current task state, while keys and values come from the question, paper paragraphs, variable definitions, regression tables, and citation context.
  • When the model writes that the employment effect depends on identification design, the query can compare against keys such as “minimum wage,” “employment,” “DID table,” and “control group,” then aggregate the relevant value vectors.
  • For table understanding, a cell query can attend to column names, row labels, units, standard-error notes, and footnotes, reducing confusion between coefficients, standard errors, and sample size.
  • In an empirical agent, attention is not causal evidence. It is a differentiable retrieval operation inside the representation; credibility still comes from source data, identification assumptions, rerunnable code, and diagnostics.

05 / Risks

Common Pitfalls

Treating attention weights as causal importance. They are internal model computations, not mechanisms in the data-generating process.
Omitting the sqrt(d_k) scale, which can saturate softmax and destabilize training in high dimensions.
Forgetting the causal mask in decoder or time-series settings, allowing future-token leakage.
Forgetting the padding mask, letting blank tokens receive attention mass.
Assuming self-attention contains order by itself; positional encoding or relative position mechanisms are needed.
Underestimating the O(n^2) length cost; long text, large tables, and high-frequency time series may need sparse, chunked, or retrieval-augmented variants.

References