↼ Back to Thoughts

5. Training Validation on a Single GPU

Distributed code without distributed hardware

Part 5 of 6 in the Distributed Training series

5.1 The Strategy Abstraction

As multiple approaches emerged (standard data parallelism, bucketed reduction, optimizer sharding, tensor parallelism), I introduced a strategy abstraction to separate training logic from distributed mechanics:

class DistributedStrategy:
    def wrap_model(self, model: nn.Module) -> nn.Module:
        raise NotImplementedError
    
    def wrap_optimizer(self, optimizer_cls, model, **kwargs) -> Optimizer:
        return optimizer_cls(model.parameters(), **kwargs)
    
    def sync_gradients(self, model: nn.Module):
        pass  # Override if manual sync needed

Each strategy implements these hooks differently. The DDPBucketedStrategy wraps the model in my custom bucketed gradient sync and requires an explicit sync call after backward:

class DDPBucketedStrategy(DistributedStrategy):
    def wrap_model(self, model):
        return DDPBucketed(model, bucket_size_mb=25.0)
    
    def sync_gradients(self, model):
        model.finish_gradient_synchronization()

The ZeROStrategy combines standard DDP for gradient sync with sharded optimizer state:

class ZeROStrategy(DistributedStrategy):
    def wrap_model(self, model):
        return DDP(model, device_ids=[local_rank])
    
    def wrap_optimizer(self, optimizer_cls, model, **kwargs):
        return ShardedOptimizer(model.parameters(), optimizer_cls, **kwargs)

This made it possible to switch between approaches without rewriting the training loop:

# Single GPU
python train.py --strategy single

# PyTorch DDP
torchrun --nproc_per_node=4 train.py --strategy ddp

# Custom bucketed DDP
torchrun --nproc_per_node=4 train.py --strategy ddp_bucketed

# ZeRO optimizer sharding
torchrun --nproc_per_node=4 train.py --strategy zero

5.2 Simulating Multi-GPU with Gloo

All development was done with access to a single RTX 3060. To validate multi-process correctness without multi-GPU hardware, I used PyTorch's gloo backend, which supports CPU-based collective communication. The key insight: torchrun spawns multiple processes regardless of GPU count, and gloo allows those processes to communicate over shared memory.

def setup_distributed(backend="auto"):
    rank = int(os.environ["RANK"])
    world_size = int(os.environ["WORLD_SIZE"])
    num_gpus = torch.cuda.device_count()
    
    # Auto-select backend based on hardware
    if backend == "auto":
        backend = "nccl" if num_gpus >= world_size else "gloo"
    
    if backend == "gloo":
        # All processes share GPU 0 (simulation mode)
        dist.init_process_group(backend="gloo")
        torch.cuda.set_device(0)
        device = torch.device("cuda:0")
    else:
        # Real multi-GPU: each process gets its own device
        torch.cuda.set_device(local_rank)
        device = torch.device(f"cuda:{local_rank}")
        dist.init_process_group(backend="nccl", device_id=device)
    
    return {"rank": rank, "world_size": world_size, "device": device}

This gives us 4 processes that believe they're on separate GPUs, but all share the same physical device. The collective operations (all_reduce, broadcast, all_gather) work identically; only the transport layer differs. This means the distributed logic is fully exercised even without distributed hardware.

To run with simulated multi-GPU:

# Spawns 4 processes, all sharing GPU 0
torchrun --nproc_per_node=4 benchmark_strategies.py --quick

5.3 The Benchmark Harness

To validate all implementations systematically, I built a benchmark harness that tests each strategy under identical conditions. The harness measures three properties:

Correctness: After N training steps, are weights identical across all ranks? For data-parallel strategies, all ranks should converge to the same weights. For tensor parallelism and FSDP, weights are intentionally sharded: different sums are correct behavior.

Memory: Peak GPU memory during training. This reveals the overhead of each approach.

Throughput: Steps per second and tokens per second. Not meaningful in simulation mode (all processes compete for one GPU), but the harness supports real multi-GPU benchmarking.

The core benchmark loop:

def run_benchmark(cfg, dist_info):
    # Create model based on strategy
    if cfg.strategy == "ddp_flat":
        wrapper = DDPIndividualParameters(model)
        needs_sync = True
    elif cfg.strategy == "ddp_bucketed":
        wrapper = DDPBucketed(model, bucket_size_mb=25.0)
        needs_sync = True
    elif cfg.strategy == "tensor_parallel":
        model = TensorParallelTransformerLM(..., tp_group)
    
    # Optional: wrap optimizer with ZeRO sharding
    if cfg.use_sharded_optimizer:
        opt = ShardedOptimizer(params, torch.optim.AdamW, ...)
    
    # Training loop
    for step in range(cfg.num_steps):
        opt.zero_grad()
        for _ in range(grad_accum):
            loss = F.cross_entropy(model(x), y)
            loss.backward()
        
        if needs_sync:
            wrapper.finish_gradient_synchronization()
        
        opt.step()
    
    # Verify weight synchronization
    weight_sum = sum(p.sum() for p in model.parameters())
    all_sums = [torch.zeros(1) for _ in range(world_size)]
    dist.all_gather(all_sums, weight_sum)
    synced = all(abs(all_sums[0] - s) 0.5 for s in all_sums)

The strategies tested:

Strategy Description
ddp_flat Async per-parameter all-reduce with gradient hooks
ddp_bucketed Bucketed gradient reduction, reverse parameter order
ddp_flat + ZeRO Per-param sync with optimizer state sharding
ddp_bucketed + ZeRO Bucketed sync with optimizer state sharding
tensor_parallel Column/row parallel layers, attention head sharding
pytorch_ddp Baseline comparison
pytorch_fsdp Baseline comparison (parameter sharding)