Single-GPU Implementation for PyTorch on Olivia

This is part 1 of the PyTorch on Olivia guide. See PyTorch on Olivia for the overview, software choice, storage recommendations, and the full guide structure.

The goal of this part is to run the reference training workflow on a single GH200 GPU before scaling to multiple GPUs and multiple nodes.

Learning Outcomes

By the end of this part, you can:

  1. Run a PyTorch training job on 1 GPU on Olivia.

  2. Submit and monitor the job with Slurm.

  3. Confirm success from expected log output.

Note

Key considerations for Olivia:

  • The login node (x86_64) and compute nodes (Aarch64) have different architectures. Software and containers must be built for ARM (Aarch64) to run on the compute nodes.

  • Compute nodes use CUDA 12.7. Ensure container compatibility.

In order to be able to use PyTorch on Olivia we provide different solutions. You can read more about those solutions in detail here. (PyTorch Software Options on Olivia)

PyTorch Runtime Setup

You can run this guide in three ways: through the PyTorch module path, by launching the container directly, or through EESSI modules.

Note

If you use Hugging Face models or datasets, see Models, Datasets, Caches, and Overlays on Olivia for the required environment cache flags.

Before submitting jobs, set this variable in your job script:

SCRIPT_DIR="/cluster/work/projects/<project_number>/<username>/pytorch_olivia"

Then load the software stack in this order:

ml reset
ml load NRIS/GPU
ml load NCCL/2.26.6-GCCcore-14.2.0-CUDA-12.8.0
ml use /cluster/work/support/pytorch_module
ml load PyTorch/2.8.0
export PYTORCH_OVERLAY_MODE=ro

Note

/cluster/work/support/temporary_modules is a temporary module root while the PyTorch module rollout is in progress. When the service is fully live, you can load the PyTorch module directly without this ml use line.

Project Setup

Warning

Due to limited space in your home directory, set up your project in your work or project area (e.g., /cluster/work/projects/nnXXXXk/username/pytorch_olivia/). The CIFAR-100 dataset (~500 MB) will be downloaded automatically on first run.

Before you begin:

  1. Create an empty directory in your work or project area

  2. cd into that directory

  3. Copy the code blocks from this page into files with the same names

The guide will provide all the files needed. When complete, your directory will have this structure:

your_project_directory/
├── train.py
├── train_ddp.py          # for multi-GPU/multi-node
├── dataset_utils.py
├── train_utils.py
├── device_utils.py
├── model.py
├── singlegpu_job.sh
├── multigpu_job.sh       # for multi-GPU
├── multinode_job.sh      # for multi-node
├── hf_cache/             # created automatically by job scripts
└── datasets/             # created automatically on first run
    └── cifar-100-python/

Single GPU Implementation

To train the Wide ResNet model on a single GPU, we use the following files. The train.py file is the main training script.

 1"""
 2Single-GPU training script for Wide ResNet on CIFAR-100.
 3"""
 4
 5import argparse
 6import time
 7import torch
 8import torch.nn as nn
 9import torch.optim as optim
10
11from dataset_utils import load_cifar100
12from model import WideResNet
13from train_utils import train as train_one_epoch, test as evaluate
14from device_utils import get_device
15
16# Parse command-line arguments
17parser = argparse.ArgumentParser()
18parser.add_argument('--batch-size', type=int, default=256, help='Batch size for training')
19parser.add_argument('--epochs', type=int, default=100, help='Number of epochs')
20parser.add_argument('--base-lr', type=float, default=0.01, help='Learning rate')
21parser.add_argument('--target-accuracy', type=float, default=0.95, help='Target accuracy to stop training')
22parser.add_argument('--patience', type=int, default=2, help='Number of epochs that meet target before stopping')
23args = parser.parse_args()
24
25
26def main():
27    device = get_device()
28    print(f"Training WideResNet on CIFAR-100 with Batch Size: {args.batch_size}")
29
30    # Training variables
31    val_accuracy = []
32    total_time = 0
33    total_images = 0
34
35    # Load the dataset
36    train_loader, test_loader = load_cifar100(batch_size=args.batch_size)
37
38    # Initialize the model
39    model = WideResNet(num_classes=100).to(device)
40
41    # Define loss function and optimizer
42    loss_fn = nn.CrossEntropyLoss()
43    optimizer = optim.SGD(model.parameters(), lr=args.base_lr)
44
45    for epoch in range(args.epochs):
46        t0 = time.time()
47
48        # Train for one epoch
49        train_one_epoch(model, optimizer, train_loader, loss_fn, device)
50
51        epoch_time = time.time() - t0
52        total_time += epoch_time
53
54        # Compute throughput
55        images_per_sec = len(train_loader) * args.batch_size / epoch_time
56        total_images += len(train_loader) * args.batch_size
57
58        # Evaluate
59        v_accuracy, v_loss = evaluate(model, test_loader, loss_fn, device)
60        val_accuracy.append(v_accuracy)
61
62        print(f"Epoch {epoch + 1}/{args.epochs}: Time={epoch_time:.3f}s, "
63              f"Loss={v_loss:.4f}, Accuracy={v_accuracy:.4f}, "
64              f"Throughput={images_per_sec:.1f} img/s")
65
66        # Early stopping
67        if len(val_accuracy) >= args.patience and all(
68            acc >= args.target_accuracy for acc in val_accuracy[-args.patience:]
69        ):
70            print(f"Target accuracy reached. Early stopping after epoch {epoch + 1}.")
71            break
72
73    # Final summary
74    throughput = total_images / total_time
75    print(f"\nTraining complete. Final Accuracy: {val_accuracy[-1]:.4f}")
76    print(f"Total Time: {total_time:.1f}s, Throughput: {throughput:.1f} img/s")
77
78
79if __name__ == "__main__":
80    main()

The dataset_utils.py file contains the data utility functions used for preparing and managing the dataset.Please note that, you will be installing the CIFAR-100 dataset and it will be placed in the datasets folder through this script.

 1import torchvision
 2import torchvision.transforms as transforms
 3import torch
 4from pathlib import Path
 5import os
 6
 7def _data_dir_default():
 8    repo_root = Path(__file__).resolve().parent
 9    data_dir = repo_root / "datasets"
10    data_dir.mkdir(parents=True, exist_ok=True)
11    return data_dir
12
13
14def load_cifar100(batch_size, num_workers=0, sampler=None, data_dir=None):
15    """
16    Loads the CIFAR-100 dataset.Create the dataset directory to store dataset during runtime and no environment variable support.
17    """
18    root = Path(data_dir).expanduser().resolve() if data_dir else _data_dir_default()
19
20    # Define transformations
21    transform = transforms.Compose([
22        transforms.RandomHorizontalFlip(),
23        transforms.RandomCrop(32, padding=4),
24        transforms.ToTensor(),
25        transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
26    ])
27
28    # Load full datasets
29    train_set = torchvision.datasets.CIFAR100(
30            root=str(root), download=True, train=True, transform=transform)
31    test_set = torchvision.datasets.CIFAR100(root=str(root), download=True, train=False, transform=transform)
32
33    # Create the data loaders
34    train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, drop_last=True, shuffle=(sampler is None), sampler=sampler, num_workers=num_workers, pin_memory=True)
35    test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, drop_last=True, shuffle=False, num_workers=num_workers, pin_memory=True)
36    return train_loader, test_loader

The device_utils.py file includes the device utility functions, which handle device selection and management for training (e.g., selecting the appropriate GPU).

 1import torch
 2
 3def get_device():
 4    """
 5    Determine the compute device (GPU or CPU).
 6    Returns:
 7        torch.device: The device to use for the computations.
 8    """
 9
10    return torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

The model.py file contains the implementation of the Wide ResNet model architecture.

 1import torch.nn as nn
 2
 3# Standard convulation block followed by batch normalization
 4class cbrblock(nn.Module):
 5    def __init__(self, input_channels, output_channels):
 6        super(cbrblock, self).__init__()
 7        self.cbr = nn.Sequential(nn.Conv2d(input_channels, output_channels, kernel_size=3, stride=(1, 1), padding='same', bias=False), nn.BatchNorm2d(output_channels), nn.ReLU())
 8
 9    def forward(self, x):
10        return self.cbr(x)
11
12
13# Basic residual block
14class conv_block(nn.Module):
15    def __init__(self, input_channels, output_channels, scale_input):
16        super(conv_block, self).__init__()
17        self.scale_input = scale_input
18        if self.scale_input:
19            self.scale = nn.Conv2d(input_channels, output_channels, kernel_size=1, stride=(1, 1), padding='same')
20        self.layer1 = cbrblock(input_channels, output_channels)
21        self.dropout = nn.Dropout(p=0.01)
22        self.layer2 = cbrblock(output_channels, output_channels)
23
24    def forward(self, x):
25        residual = x
26        out = self.layer1(x)
27        out = self.dropout(out)
28        out = self.layer2(out)
29        if self.scale_input:
30            residual = self.scale(residual)
31        return out + residual
32
33# WideResnet model
34class WideResNet(nn.Module):
35    def __init__(self, num_classes):
36        super(WideResNet, self).__init__()
37        # RGB images (3 channels) input for CIFAR-100 dataset
38        nChannels = [3, 16, 160, 320, 640]
39        # Grayscale images (1 channel) for Fashion MNIST dataset
40        # nChannels = [1, 16, 160, 320, 640]
41        self.input_block = cbrblock(nChannels[0], nChannels[1])
42        self.block1 = conv_block(nChannels[1], nChannels[2], scale_input=True)
43        self.block2 = conv_block(nChannels[2], nChannels[2], scale_input=False)
44        self.pool1 = nn.MaxPool2d(2)
45        self.block3 = conv_block(nChannels[2], nChannels[3], scale_input=True)
46        self.block4 = conv_block(nChannels[3], nChannels[3], scale_input=False)
47        self.pool2 = nn.MaxPool2d(2)
48        self.block5 = conv_block(nChannels[3], nChannels[4], scale_input=True)
49        self.block6 = conv_block(nChannels[4], nChannels[4], scale_input=False)
50        # Global Average pooling
51        self.pool = nn.AvgPool2d(7)
52        # Fully connected layer
53        self.flat = nn.Flatten()
54        self.fc = nn.Linear(nChannels[4], num_classes)
55
56    def forward(self, x):
57        out = self.input_block(x)
58        out = self.block1(out)
59        out = self.block2(out)
60        out = self.pool1(out)
61        out = self.block3(out)
62        out = self.block4(out)
63        out = self.pool2(out)
64        out = self.block5(out)
65        out = self.block6(out)
66        out = self.pool(out)
67        out = self.flat(out)
68        out = self.fc(out)
69        return out

Finally, the train_utils.py file serves as a utility module for importing the training and testing datasets.

 1import torch
 2
 3def train(model, optimizer, train_loader, loss_fn, device):
 4    """
 5    Trains the model for one epoch.Note that, this function will be used only for single gpu implementation. For the multi-gpu implementation, we will be defining the train function in the train_ddp.py file itself.
 6    Args:
 7        model(torch.nn.Module): The model to train.
 8        optimizer(torch.optim.Optimizer): Optimizer for updating model parameters.
 9        train_loader(torch.utils.data.DataLoader): DataLoader for training data.
10        loss_fn (torch.nn.Module): Loss function.
11        device (torch.device): Device to run training on (CPU or GPU).
12    """
13    model.train()
14    for images, labels in train_loader:
15        images, labels = images.to(device), labels.to(device)
16        # Forward passs
17        outputs = model(images)
18        loss = loss_fn(outputs, labels)
19        # Backward pass and optimization
20        optimizer.zero_grad()
21        loss.backward()
22        optimizer.step()
23
24
25def test(model, test_loader, loss_fn, device):
26    """
27    Evaluates the model on the validation dataset.Note that, this function will be used in the multi-gpu implementation aswell.
28    Args:
29        model(torch.nn.Module): The model to evaluate.
30        test_loader (torch.utils.data.DataLoader): DataLoader for validation data.
31        loss_fn (torch.nn.Module): Loss function.
32        device (torch.device): Device to run evaluation on (CPU or GPU).
33    Returns:
34        tuple: Validation accuracy and validaiton loss.
35    """
36
37    model.eval()
38    total_labels = 0
39    correct_labels = 0
40    loss_total = 0
41    with torch.no_grad():
42        for images, labels in test_loader:
43            images, labels = images.to(device), labels.to(device)
44            # Forward pass
45            outputs = model(images)
46            loss = loss_fn(outputs, labels)
47            # Compute accuracy and loss
48            predictions = torch.max(outputs, 1)[1]
49            total_labels += len(labels)
50            correct_labels += (predictions == labels).sum().item()
51            loss_total += loss.item()
52
53    v_accuracy = correct_labels / total_labels
54    v_loss = loss_total / len(test_loader)
55    return v_accuracy, v_loss

Job Script for Single GPU Training

Use whichever launch model matches your workflow.

 1#!/bin/bash
 2#SBATCH --job-name=resnet_singleGpu_mod
 3#SBATCH --account=<project_number>
 4#SBATCH --output=singlegpu_module_%j.out
 5#SBATCH --error=singlegpu_module_%j.err
 6#SBATCH --time=01:00:00
 7#SBATCH --partition=accel
 8#SBATCH --nodes=1
 9#SBATCH --ntasks-per-node=1
10#SBATCH --cpus-per-task=72
11#SBATCH --mem=110G
12#SBATCH --gpus-per-node=1
13
14set -euo pipefail
15
16SCRIPT_DIR="/cluster/work/projects/<project_number>/<username>/pytorch_olivia"
17
18ml reset
19ml load NRIS/GPU
20ml load NCCL/2.26.6-GCCcore-14.2.0-CUDA-12.8.0
21ml use /cluster/work/support/pytorch_module
22ml load PyTorch/2.8.0
23
24export PYTORCH_OVERLAY_MODE=ro
25
26
27cd "${SCRIPT_DIR}"
28
29python -c 'import torch; print(f"CUDA available: {torch.cuda.is_available()}, GPUs: {torch.cuda.device_count()}")'
30
31torchrun --standalone --nnodes=1 --nproc_per_node=1 \
32    train.py --batch-size 256 --epochs 100 --base-lr 0.01 --target-accuracy 0.95 --patience 2

The submit and monitor commands are identical for both launch modes.

sbatch singlegpu_module.sh
squeue -u $USER
tail -f singlegpu_module_<jobid>.out

Example output showing training progress:

Epoch 95/100: Time=9.819s, Loss=1.6997, Accuracy=0.6386, Throughput=5084.2 img/s
Epoch 96/100: Time=9.789s, Loss=1.5348, Accuracy=0.6581, Throughput=5099.8 img/s
Epoch 97/100: Time=9.818s, Loss=1.5620, Accuracy=0.6507, Throughput=5084.4 img/s
Epoch 98/100: Time=9.805s, Loss=1.5820, Accuracy=0.6562, Throughput=5091.3 img/s
Epoch 99/100: Time=9.773s, Loss=1.5247, Accuracy=0.6635, Throughput=5107.8 img/s
Epoch 100/100: Time=9.608s, Loss=1.6100, Accuracy=0.6419, Throughput=5195.4 img/s

Training complete. Final Accuracy: 0.6419
Total Time: 973.8s, Throughput: 5126.4 img/s

The output shows a throughput of approximately 5,100 images/second on a single GH200 GPU. In the next parts of this guide, we’ll scale this up to multiple GPUs and see significant speedups.

Success criteria for Part 1:

  • Job reaches Training complete

  • Final summary prints a non-zero throughput in img/s

  • No CUDA initialization errors in .err log

Now the goal is to scale this up to multiple GPUs. For this, please check out the Multi GPU Guide.