Self Attention
Self-attention as a learned bilinear relation and a Nadaraya–Watson kernel smoother: why Q and K projections matter and how heads become readable.
-
Q and K Projections in JAX/Flax NNX
A runnable companion to Why Attention Needs Q and K Projections: build scaled dot-product attention with separate query and key projections in Flax NNX, pull the bilinear form B = W_Q W_Kᵀ out of the module, split it into a symmetric metric and an antisymmetric directed part, wire a toy induction head, add RoPE, and measure the low-rank budget and the gauge freedom, all in plain JAX.
-
Why Attention Needs Q and K Projections
The dot product in attention is not enough by itself. Without learned query and key projections, attention can only compare tokens in the residual stream’s native geometry. With a shared projection it learns a symmetric metric. With separate Q and K projections, the score becomes a learned bilinear form x_iᵀW_QW_Kᵀx_j: directional, role-aware, low-rank, and different per head. That bilinearity is what lets attention ask one kind of question and let tokens advertise another kind of answer.
-
Cheap Attention: Linear-Time Kernel Approximation
A 128K-token context creates billions of pairwise questions per attention head. But the N×N matrix is not the essence of attention; it is the receipt for an infinite feature map we never wrote down. Approximate that feature map with random features, reassociate the sum, and softmax attention becomes linear-time kernel attention. The whole argument is built from live in-browser visualizations.
-
Attention is Explainable Because it is a Kernel
Self-attention in transformers is a Nadaraya–Watson kernel smoother. That fact — and not "we visualize the matrix" — is why attention heads are readable while MLPs are not.