Part 3 of 6 in the Distributed Training series
3.1 Bucketed Gradient Reduction as the Second Scaling Primitive
Even after moving to explicit gradient synchronization, latency dominated because each parameter triggered a separate all-reduce. To reduce the number of collectives, I implemented bucketed gradient reduction.
Gradients are grouped into contiguous buckets (e.g., 25 MB each), flattened into a single tensor, and reduced asynchronously. Bucketing parameters in reverse order aligns with PyTorch’s backprop order, allowing communication for early layers to overlap with gradient computation in later layers.
class DDPBucketed:
def __init__(self, model, bucket_size_mb=25.0):
self.bucket_size_bytes = int(bucket_size_mb * 1024 * 1024)
self.buckets = []
current_bucket = []
current_size = 0
for param in reversed(list(model.parameters())):
param_size = param.numel() * param.element_size()
if current_size + param_size > self.bucket_size_bytes and current_bucket:
self.buckets.append(current_bucket)
current_bucket = []
current_size = 0
current_bucket.append(param)
current_size += param_size
if current_bucket:
self.buckets.append(current_bucket) When a bucket is ready, gradients are flattened into a single tensor and reduced asynchronously:
def _reduce_bucket(self, bucket_idx: int):
bucket = self.buckets[bucket_idx]
# Flatten all gradients into a single contiguous tensor
flat_grads = torch.cat([p.grad.view(-1) for p in bucket])
# Launch async all-reduce
handle = dist.all_reduce(flat_grads, op=dist.ReduceOp.AVG, async_op=True)
# Store handle to wait on later
self._pending_ops.append((handle, bucket, flat_grads)) Impact: Fewer collectives per step, overlapping communication with computation, and no change to optimizer semantics. This was a critical step toward scaling transformers with hundreds of parameters per layer.
3.2 Optimizer Sharding as the Third Scaling Primitive
Even with data parallelism and bucketed reduction, each process still held a full copy of optimizer state. For Adam-style optimizers, this quickly dominates memory usage:
| Component | Memory per Parameter |
|---|---|
| Parameters | 4 bytes (fp32) |
| Gradients | 4 bytes |
| Adam momentum | 4 bytes |
| Adam variance | 4 bytes |
| Total | 16 bytes |
For a 1B parameter model: 1B × 16 bytes = 16GB for parameters, gradients, and optimizer state before activations.
To address this, I implemented ZeRO-style optimizer state sharding:
- Shard optimizer state across ranks.
- Reduce gradients locally per shard, then perform optimizer steps only on owned parameters.
- All-gather parameters post-update to keep replicas consistent.
class ShardedOptimizer:
def __init__(self, params, optimizer_cls, **optimizer_kwargs):
self.params = list(params)
self.world_size = dist.get_world_size()
self.rank = dist.get_rank()
# Shard parameters across ranks
self.param_groups = self._shard_params()
# Each rank creates optimizer only for its shard
self.optimizer = optimizer_cls(
self.param_groups[self.rank],
**optimizer_kwargs
)
def step(self):
# 1. Reduce gradients (each rank gets its shard's grads)
for rank_idx, shard in enumerate(self.param_groups):
for param in shard:
if param.grad is not None:
dist.reduce(param.grad, dst=rank_idx, op=dist.ReduceOp.SUM)
if self.rank == rank_idx:
param.grad.div_(self.world_size)
# 2. Local optimizer step (only on my shard)
self.optimizer.step()
# 3. All-gather updated parameters
for shard in self.param_groups:
for param in shard:
dist.broadcast(param.data, src=self._get_param_owner(param)) This reduced per-process optimizer memory roughly proportional to world size, trading additional communication for lower memory pressure.
With 4 GPUs, optimizer state memory drops from 8GB to ~2GB per GPU, enabling training of billion-parameter models without running out of memory at the cost of additional communication.