PyTorch on Olivia

This guide demonstrates how to run PyTorch on Olivia using NVIDIA’s optimized PyTorch container. We train a Wide ResNet model on the CIFAR-100 dataset across three scenarios:

  1. Single GPU (this page)

  2. Multi-GPU - 4 GPUs on a single node (Multi-GPU Implementation for PyTorch on Olivia)

  3. Multi-Node - Multiple nodes (Multi-Node Implementation for PyTorch on Olivia)

Performance Summary

This 3-part guide walks you through scaling PyTorch training on Olivia’s GH200 GPUs:

Configuration

Throughput

Speedup

Single GPU (Part 1)

~5,100 img/s

1x

4 GPUs on 1 node (Part 2)

~37,000 img/s

7x

8 GPUs on 2 nodes (Part 3)

~63,000 img/s

12x

The multi-GPU guides use FP16 mixed precision for improved performance.

Note

Key considerations for Olivia:

  • The login node (x86_64) and compute nodes (Aarch64) have different architectures. Software and containers must be built for ARM (Aarch64) to run on the compute nodes.

  • Compute nodes use CUDA 12.7. Ensure container compatibility.

Getting the Container

The PyTorch container is available pre-pulled at:

/cluster/work/support/container/pytorch_nvidia_25.06_arm64.sif

To pull a different version yourself, use the --arch arm64 flag since you’re pulling from the login node (x86_64) for use on compute nodes (Aarch64):

apptainer pull --arch arm64 docker://nvcr.io/nvidia/pytorch:25.06-py3

Project Setup

Warning

Due to limited space in your home directory, set up your project in your work or project area (e.g., /cluster/work/projects/nnXXXXk/username/pytorch_olivia/). The CIFAR-100 dataset (~500 MB) will be downloaded automatically on first run.

Before you begin:

  1. Create an empty directory in your work or project area

  2. cd into that directory

The guide will provide all the files needed. When complete, your directory will have this structure:

your_project_directory/
├── train.py
├── train_ddp.py          # for multi-GPU/multi-node
├── dataset_utils.py
├── train_utils.py
├── device_utils.py
├── model.py
├── singlegpu_job.sh
├── multigpu_job.sh       # for multi-GPU
├── multinode_job.sh      # for multi-node
└── datasets/             # created automatically on first run
    └── cifar-100-python/

Single GPU Implementation

To train the Wide ResNet model on a single GPU, we use the following files. The train.py file is the main training script.

 1# train.py
 2
 3"""
 4Single-GPU training script for Wide ResNet on CIFAR-100.
 5"""
 6
 7import argparse
 8import time
 9import torch
10import torch.nn as nn
11import torch.optim as optim
12
13from dataset_utils import load_cifar100
14from model import WideResNet
15from train_utils import train as train_one_epoch, test as evaluate
16from device_utils import get_device
17
18# Parse command-line arguments
19parser = argparse.ArgumentParser()
20parser.add_argument('--batch-size', type=int, default=256, help='Batch size for training')
21parser.add_argument('--epochs', type=int, default=100, help='Number of epochs')
22parser.add_argument('--base-lr', type=float, default=0.01, help='Learning rate')
23parser.add_argument('--target-accuracy', type=float, default=0.95, help='Target accuracy to stop training')
24parser.add_argument('--patience', type=int, default=2, help='Number of epochs that meet target before stopping')
25args = parser.parse_args()
26
27def main():
28    device = get_device()
29    print(f"Training WideResNet on CIFAR-100 with Batch Size: {args.batch_size}")
30
31    # Training variables
32    val_accuracy = []
33    total_time = 0
34    total_images = 0
35
36    # Load the dataset
37    train_loader, test_loader = load_cifar100(batch_size=args.batch_size)
38
39    # Initialize the model
40    model = WideResNet(num_classes=100).to(device)
41
42    # Define loss function and optimizer
43    loss_fn = nn.CrossEntropyLoss()
44    optimizer = optim.SGD(model.parameters(), lr=args.base_lr)
45
46    for epoch in range(args.epochs):
47        t0 = time.time()
48
49        # Train for one epoch
50        train_one_epoch(model, optimizer, train_loader, loss_fn, device)
51
52        epoch_time = time.time() - t0
53        total_time += epoch_time
54
55        # Compute throughput
56        images_per_sec = len(train_loader) * args.batch_size / epoch_time
57        total_images += len(train_loader) * args.batch_size
58
59        # Evaluate
60        v_accuracy, v_loss = evaluate(model, test_loader, loss_fn, device)
61        val_accuracy.append(v_accuracy)
62
63        print(f"Epoch {epoch + 1}/{args.epochs}: Time={epoch_time:.3f}s, "
64              f"Loss={v_loss:.4f}, Accuracy={v_accuracy:.4f}, "
65              f"Throughput={images_per_sec:.1f} img/s")
66
67        # Early stopping
68        if len(val_accuracy) >= args.patience and all(
69            acc >= args.target_accuracy for acc in val_accuracy[-args.patience:]
70        ):
71            print(f"Target accuracy reached. Early stopping after epoch {epoch + 1}.")
72            break
73
74    # Final summary
75    throughput = total_images / total_time
76    print(f"\nTraining complete. Final Accuracy: {val_accuracy[-1]:.4f}")
77    print(f"Total Time: {total_time:.1f}s, Throughput: {throughput:.1f} img/s")
78
79if __name__ == "__main__":
80    main()

The dataset_utils.py file contains the data utility functions used for preparing and managing the dataset.Please note that, you will be installing the CIFAR-100 dataset and it will be placed in the datasets folder through this script.

 1# dataset_utils.py
 2
 3import torchvision
 4import torchvision.transforms as transforms
 5import torch
 6from pathlib import Path
 7import os
 8
 9def _data_dir_default():
10    repo_root = Path(__file__).resolve().parent
11    data_dir = repo_root / "datasets"
12    data_dir.mkdir(parents=True, exist_ok=True)
13    return data_dir
14
15
16def load_cifar100(batch_size, num_workers=0,sampler=None, data_dir=None):
17    """
18    Loads the CIFAR-100 dataset.Create the dataset directory to store dataset during runtime and no environment variable support.
19    """
20    root = Path(data_dir).expanduser().resolve() if data_dir else _data_dir_default()
21
22    # Define transformations
23    transform = transforms.Compose([
24        transforms.RandomHorizontalFlip(),
25        transforms.RandomCrop(32, padding=4),
26        transforms.ToTensor(),
27        transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)) # CIFAR-100 mean and std
28    ])
29
30    # Load full datasets
31    train_set = torchvision.datasets.CIFAR100(
32            root=str(root),download=True,train=True,transform=transform)
33    test_set = torchvision.datasets.CIFAR100(root=str(root),download=True,train=False,transform=transform)
34
35    # Create the data loaders
36    train_loader = torch.utils.data.DataLoader(train_set,batch_size=batch_size,drop_last=True,shuffle=(sampler is None),sampler= sampler,num_workers=num_workers,pin_memory=True)
37    test_loader = torch.utils.data.DataLoader(test_set,batch_size=batch_size,drop_last=True,shuffle=False,num_workers=num_workers,pin_memory=True)
38    return train_loader, test_loader

The device_utils.py file includes the device utility functions, which handle device selection and management for training (e.g., selecting the appropriate GPU).

 1# device_utils.py
 2import torch
 3
 4def get_device():
 5    """
 6    Determine the compute device (GPU or CPU).
 7    Returns:
 8        torch.device: The device to use for the computations.
 9    """
10
11    return torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

The model.py file contains the implementation of the Wide ResNet model architecture.

 1# model.py
 2
 3import torch.nn as nn
 4
 5# Standard convulation block followed by batch normalization
 6class cbrblock(nn.Module):
 7    def __init__(self, input_channels, output_channels):
 8        super(cbrblock, self).__init__()
 9        self.cbr = nn.Sequential(nn.Conv2d(input_channels, output_channels, kernel_size=3, stride=(1,1), padding='same', bias=False),nn.BatchNorm2d(output_channels), nn.ReLU())
10    def forward(self, x):
11        return self.cbr(x)
12
13
14# Basic residual block
15class conv_block(nn.Module):
16    def __init__(self, input_channels, output_channels, scale_input):
17        super(conv_block, self).__init__()
18        self.scale_input = scale_input
19        if self.scale_input:
20            self.scale = nn.Conv2d(input_channels,output_channels, kernel_size=1, stride=(1,1), padding='same')
21        self.layer1 = cbrblock(input_channels, output_channels)
22        self.dropout = nn.Dropout(p=0.01)
23        self.layer2 = cbrblock(output_channels, output_channels)
24
25    def forward(self,x):
26        residual = x
27        out = self.layer1(x)
28        out = self.dropout(out)
29        out = self.layer2(out)
30        if self.scale_input:
31            residual = self.scale(residual)
32        return out + residual
33
34# WideResnet model
35class WideResNet(nn.Module):
36    def __init__(self, num_classes):
37        super(WideResNet, self).__init__()
38        # RGB images (3 channels) input for CIFAR-100 dataset
39        nChannels = [3, 16, 160, 320, 640]
40        # Grayscale images (1 channel) for Fashion MNIST dataset
41        # nChannels = [1, 16, 160, 320, 640]
42        self.input_block = cbrblock(nChannels[0], nChannels[1])
43        self.block1 = conv_block(nChannels[1], nChannels[2], scale_input=True)
44        self.block2 = conv_block(nChannels[2], nChannels[2], scale_input=False)
45        self.pool1 = nn.MaxPool2d(2)
46        self.block3 = conv_block(nChannels[2], nChannels[3], scale_input=True)
47        self.block4 = conv_block(nChannels[3], nChannels[3], scale_input=False)
48        self.pool2 = nn.MaxPool2d(2)
49        self.block5 = conv_block(nChannels[3], nChannels[4], scale_input=True)
50        self.block6 = conv_block(nChannels[4], nChannels[4], scale_input=False)
51        # Global Average pooling
52        self.pool = nn.AvgPool2d(7)
53        # Fully connected layer
54        self.flat = nn.Flatten()
55        self.fc = nn.Linear(nChannels[4], num_classes)
56
57    def forward(self, x):
58        out = self.input_block(x)
59        out = self.block1(out)
60        out = self.block2(out)
61        out = self.pool1(out)
62        out = self.block3(out)
63        out = self.block4(out)
64        out = self.pool2(out)
65        out = self.block5(out)
66        out = self.block6(out)
67        out = self.pool(out)
68        out = self.flat(out)
69        out = self.fc(out)
70        return out

Finally, the train_utils.py file serves as a utility module for importing the training and testing datasets.

 1# train_utils.py
 2import torch
 3
 4def train(model, optimizer, train_loader, loss_fn, device):
 5    """
 6    Trains the model for one epoch.Note that, this function will be used only for single gpu implementation. For the multi-gpu implementation, we will be defining the train function in the train_ddp.py file itself.
 7    Args:
 8        model(torch.nn.Module): The model to train.
 9        optimizer(torch.optim.Optimizer): Optimizer for updating model parameters.
10        train_loader(torch.utils.data.DataLoader): DataLoader for training data.
11        loss_fn (torch.nn.Module): Loss function.
12        device (torch.device): Device to run training on (CPU or GPU).
13    """
14    model.train()
15    for images, labels in train_loader:
16        images, labels = images.to(device), labels.to(device)
17        # Forward passs
18        outputs = model(images)
19        loss = loss_fn(outputs, labels)
20        # Backward pass and optimization
21        optimizer.zero_grad()
22        loss.backward()
23        optimizer.step()
24
25
26def test(model, test_loader, loss_fn, device):
27    """
28    Evaluates the model on the validation dataset.Note that, this function will be used in the multi-gpu implementation aswell.
29    Args:
30        model(torch.nn.Module): The model to evaluate.
31        test_loader (torch.utils.data.DataLoader): DataLoader for validation data.
32        loss_fn (torch.nn.Module): Loss function.
33        device (torch.device): Device to run evaluation on (CPU or GPU).
34    Returns:
35        tuple: Validation accuracy and validaiton loss.
36    """
37
38    model.eval()
39    total_labels = 0
40    correct_labels = 0
41    loss_total = 0
42    with torch.no_grad():
43        for images, labels in test_loader:
44            images, labels = images.to(device), labels.to(device)
45            # Forward pass
46            outputs = model(images)
47            loss = loss_fn(outputs, labels)
48            # Compute accuracy and loss
49            predictions = torch.max(outputs, 1)[1]
50            total_labels += len(labels)
51            correct_labels += (predictions == labels).sum().item()
52            loss_total += loss.item()
53
54    v_accuracy = correct_labels / total_labels
55    v_loss = loss_total / len(test_loader)
56    return v_accuracy, v_loss

Job Script for Single GPU Training

The --nv flag gives the container access to GPU resources. We use torchrun to launch the training script.

 1#!/bin/bash
 2#SBATCH --job-name=resnet_singleGpu
 3#SBATCH --account=<project_number>
 4#SBATCH --output=singlegpu_%j.out
 5#SBATCH --error=singlegpu_%j.err
 6#SBATCH --time=01:00:00
 7#SBATCH --partition=accel
 8#SBATCH --nodes=1
 9#SBATCH --ntasks-per-node=1
10#SBATCH --cpus-per-task=72
11#SBATCH --mem=110G
12#SBATCH --gpus-per-node=1
13
14CONTAINER_PATH="/cluster/work/support/container/pytorch_nvidia_25.06_arm64.sif"
15
16# Check GPU availability
17apptainer exec --nv $CONTAINER_PATH python -c 'import torch; print(f"CUDA available: {torch.cuda.is_available()}, GPUs: {torch.cuda.device_count()}")'
18
19# Run training
20apptainer exec --nv $CONTAINER_PATH torchrun --standalone --nnodes=1 --nproc_per_node=1 \
21    train.py --batch-size 256 --epochs 100 --base-lr 0.01 --target-accuracy 0.95 --patience 2

Example output showing training progress:

Epoch 95/100: Time=9.819s, Loss=1.6997, Accuracy=0.6386, Throughput=5084.2 img/s
Epoch 96/100: Time=9.789s, Loss=1.5348, Accuracy=0.6581, Throughput=5099.8 img/s
Epoch 97/100: Time=9.818s, Loss=1.5620, Accuracy=0.6507, Throughput=5084.4 img/s
Epoch 98/100: Time=9.805s, Loss=1.5820, Accuracy=0.6562, Throughput=5091.3 img/s
Epoch 99/100: Time=9.773s, Loss=1.5247, Accuracy=0.6635, Throughput=5107.8 img/s
Epoch 100/100: Time=9.608s, Loss=1.6100, Accuracy=0.6419, Throughput=5195.4 img/s

Training complete. Final Validation Accuracy = 0.6419
Total Training Time: 973.8 seconds
Throughput: 5126.4 images/second

The output shows a throughput of approximately 5,100 images/second on a single GH200 GPU. In the next parts of this guide, we’ll scale this up to multiple GPUs and see significant speedups.

Now the goal is to scale this up to multiple GPUs. For this, please check out the Multi GPU Guide.