↼ Back to Thoughts

3. Bucketing and Sharding as the Second and Third Scaling Primitives

Reducing communication overhead and memory usage.

Part 3 of 6 in the Distributed Training series

3.1 Bucketed Gradient Reduction as the Second Scaling Primitive

Even after moving to explicit gradient synchronization, latency dominated because each parameter triggered a separate all-reduce. To reduce the number of collectives, I implemented bucketed gradient reduction.

Gradients are grouped into contiguous buckets (e.g., 25 MB each), flattened into a single tensor, and reduced asynchronously. Bucketing parameters in reverse order aligns with PyTorch’s backprop order, allowing communication for early layers to overlap with gradient computation in later layers.

class DDPBucketed:
    def __init__(self, model, bucket_size_mb=25.0):
        self.bucket_size_bytes = int(bucket_size_mb * 1024 * 1024)
        self.buckets = []

        current_bucket = []
        current_size = 0

        for param in reversed(list(model.parameters())):
            param_size = param.numel() * param.element_size()

            if current_size + param_size > self.bucket_size_bytes and current_bucket:
                self.buckets.append(current_bucket)
                current_bucket = []
                current_size = 0

            current_bucket.append(param)
            current_size += param_size

        if current_bucket:
            self.buckets.append(current_bucket)

When a bucket is ready, gradients are flattened into a single tensor and reduced asynchronously:

def _reduce_bucket(self, bucket_idx: int):
    bucket = self.buckets[bucket_idx]
    
    # Flatten all gradients into a single contiguous tensor
    flat_grads = torch.cat([p.grad.view(-1) for p in bucket])
    
    # Launch async all-reduce
    handle = dist.all_reduce(flat_grads, op=dist.ReduceOp.AVG, async_op=True)
    
    # Store handle to wait on later
    self._pending_ops.append((handle, bucket, flat_grads))

Impact: Fewer collectives per step, overlapping communication with computation, and no change to optimizer semantics. This was a critical step toward scaling transformers with hundreds of parameters per layer.

3.2 Optimizer Sharding as the Third Scaling Primitive

Even with data parallelism and bucketed reduction, each process still held a full copy of optimizer state. For Adam-style optimizers, this quickly dominates memory usage:

Component Memory per Parameter
Parameters 4 bytes (fp32)
Gradients 4 bytes
Adam momentum 4 bytes
Adam variance 4 bytes
Total 16 bytes

For a 1B parameter model: 1B × 16 bytes = 16GB for parameters, gradients, and optimizer state before activations.

To address this, I implemented ZeRO-style optimizer state sharding:

  • Shard optimizer state across ranks.
  • Reduce gradients locally per shard, then perform optimizer steps only on owned parameters.
  • All-gather parameters post-update to keep replicas consistent.
class ShardedOptimizer:
    def __init__(self, params, optimizer_cls, **optimizer_kwargs):
        self.params = list(params)
        self.world_size = dist.get_world_size()
        self.rank = dist.get_rank()
        
        # Shard parameters across ranks
        self.param_groups = self._shard_params()
        
        # Each rank creates optimizer only for its shard
        self.optimizer = optimizer_cls(
            self.param_groups[self.rank], 
            **optimizer_kwargs
        )
    
    def step(self):
        # 1. Reduce gradients (each rank gets its shard's grads)
        for rank_idx, shard in enumerate(self.param_groups):
            for param in shard:
                if param.grad is not None:
                    dist.reduce(param.grad, dst=rank_idx, op=dist.ReduceOp.SUM)
                    if self.rank == rank_idx:
                        param.grad.div_(self.world_size)
        
        # 2. Local optimizer step (only on my shard)
        self.optimizer.step()
        
        # 3. All-gather updated parameters
        for shard in self.param_groups:
            for param in shard:
                dist.broadcast(param.data, src=self._get_param_owner(param))

This reduced per-process optimizer memory roughly proportional to world size, trading additional communication for lower memory pressure.

With 4 GPUs, optimizer state memory drops from 8GB to ~2GB per GPU, enabling training of billion-parameter models without running out of memory at the cost of additional communication.