Distributed Training Infrastructure

Overview

In our previous lesson, we explored parameter-efficient fine-tuning techniques that enable working with large language models on limited hardware. However, as model sizes continue to grow beyond even what PEFT methods can handle on a single device, distributed training becomes essential. This lesson explores how to scale training across multiple GPUs, multiple machines, and even multiple data centers.

Distributed training is what makes training models with hundreds of billions of parameters possible. Understanding these techniques allows you to work with state-of-the-art models and contribute to pushing the boundaries of what's possible in natural language processing.

Learning Objectives

After completing this lesson, you will be able to:

  • Understand the fundamental challenges of distributed training for large language models
  • Identify and implement appropriate parallelism strategies based on model and hardware constraints
  • Set up and configure popular distributed training frameworks like DeepSpeed and PyTorch FSDP
  • Optimize distributed training performance through proper hyperparameter tuning
  • Implement effective monitoring and debugging strategies for distributed training
  • Design resilient training systems that can recover from failures

The Need for Distributed Training

Scale Drives Progress

The past few years have demonstrated a clear trend: larger models, trained on more data, perform better on a wide range of tasks. This scaling law presents a technical challenge—how do we train these enormous models efficiently?

Analogy: Distributed Training as Coordinated Construction

Think of training a large language model like constructing a massive skyscraper:

  • Single-device Training: One construction team trying to build the entire structure—impossible beyond a certain size
  • Distributed Training: Multiple specialized teams working on different sections simultaneously
  • Coordination Overhead: Teams need to communicate, synchronize, and integrate their work
  • Resource Planning: Different tasks require different equipment and expertise

Just as a skyscraper can only reach new heights through coordinated teamwork, today's largest language models can only be trained through sophisticated distributed systems.

The Fundamental Challenges

Loading tool...

Memory Constraints

A fundamental challenge in training large language models is memory:

  1. Model Parameters: FP16 parameters require 2 bytes each

    • 1B parameters = 2GB
    • 100B parameters = 200GB
    • 1T parameters = 2TB
  2. Optimizer States: Optimizers like Adam require additional memory

    • Adam needs 8 bytes per parameter (4x model size)
    • 1B parameters = 10GB total
    • 100B parameters = 1TB total
  3. Activation Memory: Forward pass outputs needed for backpropagation

    • Scales with batch size and sequence length
    • Can often exceed parameter memory for large batches
  4. Gradient Accumulation: Reduces memory but increases training time

Parallelism Strategies: Dividing the Problem

Types of Parallelism

To overcome these challenges, we use multiple forms of parallelism:

Loading tool...

Data Parallelism

In data parallelism, the entire model is replicated across devices, but each processes different batches of data:

  1. Each device maintains a complete copy of the model
  2. Each device processes different data samples
  3. Gradients are synchronized across devices
  4. Model weights are updated synchronously

Distributed Data Parallel (DDP)

The standard approach to data parallelism in PyTorch:

{"tool": "code-editor", "defaultValue": "import torch import torch.nn as nn import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data.distributed import DistributedSampler

def setup(rank, world_size): """Initialize distributed process group.""" dist.init_process_group(backend='nccl', init_method='tcp://localhost:12355', world_size=world_size, rank=rank)

def cleanup(): """Clean up distributed process group.""" dist.destroy_process_group()

class SimpleModel(nn.Module): def init(self): super().init() self.layers = nn.Sequential( nn.Linear(768, 3072), nn.GELU(), nn.Linear(3072, 768) )

1
def forward(self, x):
2
return self.layers(x)

def train(rank, world_size): # Initialize distributed training setup(rank, world_size)

1
# Create model and move to GPU
2
model = SimpleModel().to(rank)
3
4
# Wrap model with DDP
5
ddp_model = DDP(model, device_ids=[rank])
6
7
# Create optimizer
8
optimizer = torch.optim.AdamW(ddp_model.parameters(), lr=5e-5)
9
10
# Create dataset and sampler
11
dataset = YourDataset() # Replace with your dataset
12
sampler = DistributedSampler(dataset,
13
num_replicas=world_size,
14
rank=rank)
15
dataloader = DataLoader(dataset,
16
batch_size=32,
17
sampler=sampler)
18
19
# Training loop
20
for epoch in range(num_epochs):
21
sampler.set_epoch(epoch) # Important for proper shuffling
22
for batch in dataloader:
23
inputs, labels = batch
24
inputs, labels = inputs.to(rank), labels.to(rank)
25
26
# Forward pass
27
outputs = ddp_model(inputs)
28
loss = criterion(outputs, labels)
29
30
# Backward pass and optimize
31
optimizer.zero_grad()
32
loss.backward()
33
optimizer.step()
34
35
cleanup()

Launch with: torch.multiprocessing.spawn(train, args=(world_size,), nprocs=world_size)"}

Zero Redundancy Optimizer (ZeRO)

ZeRO enhances data parallelism by partitioning optimizer states, gradients, and even parameters across GPUs:

  1. ZeRO Stage 1: Partitions optimizer states (momentum, variance)
  2. ZeRO Stage 2: Partitions gradients as well
  3. ZeRO Stage 3: Partitions model parameters too

This strategy dramatically reduces memory overhead while maintaining the simplicity of data parallelism.

Loading tool...

Model Parallelism

When a model is too large to fit on a single device, we can split it across multiple devices:

Tensor Parallelism (TP)

Tensor parallelism splits individual layers across devices:

  1. Divides matrix operations across devices
  2. Particularly effective for large matrix multiplications in attention layers
  3. Requires frequent communication between devices
  4. Typically limited to devices within the same node (due to communication overhead)

Pipeline Parallelism (PP)

Pipeline parallelism divides the model sequentially into stages:

  1. Different layers run on different devices
  2. Data flows through the pipeline in micro-batches
  3. Reduces communication overhead compared to tensor parallelism
  4. Can span across multiple nodes, even with slower interconnects
Loading tool...

3D Parallelism

Modern training frameworks combine all three forms of parallelism:

  • Data Parallelism: Process different batches in parallel
  • Tensor Parallelism: Split individual layers across devices
  • Pipeline Parallelism: Split model vertically across devices

This 3D approach allows for tremendous flexibility in scaling models across different hardware configurations.

Modern Distributed Training Frameworks

PyTorch Fully Sharded Data Parallel (FSDP)

PyTorch's native implementation of ZeRO-like sharding:

  1. Key Features:

    • Parameter sharding to minimize memory usage
    • Efficient communication patterns
    • Integration with PyTorch ecosystem
    • Supports mixed precision training
  2. Implementation Approach:

    • Shards parameters, gradients, and optimizer states
    • Reshards during forward and backward passes
    • Uses efficient all-gather and reduce-scatter operations

{"tool": "code-editor", "defaultValue": "import torch import torch.nn as nn import torch.distributed as dist from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.fully_sharded_data_parallel import ( CPUOffload, BackwardPrefetch, ) from torch.distributed.fsdp.wrap import ( size_based_auto_wrap_policy, enable_wrap, wrap, )

Initialize distributed environment

dist.init_process_group(backend="nccl") local_rank = int(os.environ["LOCAL_RANK"]) torch.cuda.set_device(local_rank)

Define your model

model = YourLargeTransformerModel()

Configure FSDP wrapping policy

auto_wrap_policy = size_based_auto_wrap_policy( min_num_params=100_000, # Layers with >100K params will be wrapped separately excluded_wrap_modules={nn.Embedding} # Don't wrap embedding layers separately )

Wrap your model with FSDP

model = FSDP( model, auto_wrap_policy=auto_wrap_policy, device_id=torch.cuda.current_device(), cpu_offload=CPUOffload(offload_params=True), # Offload parameters to CPU when not in use backward_prefetch=BackwardPrefetch.BACKWARD_PRE, # Prefetch next layer's params during backward pass mixed_precision=True, # Enable mixed precision training )

Standard training loop

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)

for epoch in range(num_epochs): for batch in dataloader: inputs, labels = batch inputs, labels = inputs.cuda(), labels.cuda()

1
outputs = model(inputs)
2
loss = criterion(outputs, labels)
3
4
optimizer.zero_grad()
5
loss.backward()
6
optimizer.step()"}

DeepSpeed

Microsoft's comprehensive library for efficient distributed training:

  1. Key Features:

    • ZeRO optimizer for memory efficiency
    • 3D parallelism support
    • 1-bit Adam/LAMB for communication efficiency
    • Curriculum learning and other training optimizations
    • Inference optimizations
  2. ZeRO-Infinity:

    • Extends memory capacity to NVMe storage
    • Allows training models that exceed total GPU memory
    • Smart offloading of tensors to CPU and NVMe

{"tool": "code-editor", "defaultValue": "# DeepSpeed configuration ds_config = { "train_batch_size": 32 * torch.cuda.device_count(), "gradient_accumulation_steps": 4, "optimizer": { "type": "Adam", "params": { "lr": 1e-5, "weight_decay": 0.01, "bias_correction": True } }, "fp16": { "enabled": True, "loss_scale": 0, "loss_scale_window": 1000, "initial_scale_power": 16, "hysteresis": 2, "min_loss_scale": 1 }, "zero_optimization": { "stage": 3, "offload_optimizer": { "device": "cpu", "pin_memory": True }, "offload_param": { "device": "cpu", "pin_memory": True }, "overlap_comm": True, "contiguous_gradients": True, "reduce_bucket_size": 5e8, "stage3_prefetch_bucket_size": 5e8, "stage3_param_persistence_threshold": 1e6 }, "gradient_clipping": 1.0, "steps_per_print": 50, "wall_clock_breakdown": False }

Import and initialize DeepSpeed

import deepspeed from transformers import AutoModelForCausalLM, AutoTokenizer from deepspeed.ops.adam import FusedAdam

Initialize model and tokenizer

model = AutoModelForCausalLM.from_pretrained('gpt2-large') tokenizer = AutoTokenizer.from_pretrained('gpt2-large')

Prepare for DeepSpeed

parameters = filter(lambda p: p.requires_grad, model.parameters()) model_engine, optimizer, _, _ = deepspeed.initialize( model=model, model_parameters=parameters, config=ds_config )

Training loop

for epoch in range(num_epochs): for batch in train_dataloader: inputs = tokenizer(batch['text'], return_tensors='pt', padding=True, truncation=True).to(model_engine.device) outputs = model_engine(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'], labels=inputs['input_ids']) loss = outputs.loss

1
# DeepSpeed handles loss scaling and backward
2
model_engine.backward(loss)
3
model_engine.step()"}

Comparison of Frameworks

trade-off-visualizer - Coming Soon

This interactive tool is still under development. Check back later!

Tool configuration: {"defaultValue":{"xAxis":{"label":"Ease of Use","min":0,"max":10,"step":1},"yAxis":{"label":"Feature Richness","min":0,"max":10},"curves":[{"label":"PyTorch DDP","values":[8,4],"description":"Simple but limited features"},{"label":"PyTorch FSDP","values":[6,7],"description":"Good balance, native PyTorch"},{"label":"DeepSpeed","values":[5,9],"description":"Comprehensive but complex"},{"label":"Megatron-LM","values":[3,8],"description":"Specialized for very large models"}]}}

Practical Implementation

Hardware Considerations

Choosing the right hardware configuration is crucial:

  1. GPU Interconnect:

    • NVLink (600GB/s): Best for tensor parallelism
    • PCIe 4.0 (32GB/s): Sufficient for pipeline parallelism
    • InfiniBand (200Gb/s): Good for cross-node communication
  2. Memory Hierarchy:

    • HBM (GPU Memory): Fastest, most limited
    • CPU RAM: 10x slower, much larger
    • NVMe: 100x slower, immense capacity
  3. GPU Selection:

    • A100s/H100s: Best for large-scale training
    • Consumer GPUs: Viable with memory-efficient techniques

Optimizing Communication

Communication often becomes the bottleneck in distributed training:

  1. Gradient Accumulation:

    • Reduces communication frequency
    • Trades off training time for efficiency
  2. Gradient Compression:

    • 1-bit Adam: Quantize gradient updates
    • PowerSGD: Low-rank approximation of gradients
  3. Optimized Collectives:

    • NCCL: Optimized for GPU communication
    • AllReduce, AllGather, ReduceScatter operations

Monitoring Distributed Training

{"tool": "code-editor", "defaultValue": "import torch import time import wandb from torch.utils.tensorboard import SummaryWriter from torch.distributed import get_rank

class DistributedTrainingMonitor: def init(self, model, log_dir, project_name, config=None): self.start_time = time.time() self.rank = get_rank() if torch.distributed.is_initialized() else 0 self.is_main_process = self.rank == 0

1
# Only initialize loggers on the main process
2
if self.is_main_process:
3
self.tb_writer = SummaryWriter(log_dir)
4
wandb.init(project=project_name, config=config)
5
6
# For monitoring GPU metrics
7
self.model = model
8
self.step_times = []
9
self.comm_times = []
10
self.forward_times = []
11
self.backward_times = []
12
13
def log_step(self, loss, lr, step, comm_time=None, forward_time=None, backward_time=None):
14
# Basic metrics
15
metrics = {
16
'loss': loss,
17
'learning_rate': lr,
18
'step': step,
19
'samples_per_second': self.get_throughput(),
20
}
21
22
# Advanced timing metrics
23
if comm_time:
24
self.comm_times.append(comm_time)
25
metrics['communication_time'] = comm_time
26
metrics['communication_time_avg'] = sum(self.comm_times[-100:]) / len(self.comm_times[-100:])
27
28
if forward_time:
29
self.forward_times.append(forward_time)
30
metrics['forward_time'] = forward_time
31
32
if backward_time:
33
self.backward_times.append(backward_time)
34
metrics['backward_time'] = backward_time
35
36
# GPU memory metrics
37
metrics.update(self.get_gpu_metrics())
38
39
# Log only from main process
40
if self.is_main_process:
41
wandb.log(metrics)
42
for key, value in metrics.items():
43
self.tb_writer.add_scalar(key, value, step)
44
45
def get_gpu_metrics(self):
46
metrics = {}
47
48
# Current GPU memory usage
49
metrics['gpu_memory_allocated'] = torch.cuda.memory_allocated() / 1e9 # GB
50
metrics['gpu_memory_reserved'] = torch.cuda.memory_reserved() / 1e9 # GB
51
52
# Peak memory stats
53
metrics['gpu_max_memory_allocated'] = torch.cuda.max_memory_allocated() / 1e9 # GB
54
55
# If using FSDP or DeepSpeed, add model-specific memory metrics
56
if hasattr(self.model, 'get_memory_footprint'):
57
metrics['model_memory_footprint'] = self.model.get_memory_footprint() / 1e9 # GB
58
59
return metrics
60
61
def get_throughput(self):
62
elapsed = time.time() - self.start_time
63
if not hasattr(self, 'total_samples'):
64
return 0
65
return self.total_samples / elapsed
66
67
def update_samples(self, batch_size):
68
if not hasattr(self, 'total_samples'):
69
self.total_samples = 0
70
self.total_samples += batch_size * torch.distributed.get_world_size()
71
72
def close(self):
73
if self.is_main_process:
74
self.tb_writer.close()
75
wandb.finish()"}

Key Metrics to Monitor

  1. Training Efficiency:

    • Samples per second (throughput)
    • GPU utilization
    • Communication overhead
    • Load imbalance
  2. Memory Usage:

    • GPU memory allocated
    • Memory breakdown (parameters, activations, gradients)
    • Out-of-memory events
    • Memory fragmentation
  3. Scaling Efficiency:

    • Linear scaling factor
    • Communication-to-computation ratio
    • Device idle time

Advanced Techniques and Optimizations

Activation Checkpointing

Trades computation for memory by recomputing activations during backward pass:

  1. How it Works:

    • Save activations only at specific checkpoints
    • Recompute intermediate activations during backward pass
    • Reduces memory at the cost of additional computation
  2. Implementation:

    • Available in PyTorch, DeepSpeed, FSDP
    • Often applied to transformer blocks

{"tool": "code-editor", "defaultValue": "import torch from torch.utils.checkpoint import checkpoint

class CheckpointedTransformerLayer(nn.Module): def init(self, hidden_size, num_attention_heads, intermediate_size): super().init() self.attention = SelfAttention(hidden_size, num_attention_heads) self.intermediate = nn.Linear(hidden_size, intermediate_size) self.output = nn.Linear(intermediate_size, hidden_size)

1
def forward(self, hidden_states, attention_mask=None):
2
def custom_forward(hidden_states, attention_mask):
3
attention_output = self.attention(hidden_states, attention_mask)
4
intermediate_output = self.intermediate(attention_output)
5
layer_output = self.output(intermediate_output)
6
return layer_output
7
8
# Use checkpointing to save memory
9
return checkpoint(custom_forward, hidden_states, attention_mask)"}

Mixed Precision Training

Using lower precision (FP16/BF16) for most operations while maintaining stability:

  1. FP16 vs BF16:

    • FP16: More precision near zero, limited range
    • BF16: Same range as FP32, less precision
    • BF16 preferred for LLMs due to range requirements
  2. Numeric Stability Techniques:

    • Loss scaling to prevent gradient underflow
    • Master weights in FP32
    • Selective operations in higher precision
Loading tool...

Optimized Kernels and Operators

Specialized implementations for common operations:

  1. Flash Attention:

    • Memory-efficient attention implementation
    • IO-aware algorithm that minimizes HBM access
    • 2-3x speedup for attention computation
  2. Fused Operators:

    • Combine multiple operations into single kernels
    • Reduce memory traffic and kernel launch overhead
    • Examples: fused GELU, fused LayerNorm, fused AdamW

Fault Tolerance and Elastic Training

Strategies for resilient training at scale:

  1. Checkpoint-based Recovery:

    • Regular state saving
    • Resume training from last checkpoint
    • DeepSpeed and FSDP support
  2. Elastic Training:

    • Adapt to changing resource availability
    • Add or remove nodes during training
    • Particularly useful for preemptible instances

Case Studies and Real-world Examples

Training GPT-3 175B

OpenAI's approach used a combination of techniques:

  1. Data Parallelism: 96-way across 96 machines
  2. Model Parallelism: 8-way tensor parallelism
  3. Custom Distributed Communication: Optimized for Azure
  4. Mixed Precision: FP16 with dynamic loss scaling
  5. Memory Optimization: Attention optimizations

Training LLaMA 65B

Meta's approach:

  1. Distributed Training: 2048 A100 GPUs
  2. NCCL Communication: Optimized for GPU clusters
  3. Efficient Architecture: Fewer non-linearities, no biases in dense layers
  4. Data Processing: Tokenizer with larger vocabulary to reduce sequence lengths

Training Mixtral 8x7B

Mistral AI's mixture-of-experts approach:

  1. Sparse MoE Architecture: Only activates specific experts per token
  2. Expert Parallelism: Distribute experts across GPUs
  3. Load Balancing: Ensure even utilization of experts
  4. Memory Optimization: Share parameters across experts where appropriate

Practical Exercises

Exercise 1: Setting Up FSDP

Implement a distributed training setup using PyTorch FSDP:

  1. Set up a multi-GPU environment
  2. Configure FSDP with the appropriate policies
  3. Implement gradient accumulation and mixed precision
  4. Benchmark performance and memory usage

Exercise 2: DeepSpeed Configuration

Create a DeepSpeed configuration for different scenarios:

  1. Single-node, multi-GPU training with ZeRO-3
  2. Multi-node training with pipeline parallelism
  3. Configure offloading for training with limited GPU memory
  4. Measure throughput and compare with baseline PyTorch implementation

Exercise 3: Distributed Training Diagnostics

Build a diagnostic tool to identify bottlenecks in distributed training:

  1. Monitor communication vs. computation time
  2. Analyze GPU memory usage patterns
  3. Detect load imbalances across GPUs
  4. Recommend optimizations based on collected metrics

Conclusion

Distributed training infrastructure is what makes today's largest and most capable language models possible. As models continue to grow, efficient distributed training becomes increasingly important for both research and production deployments.

The field is rapidly evolving, with new techniques and optimizations emerging regularly. Understanding the fundamentals of data, tensor, and pipeline parallelism provides a foundation for adopting these new approaches as they become available.

In our next lesson, we'll explore preference alignment techniques that help ensure these powerful large models behave according to human preferences and values.

Additional Resources

Papers

  • "Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism" (Shoeybi et al., 2019)
  • "ZeRO: Memory Optimizations Toward Training Trillion Parameter Models" (Rajbhandari et al., 2020)
  • "Efficiently Scaling Transformer Inference" (Pope et al., 2022)
  • "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness" (Dao et al., 2022)

Libraries and Tools

Blogs and Tutorials