Multi-GPU Implementation for PyTorch on Olivia

This is part 2 of the PyTorch on Olivia guide. See PyTorch on Olivia for the single-GPU setup.

To scale training across multiple GPUs, we use PyTorch’s Distributed Data Parallel (DDP). The code below works for both single-node multi-GPU and multi-node configurations.

  1# train_ddp.py
  2import os
  3import time
  4import argparse
  5import torch
  6import torch.nn as nn
  7import torchvision
  8from torch.utils.data.distributed import DistributedSampler
  9from torch.nn.parallel import DistributedDataParallel as DDP
 10from torch.distributed import init_process_group, destroy_process_group
 11from dataset_utils import load_cifar100
 12from model import WideResNet
 13from train_utils import test
 14
 15# Configuration
 16DATA_DIR = "./datasets"  # Dataset downloads automatically here
 17
 18# Parse input arguments
 19parser = argparse.ArgumentParser(description='CIFAR-100 DDP example with Mixed Precision',
 20                                 formatter_class=argparse.ArgumentDefaultsHelpFormatter)
 21parser.add_argument('--batch-size', type=int, default=512, help='Input batch size for training')
 22parser.add_argument('--epochs', type=int, default=5, help='Number of epochs to train')
 23parser.add_argument('--base-lr', type=float, default=0.01, help='Learning rate for single GPU')
 24parser.add_argument('--target-accuracy', type=float, default=0.85, help='Target accuracy to stop training')
 25parser.add_argument('--patience', type=int, default=2, help='Number of epochs that meet target before stopping')
 26args = parser.parse_args()
 27
 28def ddp_setup():
 29    """Set up the distributed environment."""
 30    init_process_group(backend="nccl")
 31    torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
 32
 33def main_worker():
 34    ddp_setup()
 35
 36   # Get the local rank and device
 37    local_rank = int(os.environ["LOCAL_RANK"])
 38    global_rank = int(os.environ["RANK"])
 39    world_size = int(os.environ["WORLD_SIZE"])
 40    device = torch.device(f"cuda:{local_rank}")  # Note: we don't use get_device from device_utils here
 41
 42   # Log initialization info
 43    if global_rank == 0:
 44        print(f"Training started with {world_size} processes across {world_size // torch.cuda.device_count()} nodes.")
 45        print(f"Using {torch.cuda.device_count()} GPUs per node.")
 46
 47    # Load the CIFAR-100 dataset with DistributedSampler
 48    per_gpu_batch_size = args.batch_size // world_size  # Divide global batch size across GPUs
 49    train_sampler = DistributedSampler(
 50        torchvision.datasets.CIFAR100(
 51            root=DATA_DIR,
 52            train=True,
 53            download=True
 54        )
 55    )
 56    train_loader, test_loader = load_cifar100(
 57        batch_size=per_gpu_batch_size,
 58        num_workers=8,
 59        sampler=train_sampler
 60    )
 61
 62    # Create the model and wrap it with DDP
 63    num_classes = 100  # CIFAR-100 has 100 classes
 64    model = WideResNet(num_classes).to(device)
 65    model = DDP(model, device_ids=[local_rank])
 66
 67    # Define loss function and optimizer
 68    loss_fn = nn.CrossEntropyLoss()
 69    optimizer = torch.optim.SGD(model.parameters(), lr=args.base_lr, momentum=0.9, weight_decay=5e-4)
 70
 71    # Initialize gradient scaler for mixed precision
 72    scaler = torch.amp.GradScaler('cuda')
 73    val_accuracy = []
 74    total_time = 0
 75    total_images = 0  # Total images processed globally
 76
 77    # Training loop
 78    for epoch in range(args.epochs):
 79        train_sampler.set_epoch(epoch)  # Set the sampler epoch for shuffling
 80        model.train()
 81        t0 = time.time()
 82
 83       # Train the model for one epoch
 84        for images, labels in train_loader:
 85            images, labels = images.to(device, non_blocking=True), labels.to(device, non_blocking=True)
 86
 87           # Zero the gradients
 88            optimizer.zero_grad()
 89
 90           # Forward pass with mixed precision
 91            with torch.amp.autocast('cuda'):
 92                outputs = model(images)
 93                loss = loss_fn(outputs, labels)
 94
 95           # Backward pass and optimization with scaled gradients
 96            scaler.scale(loss).backward()
 97            scaler.step(optimizer)
 98            scaler.update()
 99
100        # Synchronize all processes
101        torch.distributed.barrier()
102        epoch_time = time.time() - t0
103        total_time += epoch_time
104
105        # Compute throughput (images per second for this epoch)
106        images_per_sec = len(train_loader) * args.batch_size / epoch_time
107        total_images += len(train_loader) * args.batch_size
108
109        # Compute validation accuracy and loss
110        v_accuracy, v_loss = test(model, test_loader, loss_fn, device)
111
112        # Average validation metrics across all GPUs
113        v_accuracy_tensor = torch.tensor(v_accuracy).to(device)
114        v_loss_tensor = torch.tensor(v_loss).to(device)
115        torch.distributed.all_reduce(v_accuracy_tensor, op=torch.distributed.ReduceOp.AVG)
116        torch.distributed.all_reduce(v_loss_tensor, op=torch.distributed.ReduceOp.AVG)
117
118        # Print metrics only from the main process
119        if global_rank == 0:
120            print(f"Epoch {epoch + 1}/{args.epochs} completed in {epoch_time:.3f} seconds")
121            print(f"Validation Loss: {v_loss_tensor.item():.4f}, Validation Accuracy: {v_accuracy_tensor.item():.4f}")
122            print(f"Epoch Throughput: {images_per_sec:.3f} images/second")
123
124        # Early stopping
125        val_accuracy.append(v_accuracy_tensor.item())
126        if len(val_accuracy) >= args.patience and all(acc >= args.target_accuracy for acc in val_accuracy[-args.patience:]):
127            if global_rank == 0:
128                print(f"Target accuracy reached. Early stopping after epoch {epoch + 1}.")
129            break
130
131    # Log total training time and summary
132    if global_rank == 0:
133        throughput = total_images / total_time
134        print("\nTraining Summary:")
135        print(f"Total training time: {total_time:.3f} seconds")
136        print(f"Throughput: {throughput:.3f} images/second")
137        print(f"Number of nodes: {world_size // torch.cuda.device_count()}")
138        print(f"Number of GPUs per node: {torch.cuda.device_count()}")
139        print(f"Total GPUs used: {world_size}")
140        print("Training completed successfully.")
141
142    # Clean up the distributed environment
143    destroy_process_group()
144if __name__ == '__main__':
145    main_worker()

Key Changes from Single-GPU to Multi-GPU

The highlighted lines above show the DDP-specific additions:

Lines

Change

Purpose

8-10

DDP imports

DistributedSampler, DDP, init_process_group

28-31

ddp_setup()

Initialize NCCL backend and set local GPU

34

Call ddp_setup()

Start distributed environment

48-49

Batch size division

Split global batch across GPUs

64-65

Wrap model with DDP

Enable synchronized gradient updates

71-72

Mixed precision setup

GradScaler for FP16 training

79

set_epoch()

Ensure proper shuffling across epochs

91-93

autocast() context

Run forward pass in FP16

101

barrier()

Synchronize all processes after epoch

113-116

all_reduce()

Average metrics across GPUs

143

destroy_process_group()

Clean up distributed environment

Job Script for Multi-GPU Training

For single-node multi-GPU training, use torchrun with --standalone. We request 4 GPUs and adjust batch size and learning rate for better scaling.

 1#!/bin/bash
 2#SBATCH --job-name=resnet_multigpu
 3#SBATCH --account=<project_number>
 4#SBATCH --output=multigpu_%j.out
 5#SBATCH --error=multigpu_%j.err
 6#SBATCH --time=01:00:00
 7#SBATCH --partition=accel
 8#SBATCH --nodes=1
 9#SBATCH --ntasks-per-node=1
10#SBATCH --cpus-per-task=72
11#SBATCH --mem=440G
12#SBATCH --gpus=4
13
14CONTAINER_PATH="/cluster/work/support/container/pytorch_nvidia_25.06_arm64.sif"
15
16# Run training with 4 GPUs
17apptainer exec --nv $CONTAINER_PATH torchrun \
18    --standalone \
19    --nnodes=1 \
20    --nproc_per_node=4 \
21    train_ddp.py --batch-size 1024 --epochs 100 --base-lr 0.04 --target-accuracy 0.95 --patience 2

Example output:

Epoch 95/100 completed in 1.330 seconds
Validation Loss: 1.0529, Validation Accuracy: 0.7362
Epoch Throughput: 36967.964 images/second
Epoch 96/100 completed in 1.312 seconds
Validation Loss: 1.1540, Validation Accuracy: 0.7197
Epoch Throughput: 37460.949 images/second
Epoch 97/100 completed in 1.326 seconds
Validation Loss: 1.1519, Validation Accuracy: 0.7183
Epoch Throughput: 37075.227 images/second
Epoch 98/100 completed in 1.360 seconds
Validation Loss: 1.1684, Validation Accuracy: 0.7097
Epoch Throughput: 36138.810 images/second
Epoch 99/100 completed in 1.303 seconds
Validation Loss: 1.1415, Validation Accuracy: 0.7190
Epoch Throughput: 37718.226 images/second
Epoch 100/100 completed in 1.332 seconds
Validation Loss: 1.1560, Validation Accuracy: 0.7153
Epoch Throughput: 36913.593 images/second

Training Summary:
Total training time: 131.972 seconds
Throughput: 37244.369 images/second
Number of nodes: 1
Number of GPUs per node: 4
Total GPUs used: 4
Training completed successfully.

With 4 GPUs and FP16 mixed precision, the throughput increased from ~5,100 images/second (single GPU) to ~37,000 images/second—a 7x speedup. This super-linear scaling (beyond the expected 4x) comes from mixed precision training (FP16) and the larger effective batch size, which better utilizes the GPU compute capabilities.

To scale beyond a single node, see the Multi-Node Guide.