Jellyfish Technologies Logo

DDP vs. FSDP in PyTorch: Unlocking Efficient Multi-GPU Training

DDP vs FSDP in PyTorch

As deep learning models scale to billions of parameters, standard data parallel techniques like DistributedDataParallel (DDP) often hit GPU memory limits. Enter Fully Sharded Data Parallel (FSDP) — a powerful tool in PyTorch that enables efficient large-scale training by sharding model parameters, gradients, and optimizer states across GPUs.

Here’s how DDP vs FSDP helps with training:

DDP (Distributed Data Parallel)

  • Each GPU gets a full copy of the model.
  • Each GPU processes a different batch of the data.
  • After computing gradients, they sync (using All-Reduce) to update the weights.
  • Problem: Every GPU holds the entire model, which wastes memory, especially for huge models.

FSDP (Fully Sharded Data Parallel)

  • Instead of copying the full model to each GPU, FSDP splits the model weights across GPUs.
  • Each GPU only stores part of the model, part of the gradients, and part of the optimizer states.
  • During training, GPUs coordinate to compute forward/backward passes and share data when needed.
  • Result: Less memory usage per GPU → You can train bigger models or use bigger batch sizes!

FSDP Internals

Key Idea: Shard model parameters, gradients, and optimizer states across GPUs.

Training Workflow:

  1. Sharding (Preprocessing):
    • Model parameters are sharded(partitioned) across ranks before training begins.
    • E.g., layer 1 weights on GPU 0, layer 2 weights on GPU 1, etc.
  2. Forward Pass:
    • At each layer, only one rank(GPU) owns the parameters.
    • Other ranks gather the needed weights using All-Gather just-in-time.
    • Compute forward, discard unneeded params to free memory.
  3. Backward Pass:
    • Gradients are sharded across GPUs using Reduce-Scatter.
    • Each GPU keeps a portion of the gradients for only the weights it owns.
  4. Optimizer Step:
    • Optimizer states (e.g., momentum in Adam) are also sharded.
    • Each GPU updates only the part of the model it owns.
  5. Communication Optimization:
    • FSDP overlaps communication and computation to hide communication delays.
    • Uses efficient collectives (All-Gather, Reduce-Scatter) under the hood.

Configuration Options:

  • mixed_precision: Use bf16 or fp16 to save memory.
  • sharding_strategy: Choose between FULL, HYBRID, or NO_SHARD.
  • auto_wrap_policy: Automatically wrap transformer blocks for sharding.

Comparison Summary

FeatureDDPFSDP
Model ReplicationFull model on each GPUModel is sharded
Memory EfficiencyLowHigh
Communication VolumeLowerHigher (but optimized)
Ideal Use CaseSmall to medium modelsLarge models (7B, 13B, etc.)

So, in DistributedDataParallel (DDP) training, each process/ worker owns a replica of the model and processes a batch of data, finally, it uses all-reduce to sum up gradients over different workers. In DDP, the model weights and optimizer states are replicated across all workers. FSDP is a type of data parallelism that shards model parameters, optimizer states, and gradients across DDP ranks.

When training with FSDP, the GPU memory footprint is smaller than when training with DDP across all workers. This makes the training of some very large models feasible by allowing larger models or batch sizes to fit on the device.

Let’s dive into the code:

  • Importing necessary packages related to FSDP:
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
  • Define auto_wrap_policy

This helps FSDP shard only large layers (like Linear, TransformerBlock, etc.).

import functools

my_auto_wrap_policy = functools.partial(
    size_based_auto_wrap_policy, min_num_params=20000  # You can tune this value
)
  • Set CUDA Device for this Rank
torch.cuda.set_device(rank)
  • Move Model to GPU
model = Net().to(rank)
model = FSDP(model)
  • Wrap Model with FSDP (with or without auto_wrap_policy)

🔹 Without auto wrapping:

model = FSDP(model)

🔹 With auto_wrap_policy:

model = FSDP(model, auto_wrap_policy=my_auto_wrap_policy)
  • Save Model (only rank 0)
if args.save_model:
    dist.barrier()  # Ensure all ranks finish training
    states = model.state_dict()
    if rank == 0:
        torch.save(states, "model.pt")
  • FSDP Cleanup
dist.destroy_process_group()

DDP Version:

from torch.nn.parallel import DistributedDataParallel as DDP

model = Net().to(rank)
model = DDP(model, device_ids=[rank])

Final Verdict

FeatureDDPFSDP + Auto Wrap
Setup SimplicityEasierSlightly more setup
Memory EfficiencyPoorExcellent
Big Model SupportNoYes (7B+ models)
Training SpeedSlightly fasterSlightly slower (but more scalable)
Share this article
Want to speak with our solution experts?
Jellyfish Technologies

Modernize Legacy System With AI: A Strategy for CEOs

Download the eBook and get insights on CEOs growth strategy

    Let's Talk

    We believe in solving complex business challenges of the converging world, by using cutting-edge technologies.