/
Checkpointing Guide: Best Practices for DGX On-Prem Free Tier Jobs

Checkpointing Guide: Best Practices for DGX On-Prem Free Tier Jobs

This guide outlines best practices for running jobs on the University at Albany DGX On-Prem cluster. The system operates under two Quality of Service (QoS) tiers: free and paid. While this document focuses on free tier usage, understanding its limitations and implementing proper job management strategies is crucial for successful execution.

TL;DR: This is a guide for using the University at Albany's DGX On-Prem cluster's free tier. Key points:

  1. Free tier limits: please refer to the Service Level Agreement (SLA).

  2. Jobs can be preempted (stopped and requeued) if paid-tier users need resources. When this happens, all progress is lost unless properly saved.

  3. The solution is checkpointing - regularly saving your work's state (like autosave in a document). The guide provides PyTorch code examples for:

    • Saving/loading model training progress

    • Tracking training metrics

    • Managing inference tasks

The main message is: Always implement checkpointing in your jobs, or you risk losing hours/days of work when preemption occurs.

Free Tier Resource Allocation

The free tier provides substantial computing resources with specific limitations such as:

  • Maximum number of jobs per user

  • Access to up to a specific number of GPUs and CPUs

  • Maximum job duration per user

  • Automatic job requeuing upon preemption

For current free tier limits, please refer to the Service Level Agreement (SLA).

Understanding Job Preemption

In scenarios where compute resources are needed for paid-tier jobs, free-tier jobs may be preempted (temporarily terminated). When preemption occurs, the following happens:

  • The running job is immediately stopped

  • The job is automatically requeued in the system

  • When resources become available again, the job starts fresh from the beginning

Without proper safeguards, this means:

  • All in-memory data and computational progress is lost

  • The job restarts from its initial state, potentially wasting hours or days of previous computation

  • Any intermediate results that weren't explicitly saved are unrecoverable

While the system automatically handles the requeuing process, it's the user's responsibility to implement mechanisms that allow their jobs to intelligently resume from the last saved state. This is where proper checkpointing strategies become crucial - they enable your job to restart from its last saved point rather than beginning anew, effectively transforming a potentially disruptive preemption into a mere pause in execution.

Implementing Checkpointing

Checkpointing is a critical practice for maintaining training progress in a preemptible environment. Similar to periodically saving a document while writing, checkpointing preserves your model's state throughout the training process. This practice is essential in the cluster environment for several reasons:

  • Recovery from preemption events when paid-tier jobs take priority

  • Protection against unexpected system errors or hardware issues

  • Flexibility to resume training from specific progress points

Implementing checkpointing strategies varies significantly depending on your specific use case and computational workload. While checkpointing is crucial for various tasks, it's particularly vital in machine learning workflows, especially during model training processes where losing progress could mean days of wasted computation. The implementation details depend heavily on your chosen deep learning framework.

PyTorch, TensorFlow, and other frameworks each provide their own mechanisms and best practices for saving and loading model states. This guide focuses specifically on PyTorch implementations, though the general principles can be adapted for other frameworks. With PyTorch, you can save not only the model's parameters but also the optimizer state, learning rate scheduler, and custom training metrics, ensuring a complete snapshot of your training state.

PyTorch Checkpointing Implementation for Model Training

The most basic PyTorch checkpoint needs to save two essential components:

  • The model state (weights and biases)

  • The optimizer state (crucial for proper training resumption)

Here's a simple example showing how to save and load these components:

# Saving a checkpoint def save_checkpoint(model, optimizer, epoch, path='checkpoint.pt'): checkpoint = { 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict() } torch.save(checkpoint, path) print(f"Checkpoint saved at epoch {epoch}") # Loading a checkpoint def load_checkpoint(model, optimizer, path='checkpoint.pt'): if not os.path.exists(path): print("No checkpoint found, starting from scratch") return 0 # Starting epoch checkpoint = torch.load(path) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) epoch = checkpoint['epoch'] print(f"Checkpoint loaded from epoch {epoch}") return epoch

The save_checkpoint function saves the current state, while load_checkpoint restores it when needed. These save and load checkpoint functions are the building blocks for implementing a robust training process that can handle preemption. By integrating them into your training loop, you create a resilient system that can recover from interruptions. Here it is a basic training loop with integrated checkpointing.

def train_model(model, optimizer, train_loader, num_epochs, checkpoint_freq=5): # Try to load previous checkpoint start_epoch = load_checkpoint(model, optimizer) # Resume training from the last saved epoch for epoch in range(start_epoch, num_epochs): running_loss = 0.0 # Training loop for one epoch for batch_idx, (inputs, targets) in enumerate(train_loader): # Zero the gradients optimizer.zero_grad() # Forward pass outputs = model(inputs) loss = criterion(outputs, targets) # Backward pass and optimize loss.backward() optimizer.step() running_loss += loss.item() # Print epoch stats print(f'Epoch {epoch} loss: {running_loss/len(train_loader)}') # Save checkpoint every checkpoint_freq epochs if (epoch + 1) % checkpoint_freq == 0: save_checkpoint(model, optimizer, epoch) print('Training completed')

These are the steps performed by the train_model function.

  1. When the training starts, we try to load a previous checkpoint using load_checkpoint

    • If a checkpoint exists, it resumes from that epoch

    • If no checkpoint exists, it starts from epoch 0

  2. During training, we save checkpoints every checkpoint_freq epochs

    • The default is set to 5, meaning it saves every 5 epochs

    • You can adjust this frequency based on your needs

  3. If the job gets preempted, when it restarts:

    • It will load the last saved checkpoint

    • Continue training from that epoch

    • All previous progress up to the last checkpoint is preserved

Advanced PyTorch Checkpointing Features for Model Training

While saving the model and optimizer states covers the basics, you might want to track additional information during training. For example, when conducting research, you'll often need to track your model's loss history to analyze performance and create visualizations for publications. This kind of tracking requires saving additional training metrics. Here's an enhanced version of our checkpointing that includes these extra features

# Enhanced checkpoint saving def save_checkpoint(model, optimizer, epoch, loss_history, scheduler=None, path='checkpoint.pt'): checkpoint = { 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss_history': loss_history, # Track loss over time } # Save scheduler state if it exists if scheduler is not None: checkpoint['scheduler_state_dict'] = scheduler.state_dict() torch.save(checkpoint, path) print(f"Checkpoint saved at epoch {epoch}") # Enhanced checkpoint loading def load_checkpoint(model, optimizer, scheduler=None, path='checkpoint.pt'): if not os.path.exists(path): print("No checkpoint found, starting from scratch") return 0, [] # Return starting epoch and empty loss history checkpoint = torch.load(path) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) # Load scheduler state if it exists if scheduler is not None and 'scheduler_state_dict' in checkpoint: scheduler.load_state_dict(checkpoint['scheduler_state_dict']) epoch = checkpoint['epoch'] loss_history = checkpoint['loss_history'] print(f"Checkpoint loaded from epoch {epoch}") return epoch, loss_history

This enhanced version:

  • Tracks training loss history, allowing you to plot learning curves even after preemption

  • Optionally saves and loads the learning rate scheduler state (a component that adjusts the learning rate during training)

  • Returns both the epoch number and loss history when loading

Now, let's see how to integrate these enhanced checkpoint functions into our training loop. The main differences from our previous implementation are the initialization and tracking of the loss history, along with passing this information to our checkpoint functions. This allows us to maintain a complete record of our model's performance throughout training, even across multiple preemptions. Here's the enhanced training loop:

The key changes from our previous training loop are:

  1. The load_checkpoint function now returns both the epoch number and loss history

  2. We initialize an empty loss history if starting fresh

  3. We track the average loss for each epoch in the loss_history list

  4. The save_checkpoint function now includes this loss history

This enhanced version ensures that you maintain a complete record of your training progress, even if the job gets preempted multiple times. Keep in mind that this implementation is just a suggestion and can be adapted based on your specific project requirements. While we've focused on PyTorch here, the core concept of checkpointing - saving and loading training states - remains consistent across different deep learning frameworks like TensorFlow, Keras, etc. You can customize these checkpoint functions to track any metrics relevant to your research, such as accuracy, F1 score, precision, recall, or custom evaluation metrics. The key is to ensure that whatever information you need to analyze and validate your model's performance is properly saved and restored during the training process.

PyTorch Checkpointing for Inference

While checkpointing is commonly associated with model training, it's equally important for inference tasks. When running predictions on extensive datasets in our free-tier environment, your job might get preempted before processing all the data. Without proper checkpointing, you'd have to restart predictions from the beginning, wasting valuable compute time. Here's an approach to save your inference progress:

This implementation:

  • Tracks which samples have already been processed

  • Periodically saves prediction results

  • Resumes from the last saved point if preempted

  • Avoids recomputing predictions for already processed data