↼ Back to Thoughts

4. How SSL Works: DINO + SparK (And Dealing with Plateauing)

Stop Blindly Applying ImageNet SSL to Medical Data. What DINO and SparK Taught Me.

Part 4 of 6 in the SCARCE-CXR series

4.1 DINO: Self-Distillation and Architectural Guardrails

Contrastive learning works by comparison: pull two views of the same image together, push views of different images apart. That second part needs different images in the batch to push against, which is why batch size and memory queues are so critical. Self-distillation drops the negatives entirely. A "student" network does not compare against other images at all. Instead, it learns to match the output of a "teacher" network for the exact same image.

The problem is that without negatives pushing embeddings apart, the network finds the laziest solution: output the same constant vector for every image and drive the loss to zero. Early methods like BYOL dealt with this through architectural guardrails: an asymmetric predictor head on the "student" side only, plus a stop-gradient on the "teacher". These worked empirically; remove either one and training collapses. But nobody fully understood why the asymmetry prevented collapse. We needed a mathematically transparent way to prevent collapse, not one that just "worked" somehow .

DINO replaces those tricks with a mechanism one can actually reason about: a centring buffer (more on that next). If it fails, you know exactly what to look at. On ImageNet with ViT backbones, it worked well enough to produce attention maps that trace object boundaries without any human supervision. That result convinced me it should generalize anywhere.

4.2 DINO: Model Internals

Compared to MoCo v2, the infrastructure needed drops significantly. MoCo v2 holds a 65,536-entry queue in GPU memory, large batches to keep those negatives diverse, and computes a loss that pits every query against all 65,536 stored embeddings. DINO has none of that:

MoCo v2 DINO
Negatives 65,536-entry queue of past embeddings None
Batch-size sensitivity High; queue needs diverse images Low
Loss compares query against all 65,536 queue entries "teacher"/"student" crop pairs from current batch
Collapse prevention negatives push embeddings apart centring buffer + temperature asymmetry
Extra GPU memory 65,536 × 128-dim queue K-dim centring vector

The architecture is simple. Both the "student" and "teacher" share the same structure: a backbone followed by a DINOHead projection head outputting out_dim=65,536 logits. The "teacher" never receives gradients; its weights slowly track the "student" via EMA.

# From ssl_methods/dino/model.py

# Both share the same architecture: backbone → DINOHead (3-layer MLP → L2-norm → 65536-dim linear)
self.student = nn.Sequential(backbone_s, DINOHead(feature_dim, out_dim=65536))
self.teacher = nn.Sequential(backbone_t, DINOHead(feature_dim, out_dim=65536))

# 1. EMA "teacher": weights drift toward "student" via EMA
for p in self.teacher.parameters():
    p.requires_grad = False # "teacher" never updated by gradient descent

@torch.no_grad()
def update_teacher(self, momentum: float):
    for p_s, p_t in zip(self.student.parameters(), self.teacher.parameters()):
        p_t.data = p_t.data * momentum + p_s.data * (1.0 - momentum) # at momentum=0.996, only 0.4% per step

# Forward: "student" sees all crops, "teacher" only sees global crops
student_out = [self.student(v) for v in views]
with torch.no_grad():
    teacher_out = [self.teacher(views[i]) for i in range(n_global)]
DINO multi-crop augmented views: original X-ray beside two global crops the teacher and student each see
DINO's multi-crop augmentation: two large global views (Global 1, Global 2) at 224px and six smaller local crops (Local 1-6) at 96px. The "student" sees all eight; the "teacher" only sees the two global views.

DINO's loss is a cross-entropy between two probability distributions. The "teacher" passes each global crop through a projection head to get K=65,536 logits, applies centring and a low temperature to produce a sharp, peaked distribution, then detaches from the graph. The "student" processes all crops and must match those distributions.

# From ssl_methods/dino/loss.py

# 2. Temperature asymmetry: teacher_temp=0.04
# 3. Centring buffer: subtract center
teacher_probs = [
    F.softmax((t - self.center) / self.teacher_temp, dim=-1).detach()
    for t in teacher_output # peaked distribution, no gradients
]
# 2. Temperature asymmetry: student_temp=0.1 for softer distribution
student_log_probs = [
    F.log_softmax(s / self.student_temp, dim=-1)
    for s in student_output # all crops; higher temp keeps distribution softer
]

# Cross-entropy over all (teacher_i, student_j) pairs, skip same-view
total_loss -= torch.mean(torch.sum(t_prob * s_lp, dim=-1))

# 3. Centring buffer: update running mean
batch_center = torch.cat(teacher_output).mean(dim=0, keepdim=True)
self.center = self.center * self.center_momentum + batch_center * (1.0 - self.center_momentum)

Three structural choices keep the system from collapsing:

1. EMA "teacher": the "teacher" never receives gradients. Its weights drift toward the "student" each step by exponential moving average. A fast-updating "teacher" would produce an inconsistent training signal. At momentum=0.996 it moves only 0.4% per step, providing a stable target.

2. Temperature asymmetry: teacher_temp=0.04 makes the "teacher"'s distribution sharp and peaked. student_temp=0.1 keeps the "student" softer. Matching a peaked target forces the "student" to concentrate probability mass on the same prototypes, not approximate a uniform distribution.

3. Centring buffer: a running mean is subtracted from the "teacher" logits before softmax. If all mass concentrates on one output dimension, the running mean tracks and subtracts it, flattening that bias out and preventing single-prototype collapse.

4.3 DINO: Plateauing Loss

The centring buffer assumes that different images produce meaningfully different "teacher" outputs. That assumption holds on ImageNet. It does not hold here.

On ImageNet, a dog produces a wildly different "teacher" output than a car. But chest X-rays share the same modality, gross anatomy, and grayscale range, so the "teacher" assigns near-identical probability distributions to everything. The running center average converges to the default output for any image. Then t - self.center approaches zero, and softmax(0 / 0.04) returns a uniform distribution. The "student" is asked to match a flat line. There is no gradient signal. We confirmed this by running collapse_monitor.py on the DINO checkpoints: mean pairwise cosine similarity climbed steadily while effective rank fell, both signatures of collapse.

DINO collapse monitor: mean_cos rising, eff_rank falling

Our loss flatlined at 5.42 and never moved beyond that. I tried the standard fixes: dropping output dimension, raising centre momentum to 0.99, adjusting the "teacher" temperature warmup, disabling local crops. None of it moved the loss.

DINO training loss: flatlined from epoch 0
The rise is expected: teacher_temp warmup sharpens the target distribution, raising cross-entropy mechanically. The flat ceiling is the collapse: chest X-ray homogeneity drives the centring buffer to output uniform distributions, giving the "student" nothing to learn from.

Takeaway: DINO's self-distillation requires visual diversity that chest X-rays simply do not have.

4.4 SparK: Generative Models and Convolution Leakage

Contrastive learning has a blind spot when applied to X-rays: augmentation invariance. The model is rewarded for ignoring differences between two crops of the same image, but a fibrotic band or a tiny nodule is so subtle that two crops of the same X-ray look nearly identical with or without it. Contrastive learning has little incentive to preserve a microscopic feature.

In 2018, BERT transformed NLP by masking random words in a sentence and training a model to predict them from context. That idea became the dominant pretraining recipe in language, and MAE applied the same logic to vision: mask out 75% of image patches, then reconstruct them. It worked extremely well. But ViTs treat images as sequences of discrete patch tokens so masking is clean and no information leaks across boundaries. CNNs were systematically left out of this entire line of progress. SparK changed that by basically designing BERT for convolutional networks.

Lineage: BERT to SparK
flowchart LR
    bert["BERT (2018)\nNLP\n─────────────\nmask random words\npredict from context"]
    mae["MAE (2021)\nVision Transformer\n─────────────\nmask patches\nreconstruct image"]
    spark["SparK (2023)\nConvolutional Network\n─────────────\nmask pixel space\nreconstruct via U-Net"]

    bert --> mae
    mae --> spark
                    

A ViT was not an option at our scale. They have no spatial inductive bias and need over a million images to learn useful patch relationships. Given our smaller dataset with essential spatial relations, a ViT would underfit badly. A ResNet was the only viable backbone. But simply slapping a ResNet on a MAE introduces two problems:

1. Convolutional leakage: standard convolutions slide their kernels across the image and leak visible pixel data into masked regions, allowing the model to cheat.

2. Spatial resolution loss: ResNets compress a 224x224 image down to a 7x7 grid, permanently destroying the high-frequency detail needed to reconstruct smaller features.

SparK is the architectural fix for both.

4.5 SparK: Model Internals

SparK asks the question: given some patches of an image, can the network reconstruct the rest? The whole architecture follows from that change.

The forward pass has four stages.

1. Masking. Before the first convolution, the 224x224 input is divided into non-overlapping 32x32 patches. 60% are randomly selected and zeroed out in raw pixel space. This is the key architectural difference from MAE: a ViT only processes tokens for visible patches, so masked content never enters the computation at all. A CNN has no such luxury: convolutional kernels slide over the entire image regardless. Masking in pixel space before the first convolution is necessary, but it does not prevent leakage on its own. The paper's solution is sparse convolutions that physically skip masked positions entirely.

Original chest X-ray beside the same image with 60% of 32x32 patches zeroed out

2. Encoding. The masked image passes through the standard ResNet50 stem and four residual stages, producing hierarchical feature maps at progressively coarser spatial resolutions: f1 at 56x56, f2 at 28x28, f3 at 14x14, and f4 at 7x7 (the bottleneck).

3. Decoding. A lightweight U-Net decoder upsamples from the 7x7 bottleneck back to 224x224 using skip connections from f1, f2, and f3. Without those skips, the decoder would have to hallucinate sharp ribs and vasculature from a 7x7 grid (geometrically impossible).

SparK: ResNet50 encoder + U-Net decoder
flowchart TB
    inp["input 224x224 (masked)"]
    inp --> f1["f1: 56x56, 256ch"]
    f1  --> f2["f2: 28x28, 512ch"]
    f2  --> f3["f3: 14x14, 1024ch"]
    f3  --> f4["f4: 7x7, 2048ch (bottleneck)"]
    f4  --> d3["upsample to 14x14 + f3 skip"]
    d3  --> d2["upsample to 28x28 + f2 skip"]
    d2  --> d1["upsample to 56x56 + f1 skip"]
    d1  --> out["reconstruction 224x224"]
    f3 -.->|skip| d3
    f2 -.->|skip| d2
    f1 -.->|skip| d1
                    

4. Loss. The target is the original unmasked image, with each 32x32 patch normalised to zero mean and unit variance before computing the loss. MSE is evaluated on masked pixels only:

# From ssl_methods/spark/model.py

# Normalise each 32x32 patch to zero mean, unit variance
# Guessing the average grey value now yields 0 on mean but still fails on variance
# To minimise MSE, the network must reconstruct the actual high-frequency edges
target = self._patchwise_normalize(x)

# MSE on masked regions only (same idea as BERT: loss only on [MASK] tokens)
# visible pixels are in the input; including them lets the decoder copy for free
loss = ((pred - target) ** 2 * mask_px).sum() / (mask_px.sum() * x.shape[1])

4.6 SparK: When Blind Implementations Don't Work

Deploying the original theory within our compute constraints (a single 22GB GPU) required three specific implementation fixes.

Fix 1: Input-level masking instead of per-stage sparse simulation. The original SparK simulates sparse convolution using standard PyTorch operators: after each residual stage, it re-applies a downsampled version of the mask to the feature map, zeroing masked positions before the next stage sees them. This prevents masked-region values from accumulating across layers. This implementation takes the simpler route: zero masked patches once at input and run standard Conv2d through all four stages. The leakage cost is real but bounded, and on 112k images rather than ImageNet's 1.28M the easier pretext task is the right call.

# From ssl_methods/spark/model.py

def _random_mask(self, imgs):
    B, C, H, W = imgs.shape
    p = self.patch_size # 32
    nh, nw = H // p, W // p # 7×7 = 49 patches total
    n_mask = int(nh * nw * self.mask_ratio) # 60% → 29 patches masked

    # Same trick as MAE: random scores + argsort = random permutation without replacement
    noise = torch.rand(B, nh * nw, device=imgs.device)
    ids = torch.argsort(noise, dim=1)
    mask = torch.zeros(B, nh * nw, device=imgs.device)
    mask.scatter_(1, ids[:, :n_mask], 1.0) # 1 = masked, 0 = visible

    # Upsample patch-level mask (7×7) to pixel-level mask (224×224)
    mask_px = mask.reshape(B, 1, nh, nw)
    mask_px = F.interpolate(mask_px, scale_factor=float(p), mode="nearest")
    return imgs * (1.0 - mask_px), mask_px

With sparse convolutions, feature maps are zero at every masked position: the decoder must reconstruct from the 7×7 bottleneck alone. With standard convolutions, the kernel still slides over zeroed patches, so visible neighbours bleed through: the decoder gets a free spatial hint at every masked boundary before reconstruction runs. This means the task is slightly easier than true SparK. On 112k images rather than ImageNet's 1.28M, that trade-off is acceptable; an easier pretext task is less likely to leave the encoder under-trained. That's also why I set the mask ratio to 60% rather than MAE's 75% to make the task easier.

0.7
0.5
0.8
0
0
0
0.6
0.4
0.7
0.4
0.7
0.6
0
0
0
0.7
0.8
0.5
0.8
0.6
0.5
0
0
0
0.5
0.6
0.8
center input: 0.7 output 0.63 All 9 inputs visible. Normal output.

Fix 2: U-Net OOM. The U-Net skip connections require f1 through f4 to stay alive in VRAM simultaneously during the forward pass. On an L4 GPU, this immediately OOMs.

# From ssl_methods/spark/model.py

masked, mask_px = self._random_mask(x)

# Gradient checkpointing trades compute for VRAM
# PyTorch discards activations on the forward pass and recomputes them during backward
h = self.stem(masked)
f1 = ckpt.checkpoint(self.layer1, h, use_reentrant=False) # (B, 256, 56, 56)
f2 = ckpt.checkpoint(self.layer2, f1, use_reentrant=False) # (B, 512, 28, 28)
f3 = ckpt.checkpoint(self.layer3, f2, use_reentrant=False) # (B, 1024, 14, 14)
f4 = ckpt.checkpoint(self.layer4, f3, use_reentrant=False) # (B, 2048, 7, 7)

pred = self.decoder(f1, f2, f3, f4)

By wrapping each encoder stage in ckpt.checkpoint, PyTorch discards the high-resolution feature maps during the forward pass and recomputes them from scratch during loss.backward(). Trading compute time for memory was the only way to fit the U-Net on 22GB.

Fix 3: Patchwise normalisation. If you ask a network to fill in a missing 32x32 patch of a lung, the cheapest way to minimise MSE is to guess the average grey value of the surrounding tissue. It paints a flat grey square. Loss drops but the model isn't actually learning any features. Normalising each patch independently closes that shortcut. The cost is the visible patch grid in the reconstructions: each patch is predicted with its own statistics and no operation couples adjacent patches, so brightness boundaries between them don't align.

# From ssl_methods/spark/model.py

def _patchwise_normalize(self, imgs):
    B, C, H, W = imgs.shape # (256, 1, 224, 224)
    p = 32 # patch size
    nh = H // p # 224 / 32 = 7 patches vertically
    nw = W // p # 224 / 32 = 7 patches horizontally

    # Goal: one mean and one variance per 32x32 patch.
    # Problem: torch.mean() only reduces over trailing dims, so all pixels
    # of one patch must be contiguous at the end of the tensor.
    #
    # Step 1: split H and W into (grid position, within-patch offset)
    #   (256, 1, 224, 224) → (256, 1, 7, 32, 7, 32)
    #   the 7s index which patch; the 32s index pixels inside that patch
    #
    # Step 2: move grid dims forward, patch content to the back
    #   (256, 1, 7, 32, 7, 32) → (256, 7, 7, 1, 32, 32)
    #   last 3 dims (1, 32, 32) are now exactly one patch's pixels
    x = imgs.reshape(B, C, nh, p, nw, p).permute(0, 2, 4, 1, 3, 5)

    # Step 3: reduce over last 3 dims → shape (256, 7, 7, 1, 1, 1)
    #   one scalar per patch, independent of all neighbours
    mean = x.mean(dim=(-1, -2, -3), keepdim=True)
    var = x.var(dim=(-1, -2, -3), keepdim=True)
    return (x - mean) / (var + 1e-6).sqrt() # each patch independently N(0,1)

4.7 SparK: Results That Aren't Just a Loss Graph

SparK training loss across 200 epochs

To visualise what the network was learning, the model is run on val images and its predictions are composited back onto the originals. First, the model predicts in per-patch-normalised space, so predictions are re-scaled using the mean and standard deviation of the visible pixels to put all patches on the same brightness reference frame. Second, the composite feathers the edges with a Gaussian blur.

# From data/viz/spark_reconstructions.py

# Re-scale predictions to match visible pixel brightness
vis_pixels = orig_full[mask == 0]
global_mean = float(vis_pixels.mean())
global_std = float(vis_pixels.std())
recon_display = (pred_gray * global_std + global_mean).clip(0, 1)

# Feather mask edges for smooth composite
recon_smooth = gaussian_filter(recon_display, sigma=2.0)
soft_mask = gaussian_filter(mask.astype(np.float32), sigma=3.0)
composite = orig_full * (1 - soft_mask) + recon_smooth * soft_mask
The same chest X-ray before and after masking: 60% of 32x32 patches zeroed out
What the network sees. 60% of 32×32 patches are zeroed out before the encoder. The task is to reconstruct the missing regions from the visible 40%.

At epoch 50, blurry outlines were visible in the reconstructions.

SparK reconstructions at epoch 50
Epoch 50.

By epoch 100, coarse lung field boundaries and rib cage structure were visible in the reconstructions.

SparK reconstructions at epoch 100
Epoch 100. Left: original. Right: composite (visible patches + network predictions blended).

By epoch 199, finer details like individual rib edges and lung vasculature emerged.

SparK reconstructions at epoch 199
Epoch 199.

No matter what epoch, the patches will always be visible. This is a direct consequence of patchwise normalisation: the loss normalises each 32x32 patch to zero mean and unit variance independently before computing MSE. Adjacent predicted patches have slightly different overall exposure, and there is nothing in the loss to penalise that mismatch at the boundary, nor should it.

4.8 Cloud Training, Budgets, and Optimizers

Each method uses the optimizer from its original paper since the hyperparameters were derived under those assumptions.

1. MoCo v2: SGD, momentum 0.9, lr=4.5e-2 (linear-scaled from batch size: 0.03 × 384/256), weight decay 1e-4, 10-epoch warmup. SGD because the linear scaling rule in the original work assumes isotropic gradient steps; AdamW's per-parameter rescaling breaks that formula.

2. BarlowTwins: AdamW, lr=2.4e-3, weight decay 1e-4, 10-epoch warmup.

3. SparK: AdamW, lr=3.0e-4, weight decay 5e-2, 20-epoch warmup.

4. DINO: AdamW, lr=5.0e-4, weight decay 4e-2, 10-epoch warmup.

All runs use cosine annealing to 0 after warmup and mixed-precision training (torch.cuda.amp) throughout.

MoCo v2 learning rate schedule: warmup then cosine decay over 800 epochs BarlowTwins learning rate schedule: warmup then cosine decay over 200 epochs SparK learning rate schedule: 20-epoch warmup then cosine decay over 200 epochs DINO learning rate schedule: loss plateaued at epoch 15 of 100, halfway through warmup
Learning rate schedules of MoCo v2, BarlowTwins, SparK, and DINO.

All pretraining ran on a Google Cloud L4: 22GB VRAM, 16 vCPU, 64GB RAM at ~$1.25/hr (g2-standard-16).

Method Backbone Batch Epochs Failed runs Final run ~Total time ~Cost
MoCo v2ResNet503848007106h~149h~$187
BarlowTwinsResNet18512200617h~18h~$22
SparKResNet50256200540h~40h~$50
DINOResNet5064~5 (failed)10+~30h~$38
Total~237h~$297
© 2026 Ryan Zhou