Distributed Training Infrastructure
Overview
In our previous lesson, we explored parameter-efficient fine-tuning techniques that enable working with large language models on limited hardware. However, as model sizes continue to grow beyond even what PEFT methods can handle on a single device, distributed training becomes essential. This lesson explores how to scale training across multiple GPUs, multiple machines, and even multiple data centers.
Distributed training is what makes training models with hundreds of billions of parameters possible. Understanding these techniques allows you to work with state-of-the-art models and contribute to pushing the boundaries of what's possible in natural language processing.
Learning Objectives
After completing this lesson, you will be able to:
- Understand the fundamental challenges of distributed training for large language models
- Identify and implement appropriate parallelism strategies based on model and hardware constraints
- Set up and configure popular distributed training frameworks like DeepSpeed and PyTorch FSDP
- Optimize distributed training performance through proper hyperparameter tuning
- Implement effective monitoring and debugging strategies for distributed training
- Design resilient training systems that can recover from failures
The Need for Distributed Training
Scale Drives Progress
The past few years have demonstrated a clear trend: larger models, trained on more data, perform better on a wide range of tasks. This scaling law presents a technical challenge—how do we train these enormous models efficiently?
Analogy: Distributed Training as Coordinated Construction
Think of training a large language model like constructing a massive skyscraper:
- Single-device Training: One construction team trying to build the entire structure—impossible beyond a certain size
- Distributed Training: Multiple specialized teams working on different sections simultaneously
- Coordination Overhead: Teams need to communicate, synchronize, and integrate their work
- Resource Planning: Different tasks require different equipment and expertise
Just as a skyscraper can only reach new heights through coordinated teamwork, today's largest language models can only be trained through sophisticated distributed systems.
The Fundamental Challenges
Memory Constraints
A fundamental challenge in training large language models is memory:
-
Model Parameters: FP16 parameters require 2 bytes each
- 1B parameters = 2GB
- 100B parameters = 200GB
- 1T parameters = 2TB
-
Optimizer States: Optimizers like Adam require additional memory
- Adam needs 8 bytes per parameter (4x model size)
- 1B parameters = 10GB total
- 100B parameters = 1TB total
-
Activation Memory: Forward pass outputs needed for backpropagation
- Scales with batch size and sequence length
- Can often exceed parameter memory for large batches
-
Gradient Accumulation: Reduces memory but increases training time
Parallelism Strategies: Dividing the Problem
Types of Parallelism
To overcome these challenges, we use multiple forms of parallelism:
Data Parallelism
In data parallelism, the entire model is replicated across devices, but each processes different batches of data:
- Each device maintains a complete copy of the model
- Each device processes different data samples
- Gradients are synchronized across devices
- Model weights are updated synchronously
Distributed Data Parallel (DDP)
The standard approach to data parallelism in PyTorch:
{"tool": "code-editor", "defaultValue": "import torch import torch.nn as nn import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data.distributed import DistributedSampler
def setup(rank, world_size): """Initialize distributed process group.""" dist.init_process_group(backend='nccl', init_method='tcp://localhost:12355', world_size=world_size, rank=rank)
def cleanup(): """Clean up distributed process group.""" dist.destroy_process_group()
class SimpleModel(nn.Module): def init(self): super().init() self.layers = nn.Sequential( nn.Linear(768, 3072), nn.GELU(), nn.Linear(3072, 768) )
1 def forward(self, x): 2 return self.layers(x)
def train(rank, world_size): # Initialize distributed training setup(rank, world_size)
1 # Create model and move to GPU 2 model = SimpleModel().to(rank) 3
4 # Wrap model with DDP 5 ddp_model = DDP(model, device_ids=[rank]) 6
7 # Create optimizer 8 optimizer = torch.optim.AdamW(ddp_model.parameters(), lr=5e-5) 9
10 # Create dataset and sampler 11 dataset = YourDataset() # Replace with your dataset 12 sampler = DistributedSampler(dataset, 13 num_replicas=world_size, 14 rank=rank) 15 dataloader = DataLoader(dataset, 16 batch_size=32, 17 sampler=sampler) 18
19 # Training loop 20 for epoch in range(num_epochs): 21 sampler.set_epoch(epoch) # Important for proper shuffling 22 for batch in dataloader: 23 inputs, labels = batch 24 inputs, labels = inputs.to(rank), labels.to(rank) 25 26 # Forward pass 27 outputs = ddp_model(inputs) 28 loss = criterion(outputs, labels) 29 30 # Backward pass and optimize 31 optimizer.zero_grad() 32 loss.backward() 33 optimizer.step() 34
35 cleanup()
Launch with: torch.multiprocessing.spawn(train, args=(world_size,), nprocs=world_size)"}
Zero Redundancy Optimizer (ZeRO)
ZeRO enhances data parallelism by partitioning optimizer states, gradients, and even parameters across GPUs:
- ZeRO Stage 1: Partitions optimizer states (momentum, variance)
- ZeRO Stage 2: Partitions gradients as well
- ZeRO Stage 3: Partitions model parameters too
This strategy dramatically reduces memory overhead while maintaining the simplicity of data parallelism.
Model Parallelism
When a model is too large to fit on a single device, we can split it across multiple devices:
Tensor Parallelism (TP)
Tensor parallelism splits individual layers across devices:
- Divides matrix operations across devices
- Particularly effective for large matrix multiplications in attention layers
- Requires frequent communication between devices
- Typically limited to devices within the same node (due to communication overhead)
Pipeline Parallelism (PP)
Pipeline parallelism divides the model sequentially into stages:
- Different layers run on different devices
- Data flows through the pipeline in micro-batches
- Reduces communication overhead compared to tensor parallelism
- Can span across multiple nodes, even with slower interconnects
3D Parallelism
Modern training frameworks combine all three forms of parallelism:
- Data Parallelism: Process different batches in parallel
- Tensor Parallelism: Split individual layers across devices
- Pipeline Parallelism: Split model vertically across devices
This 3D approach allows for tremendous flexibility in scaling models across different hardware configurations.
Modern Distributed Training Frameworks
PyTorch Fully Sharded Data Parallel (FSDP)
PyTorch's native implementation of ZeRO-like sharding:
-
Key Features:
- Parameter sharding to minimize memory usage
- Efficient communication patterns
- Integration with PyTorch ecosystem
- Supports mixed precision training
-
Implementation Approach:
- Shards parameters, gradients, and optimizer states
- Reshards during forward and backward passes
- Uses efficient all-gather and reduce-scatter operations
{"tool": "code-editor", "defaultValue": "import torch import torch.nn as nn import torch.distributed as dist from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.fully_sharded_data_parallel import ( CPUOffload, BackwardPrefetch, ) from torch.distributed.fsdp.wrap import ( size_based_auto_wrap_policy, enable_wrap, wrap, )
Initialize distributed environment
dist.init_process_group(backend="nccl") local_rank = int(os.environ["LOCAL_RANK"]) torch.cuda.set_device(local_rank)
Define your model
model = YourLargeTransformerModel()
Configure FSDP wrapping policy
auto_wrap_policy = size_based_auto_wrap_policy( min_num_params=100_000, # Layers with >100K params will be wrapped separately excluded_wrap_modules={nn.Embedding} # Don't wrap embedding layers separately )
Wrap your model with FSDP
model = FSDP( model, auto_wrap_policy=auto_wrap_policy, device_id=torch.cuda.current_device(), cpu_offload=CPUOffload(offload_params=True), # Offload parameters to CPU when not in use backward_prefetch=BackwardPrefetch.BACKWARD_PRE, # Prefetch next layer's params during backward pass mixed_precision=True, # Enable mixed precision training )
Standard training loop
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
for epoch in range(num_epochs): for batch in dataloader: inputs, labels = batch inputs, labels = inputs.cuda(), labels.cuda()
1 outputs = model(inputs) 2 loss = criterion(outputs, labels) 3 4 optimizer.zero_grad() 5 loss.backward() 6 optimizer.step()"}
DeepSpeed
Microsoft's comprehensive library for efficient distributed training:
-
Key Features:
- ZeRO optimizer for memory efficiency
- 3D parallelism support
- 1-bit Adam/LAMB for communication efficiency
- Curriculum learning and other training optimizations
- Inference optimizations
-
ZeRO-Infinity:
- Extends memory capacity to NVMe storage
- Allows training models that exceed total GPU memory
- Smart offloading of tensors to CPU and NVMe
{"tool": "code-editor", "defaultValue": "# DeepSpeed configuration ds_config = { "train_batch_size": 32 * torch.cuda.device_count(), "gradient_accumulation_steps": 4, "optimizer": { "type": "Adam", "params": { "lr": 1e-5, "weight_decay": 0.01, "bias_correction": True } }, "fp16": { "enabled": True, "loss_scale": 0, "loss_scale_window": 1000, "initial_scale_power": 16, "hysteresis": 2, "min_loss_scale": 1 }, "zero_optimization": { "stage": 3, "offload_optimizer": { "device": "cpu", "pin_memory": True }, "offload_param": { "device": "cpu", "pin_memory": True }, "overlap_comm": True, "contiguous_gradients": True, "reduce_bucket_size": 5e8, "stage3_prefetch_bucket_size": 5e8, "stage3_param_persistence_threshold": 1e6 }, "gradient_clipping": 1.0, "steps_per_print": 50, "wall_clock_breakdown": False }
Import and initialize DeepSpeed
import deepspeed from transformers import AutoModelForCausalLM, AutoTokenizer from deepspeed.ops.adam import FusedAdam
Initialize model and tokenizer
model = AutoModelForCausalLM.from_pretrained('gpt2-large') tokenizer = AutoTokenizer.from_pretrained('gpt2-large')
Prepare for DeepSpeed
parameters = filter(lambda p: p.requires_grad, model.parameters()) model_engine, optimizer, _, _ = deepspeed.initialize( model=model, model_parameters=parameters, config=ds_config )
Training loop
for epoch in range(num_epochs): for batch in train_dataloader: inputs = tokenizer(batch['text'], return_tensors='pt', padding=True, truncation=True).to(model_engine.device) outputs = model_engine(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'], labels=inputs['input_ids']) loss = outputs.loss
1 # DeepSpeed handles loss scaling and backward 2 model_engine.backward(loss) 3 model_engine.step()"}
Comparison of Frameworks
This interactive tool is still under development. Check back later!
Practical Implementation
Hardware Considerations
Choosing the right hardware configuration is crucial:
-
GPU Interconnect:
- NVLink (600GB/s): Best for tensor parallelism
- PCIe 4.0 (32GB/s): Sufficient for pipeline parallelism
- InfiniBand (200Gb/s): Good for cross-node communication
-
Memory Hierarchy:
- HBM (GPU Memory): Fastest, most limited
- CPU RAM: 10x slower, much larger
- NVMe: 100x slower, immense capacity
-
GPU Selection:
- A100s/H100s: Best for large-scale training
- Consumer GPUs: Viable with memory-efficient techniques
Optimizing Communication
Communication often becomes the bottleneck in distributed training:
-
Gradient Accumulation:
- Reduces communication frequency
- Trades off training time for efficiency
-
Gradient Compression:
- 1-bit Adam: Quantize gradient updates
- PowerSGD: Low-rank approximation of gradients
-
Optimized Collectives:
- NCCL: Optimized for GPU communication
- AllReduce, AllGather, ReduceScatter operations
Monitoring Distributed Training
{"tool": "code-editor", "defaultValue": "import torch import time import wandb from torch.utils.tensorboard import SummaryWriter from torch.distributed import get_rank
class DistributedTrainingMonitor: def init(self, model, log_dir, project_name, config=None): self.start_time = time.time() self.rank = get_rank() if torch.distributed.is_initialized() else 0 self.is_main_process = self.rank == 0
1 # Only initialize loggers on the main process 2 if self.is_main_process: 3 self.tb_writer = SummaryWriter(log_dir) 4 wandb.init(project=project_name, config=config) 5 6 # For monitoring GPU metrics 7 self.model = model 8 self.step_times = [] 9 self.comm_times = [] 10 self.forward_times = [] 11 self.backward_times = [] 12 13 def log_step(self, loss, lr, step, comm_time=None, forward_time=None, backward_time=None): 14 # Basic metrics 15 metrics = { 16 'loss': loss, 17 'learning_rate': lr, 18 'step': step, 19 'samples_per_second': self.get_throughput(), 20 } 21 22 # Advanced timing metrics 23 if comm_time: 24 self.comm_times.append(comm_time) 25 metrics['communication_time'] = comm_time 26 metrics['communication_time_avg'] = sum(self.comm_times[-100:]) / len(self.comm_times[-100:]) 27 28 if forward_time: 29 self.forward_times.append(forward_time) 30 metrics['forward_time'] = forward_time 31 32 if backward_time: 33 self.backward_times.append(backward_time) 34 metrics['backward_time'] = backward_time 35 36 # GPU memory metrics 37 metrics.update(self.get_gpu_metrics()) 38 39 # Log only from main process 40 if self.is_main_process: 41 wandb.log(metrics) 42 for key, value in metrics.items(): 43 self.tb_writer.add_scalar(key, value, step) 44
45 def get_gpu_metrics(self): 46 metrics = {} 47 48 # Current GPU memory usage 49 metrics['gpu_memory_allocated'] = torch.cuda.memory_allocated() / 1e9 # GB 50 metrics['gpu_memory_reserved'] = torch.cuda.memory_reserved() / 1e9 # GB 51 52 # Peak memory stats 53 metrics['gpu_max_memory_allocated'] = torch.cuda.max_memory_allocated() / 1e9 # GB 54 55 # If using FSDP or DeepSpeed, add model-specific memory metrics 56 if hasattr(self.model, 'get_memory_footprint'): 57 metrics['model_memory_footprint'] = self.model.get_memory_footprint() / 1e9 # GB 58 59 return metrics 60
61 def get_throughput(self): 62 elapsed = time.time() - self.start_time 63 if not hasattr(self, 'total_samples'): 64 return 0 65 return self.total_samples / elapsed 66
67 def update_samples(self, batch_size): 68 if not hasattr(self, 'total_samples'): 69 self.total_samples = 0 70 self.total_samples += batch_size * torch.distributed.get_world_size() 71
72 def close(self): 73 if self.is_main_process: 74 self.tb_writer.close() 75 wandb.finish()"}
Key Metrics to Monitor
-
Training Efficiency:
- Samples per second (throughput)
- GPU utilization
- Communication overhead
- Load imbalance
-
Memory Usage:
- GPU memory allocated
- Memory breakdown (parameters, activations, gradients)
- Out-of-memory events
- Memory fragmentation
-
Scaling Efficiency:
- Linear scaling factor
- Communication-to-computation ratio
- Device idle time
Advanced Techniques and Optimizations
Activation Checkpointing
Trades computation for memory by recomputing activations during backward pass:
-
How it Works:
- Save activations only at specific checkpoints
- Recompute intermediate activations during backward pass
- Reduces memory at the cost of additional computation
-
Implementation:
- Available in PyTorch, DeepSpeed, FSDP
- Often applied to transformer blocks
{"tool": "code-editor", "defaultValue": "import torch from torch.utils.checkpoint import checkpoint
class CheckpointedTransformerLayer(nn.Module): def init(self, hidden_size, num_attention_heads, intermediate_size): super().init() self.attention = SelfAttention(hidden_size, num_attention_heads) self.intermediate = nn.Linear(hidden_size, intermediate_size) self.output = nn.Linear(intermediate_size, hidden_size)
1 def forward(self, hidden_states, attention_mask=None): 2 def custom_forward(hidden_states, attention_mask): 3 attention_output = self.attention(hidden_states, attention_mask) 4 intermediate_output = self.intermediate(attention_output) 5 layer_output = self.output(intermediate_output) 6 return layer_output 7 8 # Use checkpointing to save memory 9 return checkpoint(custom_forward, hidden_states, attention_mask)"}
Mixed Precision Training
Using lower precision (FP16/BF16) for most operations while maintaining stability:
-
FP16 vs BF16:
- FP16: More precision near zero, limited range
- BF16: Same range as FP32, less precision
- BF16 preferred for LLMs due to range requirements
-
Numeric Stability Techniques:
- Loss scaling to prevent gradient underflow
- Master weights in FP32
- Selective operations in higher precision
Optimized Kernels and Operators
Specialized implementations for common operations:
-
Flash Attention:
- Memory-efficient attention implementation
- IO-aware algorithm that minimizes HBM access
- 2-3x speedup for attention computation
-
Fused Operators:
- Combine multiple operations into single kernels
- Reduce memory traffic and kernel launch overhead
- Examples: fused GELU, fused LayerNorm, fused AdamW
Fault Tolerance and Elastic Training
Strategies for resilient training at scale:
-
Checkpoint-based Recovery:
- Regular state saving
- Resume training from last checkpoint
- DeepSpeed and FSDP support
-
Elastic Training:
- Adapt to changing resource availability
- Add or remove nodes during training
- Particularly useful for preemptible instances
Case Studies and Real-world Examples
Training GPT-3 175B
OpenAI's approach used a combination of techniques:
- Data Parallelism: 96-way across 96 machines
- Model Parallelism: 8-way tensor parallelism
- Custom Distributed Communication: Optimized for Azure
- Mixed Precision: FP16 with dynamic loss scaling
- Memory Optimization: Attention optimizations
Training LLaMA 65B
Meta's approach:
- Distributed Training: 2048 A100 GPUs
- NCCL Communication: Optimized for GPU clusters
- Efficient Architecture: Fewer non-linearities, no biases in dense layers
- Data Processing: Tokenizer with larger vocabulary to reduce sequence lengths
Training Mixtral 8x7B
Mistral AI's mixture-of-experts approach:
- Sparse MoE Architecture: Only activates specific experts per token
- Expert Parallelism: Distribute experts across GPUs
- Load Balancing: Ensure even utilization of experts
- Memory Optimization: Share parameters across experts where appropriate
Practical Exercises
Exercise 1: Setting Up FSDP
Implement a distributed training setup using PyTorch FSDP:
- Set up a multi-GPU environment
- Configure FSDP with the appropriate policies
- Implement gradient accumulation and mixed precision
- Benchmark performance and memory usage
Exercise 2: DeepSpeed Configuration
Create a DeepSpeed configuration for different scenarios:
- Single-node, multi-GPU training with ZeRO-3
- Multi-node training with pipeline parallelism
- Configure offloading for training with limited GPU memory
- Measure throughput and compare with baseline PyTorch implementation
Exercise 3: Distributed Training Diagnostics
Build a diagnostic tool to identify bottlenecks in distributed training:
- Monitor communication vs. computation time
- Analyze GPU memory usage patterns
- Detect load imbalances across GPUs
- Recommend optimizations based on collected metrics
Conclusion
Distributed training infrastructure is what makes today's largest and most capable language models possible. As models continue to grow, efficient distributed training becomes increasingly important for both research and production deployments.
The field is rapidly evolving, with new techniques and optimizations emerging regularly. Understanding the fundamentals of data, tensor, and pipeline parallelism provides a foundation for adopting these new approaches as they become available.
In our next lesson, we'll explore preference alignment techniques that help ensure these powerful large models behave according to human preferences and values.
Additional Resources
Papers
- "Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism" (Shoeybi et al., 2019)
- "ZeRO: Memory Optimizations Toward Training Trillion Parameter Models" (Rajbhandari et al., 2020)
- "Efficiently Scaling Transformer Inference" (Pope et al., 2022)
- "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness" (Dao et al., 2022)
Libraries and Tools
- DeepSpeed
- PyTorch FSDP
- Megatron-LM
- Alpa (Automatic Parallelism)