Cheap Attention in JAX/Flax NNX

· 7 min read

#ml#attention#kernels#linear-attention#performer#jax#flax#nnx#implementation#transformers#random-features

Explainer companionCheap Attention: Linear-Time Kernel ApproximationWant the full intuition first? This is the runnable companion to the explainer.Read the explainer

The theory post argued that cheap attention is not a separate trick from attention. It is the same kernel sum, evaluated through a finite feature map and parenthesized in the useful order. This is the implementation companion: the JAX arrays, the Flax NNX module, the causal recurrence, and the GIFs that show the bookkeeping changing from an all-pairs ledger to a shared state.

Animated comparison of gravity and attention: all-pairs ledgers versus shared fields or feature states, with compute bars for each side
The analogy in one GIF. Gravity has a direct N-body ledger and a field solve. Attention has exact softmax pair scores and a feature-state approximation. The details differ; the bookkeeping shape is the same.

The companion to Cheap Attention: Linear-Time Kernel Approximation should be boring in the best way. No custom CUDA, no special kernel, no hidden magic. Just three JAX operations:

  1. Project tokens into Q,K,VQ,K,V.
  2. Replace exp(qk)\exp(q\cdot k) with a positive random feature dot product ϕ(q)ϕ(k)\phi(q)^\top\phi(k).
  3. Compute ϕ(K)V\phi(K)^\top V before multiplying by ϕ(Q)\phi(Q).

That third line is the whole implementation story. Exact softmax attention does

softmax(QK)V,\mathrm{softmax}(QK^\top)V,

so the N×NN\times N score table sits in the middle. Linear attention does

Y^=D1ϕ(Q)(ϕ(K)V),Di=ϕ(qi)jϕ(kj),\widehat Y = D^{-1}\phi(Q)\big(\phi(K)^\top V\big), \qquad D_i = \phi(q_i)^\top\sum_j\phi(k_j),

so the middle object is no longer token-by-token. It is a feature-space state.

Animated JAX rendering of causal linear attention: token stream, key features, query features, running state S, and readout y
The implementation path. Each token writes ϕ(ki)vi\phi(k_i)v_i^\top into the running state SiS_i and ϕ(ki)\phi(k_i) into the normalizer ziz_i. The query reads ϕ(qi)Si/ϕ(qi)zi\phi(q_i)^\top S_i / \phi(q_i)^\top z_i.

The Feature Map

For the Performer-style approximation, use positive random features:

ϕ+(x)=exp(x2/2)m[exp(w1x),,exp(wmx)],wiN(0,I).\phi^+(x) = \frac{\exp(-\|x\|^2/2)}{\sqrt m} \left[ \exp(w_1^\top x), \ldots, \exp(w_m^\top x) \right], \qquad w_i\sim\mathcal N(0,I).

The positivity matters because these numbers become attention weights before normalization. A trig feature map can be unbiased and still put negative mass into a row. Positive features keep the denominator meaningful.

In JAX, the feature map is one einsum plus the norm correction:

import jax.numpy as jnp

def positive_features(x, omega):
    """x: [..., d_head], omega: [m, d_head] -> [..., m]."""
    proj = jnp.einsum("...d,md->...m", x, omega)
    norm = 0.5 * jnp.sum(x * x, axis=-1, keepdims=True)
    return jnp.exp(proj - norm) / jnp.sqrt(omega.shape[0])

For a multi-head implementation, each head can own its own random matrix:

def head_features(x, omega):
    """x: [batch, tokens, heads, d_head], omega: [heads, m, d_head]."""
    proj = jnp.einsum("bthd,hmd->bthm", x, omega)
    norm = 0.5 * jnp.sum(x * x, axis=-1, keepdims=True)
    return jnp.exp(proj - norm) / jnp.sqrt(omega.shape[1])

The NNX Module

Flax NNX modules are ordinary Python objects that subclass nnx.Module; submodules like nnx.Linear are assigned as attributes in __init__. Randomness is handled with nnx.Rngs, which is passed at initialization to create parameters.

Here is a compact non-causal linear-attention layer. It is the direct translation of the algebra above.

from flax import nnx
import jax
import jax.numpy as jnp

class LinearAttention(nnx.Module):
    def __init__(
        self,
        d_model: int,
        n_heads: int,
        n_features: int,
        *,
        rngs: nnx.Rngs,
    ):
        assert d_model % n_heads == 0
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_head = d_model // n_heads
        self.n_features = n_features

        self.q = nnx.Linear(d_model, d_model, rngs=rngs)
        self.k = nnx.Linear(d_model, d_model, rngs=rngs)
        self.v = nnx.Linear(d_model, d_model, rngs=rngs)
        self.o = nnx.Linear(d_model, d_model, rngs=rngs)

        omega = jax.random.normal(
            rngs.params(),
            (n_heads, n_features, self.d_head),
        )
        self.omega = nnx.Param(omega)

    def split_heads(self, x):
        b, t, _ = x.shape
        return x.reshape(b, t, self.n_heads, self.d_head)

    def positive_features(self, x):
        proj = jnp.einsum("bthd,hmd->bthm", x, self.omega[...])
        norm = 0.5 * jnp.sum(x * x, axis=-1, keepdims=True)
        return jnp.exp(proj - norm) / jnp.sqrt(self.n_features)

    def __call__(self, x, eps=1e-6):
        q = self.split_heads(self.q(x))
        k = self.split_heads(self.k(x))
        v = self.split_heads(self.v(x))

        phi_q = self.positive_features(q)  # [b, t, h, m]
        phi_k = self.positive_features(k)  # [b, t, h, m]

        # Shared feature-state. This is the line where N x N disappears.
        kv = jnp.einsum("bthm,bthd->bhmd", phi_k, v)
        z = jnp.sum(phi_k, axis=1)

        num = jnp.einsum("bthm,bhmd->bthd", phi_q, kv)
        den = jnp.einsum("bthm,bhm->bth", phi_q, z)[..., None]
        y = num / (den + eps)

        return self.o(y.reshape(x.shape))

The important shape is kv: [batch, heads, m, d_head]. It does not contain a token-token axis. Exact attention’s middle object is [batch, heads, tokens, tokens]; linear attention’s middle object is [batch, heads, features, head_dim].

Causal Decoding

For autoregressive decoding, the recurrence is even more explicit. At step ii:

Si=Si1+ϕ(ki)vi,zi=zi1+ϕ(ki),yi=ϕ(qi)Siϕ(qi)zi.S_i = S_{i-1} + \phi(k_i)v_i^\top, \qquad z_i = z_{i-1} + \phi(k_i), \qquad y_i = \frac{\phi(q_i)^\top S_i}{\phi(q_i)^\top z_i}.

For a whole training sequence, cumsum is the easiest way to write it:

def causal_linear_attention(phi_q, phi_k, v, eps=1e-6):
    """phi_*: [batch, tokens, heads, m], v: [batch, tokens, heads, d]."""
    kv = jnp.einsum("bthm,bthd->bthmd", phi_k, v)
    s = jnp.cumsum(kv, axis=1)
    z = jnp.cumsum(phi_k, axis=1)

    num = jnp.einsum("bthm,bthmd->bthd", phi_q, s)
    den = jnp.einsum("bthm,bthm->bth", phi_q, z)[..., None]
    return num / (den + eps)

For one-token-at-a-time generation, you do not need the full prefix arrays:

def decode_step(phi_q_i, phi_k_i, v_i, state, eps=1e-6):
    """One recurrent decoding step.

    phi_q_i, phi_k_i: [batch, heads, m]
    v_i:              [batch, heads, d]
    state.s:          [batch, heads, m, d]
    state.z:          [batch, heads, m]
    """
    s = state.s + jnp.einsum("bhm,bhd->bhmd", phi_k_i, v_i)
    z = state.z + phi_k_i
    num = jnp.einsum("bhm,bhmd->bhd", phi_q_i, s)
    den = jnp.einsum("bhm,bhm->bh", phi_q_i, z)[..., None]
    return num / (den + eps), state.replace(s=s, z=z)

That is the recurrent-network view: the cache is no longer all past keys and values. It is a fixed feature state.

Training Loop

The layer is an nnx.Module, so it fits into the normal NNX training pattern. With current NNX, nnx.Optimizer is constructed with wrt=nnx.Param, and update receives both the model and the gradients:

import optax

model = LinearAttention(
    d_model=512,
    n_heads=8,
    n_features=128,
    rngs=nnx.Rngs(0),
)
optimizer = nnx.Optimizer(
    model,
    optax.adamw(3e-4),
    wrt=nnx.Param,
)

@nnx.jit
def train_step(model, optimizer, batch):
    x, target = batch

    def loss_fn(model):
        pred = model(x)
        return jnp.mean((pred - target) ** 2)

    loss, grads = nnx.value_and_grad(loss_fn)(model)
    optimizer.update(model, grads)
    return loss

The implementation choice to watch is not jit; it is the tensor you permit yourself to create. If your forward pass forms scores = q @ k.T, you are back in the ledger world. If it forms kv = phi_k.T @ v, you have crossed into the shared-state world.

Rendering The GIFs

Both GIFs in this post are generated with Python, JAX, and matplotlib:

python scripts/render_gravity_attention_gif.py
python scripts/render_linear_attention_pipeline_gif.py

The first renderer computes:

The second renderer follows the causal linear-attention recurrence token by token. It is deliberately small: the point is not a benchmark, but a visual audit of the shapes. If the GIF cannot show where the N×NN\times N object disappeared, the implementation is probably not explaining itself.

What This Leaves Out

This is a teaching implementation, not a production kernel. A production version would care about:

Those choices matter. But they are second-order to the core idea. Linear attention is not “attention but faster” by accident. It is attention after you choose a finite feature map and refuse to materialize the pairwise ledger.


References: Flax NNX Module API and randomness guide; Performer / FAVOR+ from Choromanski et al. (2021); linear-transformer recurrence from Katharopoulos et al. (2020).

Cite as

Bouhsine, T. (). Cheap Attention in JAX/Flax NNX. Records of the !mmortal Data Scientist. https://tahabouhsine.com/blog/linear-attention-jax-flax-nnx/

BibTeX
@misc{bouhsine2026linearattentionjaxflaxnnx,
  author       = {Bouhsine, Taha},
  title        = {Cheap Attention in JAX/Flax NNX},
  year         = {2026},
  month        = {jun},
  howpublished = {\url{https://tahabouhsine.com/blog/linear-attention-jax-flax-nnx/}},
  note         = {Blog post, Records of the !mmortal Data Scientist}
}

References

  1. Katharopoulos, A., Vyas, A., Pappas, N., Fleuret, F. (2020). Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention. ICML 2020.arXiv:2006.16236
  2. Choromanski, K., et al. (2021). Rethinking Attention with Performers. ICLR 2021.arXiv:2009.14794