Organizing Randomness: Contrastive Learning in JAX

· 13 min read

#ml#contrastive#jax#optax#embeddings#contrastive-learning#infonce#supcon#siglip#triplet-loss#implementation#representation-learning#self-supervised-learning

The runnable companion to Untangling the Moons. There I raced the losses; here I build them.

Six contrastive losses organizing the same random 2D points side by side: pair, triplet, InfoNCE, SupCon, SigLIP, cosine-to-zero
Where we’re headed: six contrastive losses, the same random start, the same step counter. Every loss below is the ~15 lines of JAX that produces one of these panels. (Full walkthrough in the race.)

In Untangling the Moons I argued that the whole contrastive-learning lineage disagrees about one geometric question — how far apart should different-class points be? — and let you watch eight losses answer it in an interactive playground. This post is the other half: the actual JAX code that moves the points, and the same chaos-to-order trajectory rendered as a real GIF you can regenerate yourself.

The setup is the same trick that made the visualisations legible. Skip the encoder entirely. Take 60 points in 2D, assign them random labels, and treat the positions themselves as the embeddings. Then run gradient descent on a contrastive loss and watch the points organize. Random labels on random positions means there is no spatial signal at the start — the loss has to impose every bit of geometry itself. That is the cleanest possible test of what a loss actually wants.

Every loss below is about fifteen lines of jax.numpy. The complete, runnable generator — data, losses, training loop, and the matplotlib renderer that made these GIFs — lives in scripts/jax-contrastive/ in the repo.

The shared scaffolding

Everything that isn’t the loss is identical across all six. Start with the data: uniform points, balanced labels, then a shuffle so position tells you nothing about class.

import numpy as np

def make_random(n=60, k=2, seed=7):
    """Uniform points in [-1.5, 1.5]^2 with balanced, shuffled labels.

    No spatial signal: labels are assigned i % k then shuffled, so the loss
    must impose all of the geometry itself.
    """
    rng = np.random.default_rng(seed)
    pts = rng.uniform(-1.5, 1.5, size=(n, 2)).astype(np.float32)
    labels = np.array([i % k for i in range(n)], dtype=np.int32)
    rng.shuffle(labels)
    return pts, labels

Half the losses live on the unit circle — they compare directions, not positions, so cosine similarity is the natural metric. That means one helper, applied after every step:

import jax.numpy as jnp

def l2normalize(z):
    return z / (jnp.linalg.norm(z, axis=1, keepdims=True) + 1e-9)

The training loop is the payoff of writing each loss as a single differentiable scalar: the optimizer, the gradient, and the JIT never change. Swap the loss_fn and you have a different algorithm.

import jax
import optax

def make_step(loss_fn, lr, on_sphere, masks):
    opt = optax.sgd(lr)

    @jax.jit
    def step(z, opt_state, key, param):
        loss, grads = jax.value_and_grad(loss_fn)(z, None, key, param, m=masks)
        updates, opt_state = opt.update(grads, opt_state)
        z = optax.apply_updates(z, updates)
        if on_sphere:
            z = l2normalize(z)          # re-project onto the unit circle
        return z, opt_state, loss

    return step, opt

Two details worth flagging. The pairwise label masks (same, eye, triu) depend only on the labels, never on the positions, so they are built once on the host and closed over — that keeps the loss bodies loop-free and lets jax.jit compile a single fused kernel:

def make_masks(labels):
    same = labels[:, None] == labels[None, :]
    eye = np.eye(len(labels), dtype=bool)
    triu = np.triu(np.ones_like(same), k=1)   # each unordered pair once
    return {"same": jnp.asarray(same), "eye": jnp.asarray(eye),
            "triu": jnp.asarray(triu)}

And the key is threaded through every step with jax.random.split — only the two samplers (triplet and InfoNCE) actually use it, but passing it uniformly keeps the step signature identical for all losses. To measure progress we use the same nearest-centroid accuracy as the visualisations: assign each point to its closest class centroid, report the fraction that lands right. It is a cheap linear-separability proxy, computed in NumPy on a host snapshot every few steps along with a frame for the GIF.

Now the six losses. Each is a pure function of (z, labels, key, param) returning one scalar — the quantity whose gradient organizes the points.

1. Pair contrastive — Hadsell, Chopra & LeCun, 2006

The original. Pull same-class pairs together quadratically; push different-class pairs apart until they hit a margin, then go silent.

Lpair(i,j)={zizj2yi=yj[mzizj]+2yiyj\mathcal{L}_{\text{pair}}(i,j) = \begin{cases} \|z_i - z_j\|^2 & y_i = y_j \\ \bigl[\,m - \|z_i - z_j\|\,\bigr]_+^2 & y_i \neq y_j \end{cases}
def loss_pair(z, labels, key, margin, *, m):
    n = z.shape[0]
    diff = z[:, None, :] - z[None, :, :]
    d2 = jnp.sum(diff ** 2, axis=-1)
    d = jnp.sqrt(d2 + 1e-9)
    upper = m["triu"] > 0
    pos = jnp.where(m["same"] & upper, d2, 0.0)
    neg_active = (~m["same"]) & upper & (d < margin)
    neg = jnp.where(neg_active, (margin - d) ** 2, 0.0)
    return (jnp.sum(pos) + jnp.sum(neg)) / n
Pair contrastive loss organizing random 2D points into two tight clusters a margin apart
Pair contrastive (margin 1.2). Each class collapses to a point; the two points sit a margin apart, and then the gradient vanishes. The margin knows when to stop — its virtue and its limitation.

This is Euclidean — no sphere. The whole thing is one masked sum over the pairwise distance matrix. Note the / n normalization: it matches the update scale the interactive visualisations used, and dividing by the point count rather than the pair count keeps the gradient strong enough to actually move 60 points in a couple hundred steps.

2. Triplet — FaceNet, 2015

Make the comparison relative. For each anchor, sample one positive and one negative; require the negative to sit at least a margin farther than the positive. Only violating triplets contribute.

Ltrip(a,p,n)=[zazp2zazn2+m]+\mathcal{L}_{\text{trip}}(a,p,n) = \bigl[\,\|z_a - z_p\|^2 - \|z_a - z_n\|^2 + m\,\bigr]_+
def loss_triplet(z, labels, key, margin, *, m):
    kp, kn = jax.random.split(key)
    pos_mask = m["same"] & (~m["eye"])
    neg_mask = ~m["same"]
    pidx = jax.random.categorical(kp, jnp.where(pos_mask, 0.0, -1e9), axis=1)
    nidx = jax.random.categorical(kn, jnp.where(neg_mask, 0.0, -1e9), axis=1)
    d_ap = jnp.sum((z - z[pidx]) ** 2, axis=1)
    d_an = jnp.sum((z - z[nidx]) ** 2, axis=1)
    hinge = jnp.clip(d_ap - d_an + margin, a_min=0.0)
    violators = jnp.sum(hinge > 0)
    return jnp.sum(hinge) / jnp.maximum(violators, 1.0)
Triplet loss organizing random 2D points, with a wobbly trajectory from stochastic sampling
Triplet loss. The trajectory wobbles because each step samples a fresh anchor / positive / negative — the gradient is stochastic. Same destination as pair contrastive, noisier path.

The sampling is the interesting part of the JAX. There are no Python loops: jax.random.categorical over a masked logit row picks one positive and one negative per anchor in a single vectorized call, fully inside jit. Setting disallowed entries to -1e9 makes them unreachable. This is the idiom for “sample from a per-row set” in JAX — no gather loops, no host round-trips.

3. InfoNCE / NT-Xent — van den Oord 2018, SimCLR 2020

Replace the hard margin with a softmax over cosine similarities. Every other point is a negative; the gradient never fully vanishes.

LInfoNCE(a)=logexp(sim(za,zp)/τ)kaexp(sim(za,zk)/τ)\mathcal{L}_{\text{InfoNCE}}(a) = -\log \frac{\exp(\text{sim}(z_a, z_p)/\tau)}{\sum_{k \neq a} \exp(\text{sim}(z_a, z_k)/\tau)}
def loss_infonce(z, labels, key, tau, *, m):
    sim = (z @ z.T) / tau
    logits = jnp.where(m["eye"], -1e9, sim)          # mask self
    log_z = jax.nn.logsumexp(logits, axis=1)          # over k != i
    pos_mask = m["same"] & (~m["eye"])
    pidx = jax.random.categorical(key, jnp.where(pos_mask, 0.0, -1e9), axis=1)
    pos_sim = jnp.take_along_axis(sim, pidx[:, None], axis=1)[:, 0]
    return jnp.mean(log_z - pos_sim)
InfoNCE organizing random 2D points into two arcs on the unit circle
InfoNCE on the unit circle. The two classes settle into opposing arcs and keep micro-adjusting — the softmax always has a little more gradient to give.

Because z is normalized, z @ z.T is the cosine matrix directly. The whole loss is jax.nn.logsumexp doing the numerically stable denominator and one sampled positive in the numerator. That softmax-over-similarities is exactly the Nadaraya–Watson kernel operator I unpacked in Attention is Explainable Because it is a Kernel — InfoNCE is that operator pointed at a different objective.

4. SupCon — Khosla et al., 2020

When you have labels, every same-class point is a positive. Average the InfoNCE term over all of them — no sampling at all.

LSupCon(a)=1P(a)pP(a)logexp(sim(za,zp)/τ)kaexp(sim(za,zk)/τ)\mathcal{L}_{\text{SupCon}}(a) = \frac{-1}{|P(a)|}\sum_{p \in P(a)} \log \frac{\exp(\text{sim}(z_a, z_p)/\tau)}{\sum_{k \neq a} \exp(\text{sim}(z_a, z_k)/\tau)}
def loss_supcon(z, labels, key, tau, *, m):
    sim = (z @ z.T) / tau
    logits = jnp.where(m["eye"], -1e9, sim)
    log_p = logits - jax.nn.logsumexp(logits, axis=1, keepdims=True)
    pos = (m["same"] & (~m["eye"])).astype(z.dtype)
    p_count = jnp.sum(pos, axis=1)
    per_anchor = -jnp.sum(pos * log_p, axis=1) / jnp.maximum(p_count, 1.0)
    return jnp.mean(per_anchor)
SupCon organizing four classes of random 2D points into four tight clusters on the unit circle
SupCon on four classes. Averaging over every positive pulls each point toward its class centroid — the classes condense to near-Dirac clusters. The harshest collapser of the six.

The shape from the visualisations is visible immediately: SupCon is the tightest of all the losses. Averaging over many positives is, in effect, pulling each point toward its class centroid every step. Great for downstream classification; brutal to any within-class variation you might have cared about.

5. SigLIP — Zhai et al., 2023

Drop the softmax entirely. Score each pair independently with a sigmoid, and — crucially — put a bias on it so you choose where different-class pairs stop being pushed apart.

LSigLIP=i<jlog ⁣(1+exp(yij(tsimij+b)))\mathcal{L}_{\text{SigLIP}} = \sum_{i<j} \log\!\bigl(1 + \exp(-y_{ij}\,(t \cdot \text{sim}_{ij} + b))\bigr)
def loss_siglip(z, labels, key, target, *, m):
    n = z.shape[0]
    t, b = 10.0, -10.0 * target
    sim = z @ z.T                                    # cosine; z on the sphere
    sign = jnp.where(m["same"], 1.0, -1.0)           # y_ij = +1 same / -1 diff
    per_pair = jax.nn.softplus(-sign * (t * sim + b))
    return jnp.sum(jnp.where(m["triu"] > 0, per_pair, 0.0)) / n
SigLIP organizing random 2D points into two arcs that stop near orthogonality
SigLIP with the negatives target near cosine 0. The pairwise sigmoid saturates once each negative is past the bias threshold, so the classes stop spreading at orthogonality instead of overshooting to opposition.

jax.nn.softplus(x) is the stable log(1 + e^x), and target sets the bias b = -t·target — the cosine at which the loss for a negative pair goes flat. Set it near zero and negatives equilibrate at orthogonality, not at the diametric opposition every softmax loss implicitly chases. That is the geometric fix I argued for in Opposite Is Not Different.

6. Cosine→0 — the orthogonality objective

The simplest objective consistent with that argument. Pull same-class pairs to cosine 1; push different-class pairs to cosine 0 — orthogonality, not opposition.

Lcos→0(i,j)={1cos(zi,zj)yi=yjcos(zi,zj)2yiyj\mathcal{L}_{\text{cos→0}}(i,j) = \begin{cases} 1 - \cos(z_i, z_j) & y_i = y_j \\ \cos(z_i, z_j)^2 & y_i \neq y_j \end{cases}
def loss_orthog(z, labels, key, _unused, *, m):
    n = z.shape[0]
    c = z @ z.T
    per_pair = jnp.where(m["same"], 1.0 - c, c ** 2)
    return jnp.sum(jnp.where(m["triu"] > 0, per_pair, 0.0)) / n
Cosine-to-zero loss organizing two classes of random points onto perpendicular axes
Cosine→0 on two classes. Same-class collapses to a point, different-class settles at a right angle — orthogonality, the dimension-independent stop this loss always targets. (The strict max-margin arrangement for two classes is actually antipodal; cos→0 stops at orthogonality on purpose, since that is the target that generalizes to many classes.)

Three lines of math, and the cleanest geometry of the six: two classes land on perpendicular directions and stop, because the loss is genuinely minimized there rather than fighting the spherical geometry forever. The target is a dimension-independent approximation to the simplex optimum 1/(k1)-1/(k-1) — nearly exact when classes are many, deliberately conservative when they are few. The companion post works through why that is the right trade, and why it only holds when the class count doesn’t exceed the dimension.

One honest caveat the code made me confront. From a perfectly balanced random start, this loss can stall at a symmetric configuration — each class smeared uniformly around the circle, where same-class attraction and different-class repulsion cancel. The softmax losses break that symmetry through their sampled positive or their normalization; bare orthogonality has nothing to break it with. The fix in the generator is mundane — a slightly smaller point count and a seed that isn’t perfectly symmetric — but it is a real lesson: an objective can be correct at its minimum and still need help escaping a saddle.

The two we skipped

For brevity this guide builds six of the eight losses from the prior post. CLIP is symmetrized InfoNCE — run loss_infonce in both directions (anchor→positive and positive→anchor) and average; the geometry is identical. Alignment + Uniformity (Wang & Isola, 2020) splits InfoNCE into an explicit pull-positives-together term and a spread-everyone-out term; both are a few more lines in the same style. Both are explored in Untangling the Moons.

The race

Same losses, same random initialization, same step counter — the six panels side by side, which is where the disagreements become obvious.

Six contrastive losses organizing the same random points side by side: pair, triplet, InfoNCE, SupCon, SigLIP, cosine-to-zero
All six on the same random start. Pair and triplet (Euclidean) collapse to points and freeze; InfoNCE and SigLIP fan into arcs; SupCon condenses four classes to four clusters; cosine→0 lands two classes on perpendicular axes. The margin family stops early, the softmax family keeps fine-tuning, and SigLIP and cosine→0 stop at the geometrically correct place.

The pattern is the one from the interactive version, now in committed, reproducible pixels: the margin family freezes once satisfied, the softmax family never stops, and only SigLIP and cosine→0 halt where the geometry actually wants them to.

Run it yourself

Everything here is deterministic given the seed. The generator is six small files — data.py, losses.py, train.py, render.py, generate.py — in scripts/jax-contrastive/:

pip install -r requirements.txt        # jax, optax, matplotlib, imageio
python generate.py --loss all --grid

Every GIF on this page lands in public/jax-contrastive/. Change --loss, --dataset (random, random-4, moons), --seed, or --steps and watch a different trajectory. The optimization is reproducible to the seed; the rendered pixels may shift slightly across matplotlib versions, so the guarantee is on the trajectory, not the bytes.

If you want the why behind these six geometries — twenty years of why different-class points kept getting pushed too far apart — that is Untangling the Moons. This was the how.

Cite as

Bouhsine, T. (). Organizing Randomness: Contrastive Learning in JAX. Records of the !mmortal Data Scientist. https://tahabouhsine.com/blog/organizing-randomness-jax/

BibTeX
@misc{bouhsine2026organizingrandomnessjax,
  author       = {Bouhsine, Taha},
  title        = {Organizing Randomness: Contrastive Learning in JAX},
  year         = {2026},
  month        = {may},
  howpublished = {\url{https://tahabouhsine.com/blog/organizing-randomness-jax/}},
  note         = {Blog post, Records of the !mmortal Data Scientist}
}