Query–Key
Q and K projections as a learned bilinear form: directional, low-rank, role-aware attention, and the symmetric/antisymmetric split.
-
Q and K Projections in JAX/Flax NNX
A runnable companion to Why Attention Needs Q and K Projections: build scaled dot-product attention with separate query and key projections in Flax NNX, pull the bilinear form B = W_Q W_Kᵀ out of the module, split it into a symmetric metric and an antisymmetric directed part, wire a toy induction head, add RoPE, and measure the low-rank budget and the gauge freedom, all in plain JAX.
-
Why Attention Needs Q and K Projections
The dot product in attention is not enough by itself. Without learned query and key projections, attention can only compare tokens in the residual stream’s native geometry. With a shared projection it learns a symmetric metric. With separate Q and K projections, the score becomes a learned bilinear form x_iᵀW_QW_Kᵀx_j: directional, role-aware, low-rank, and different per head. That bilinearity is what lets attention ask one kind of question and let tokens advertise another kind of answer.