Checkpointing Guide: Best Practices for DGX On-Prem Free Tier Jobs
This guide outlines best practices for running jobs on the University at Albany DGX On-Prem cluster. The system operates under two Quality of Service (QoS) tiers: free and paid. While this document focuses on free tier usage, understanding its limitations and implementing proper job management strategies is crucial for successful execution.
TL;DR: This is a guide for using the University at Albany's DGX On-Prem cluster's free tier. Key points:
Free tier limits: please refer to the Service Level Agreement (SLA).
Jobs can be preempted (stopped and requeued) if paid-tier users need resources. When this happens, all progress is lost unless properly saved.
The solution is checkpointing - regularly saving your work's state (like autosave in a document). The guide provides PyTorch code examples for:
Saving/loading model training progress
Tracking training metrics
Managing inference tasks
The main message is: Always implement checkpointing in your jobs, or you risk losing hours/days of work when preemption occurs.
Free Tier Resource Allocation
The free tier provides substantial computing resources with specific limitations such as:
Maximum number of jobs per user
Access to up to a specific number of GPUs and CPUs
Maximum job duration per user
Automatic job requeuing upon preemption
For current free tier limits, please refer to the Service Level Agreement (SLA).
Understanding Job Preemption
In scenarios where compute resources are needed for paid-tier jobs, free-tier jobs may be preempted (temporarily terminated). When preemption occurs, the following happens:
The running job is immediately stopped
The job is automatically requeued in the system
When resources become available again, the job starts fresh from the beginning
Without proper safeguards, this means:
All in-memory data and computational progress is lost
The job restarts from its initial state, potentially wasting hours or days of previous computation
Any intermediate results that weren't explicitly saved are unrecoverable
While the system automatically handles the requeuing process, it's the user's responsibility to implement mechanisms that allow their jobs to intelligently resume from the last saved state. This is where proper checkpointing strategies become crucial - they enable your job to restart from its last saved point rather than beginning anew, effectively transforming a potentially disruptive preemption into a mere pause in execution.
Implementing Checkpointing
Checkpointing is a critical practice for maintaining training progress in a preemptible environment. Similar to periodically saving a document while writing, checkpointing preserves your model's state throughout the training process. This practice is essential in the cluster environment for several reasons:
Recovery from preemption events when paid-tier jobs take priority
Protection against unexpected system errors or hardware issues
Flexibility to resume training from specific progress points
Implementing checkpointing strategies varies significantly depending on your specific use case and computational workload. While checkpointing is crucial for various tasks, it's particularly vital in machine learning workflows, especially during model training processes where losing progress could mean days of wasted computation. The implementation details depend heavily on your chosen deep learning framework.
PyTorch, TensorFlow, and other frameworks each provide their own mechanisms and best practices for saving and loading model states. This guide focuses specifically on PyTorch implementations, though the general principles can be adapted for other frameworks. With PyTorch, you can save not only the model's parameters but also the optimizer state, learning rate scheduler, and custom training metrics, ensuring a complete snapshot of your training state.
PyTorch Checkpointing Implementation for Model Training
The most basic PyTorch checkpoint needs to save two essential components:
The model state (weights and biases)
The optimizer state (crucial for proper training resumption)
Here's a simple example showing how to save and load these components:
# Saving a checkpoint
def save_checkpoint(model, optimizer, epoch, path='checkpoint.pt'):
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict()
}
torch.save(checkpoint, path)
print(f"Checkpoint saved at epoch {epoch}")
# Loading a checkpoint
def load_checkpoint(model, optimizer, path='checkpoint.pt'):
if not os.path.exists(path):
print("No checkpoint found, starting from scratch")
return 0 # Starting epoch
checkpoint = torch.load(path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
print(f"Checkpoint loaded from epoch {epoch}")
return epoch
The save_checkpoint
function saves the current state, while load_checkpoint
restores it when needed. These save and load checkpoint functions are the building blocks for implementing a robust training process that can handle preemption. By integrating them into your training loop, you create a resilient system that can recover from interruptions. Here it is a basic training loop with integrated checkpointing.
def train_model(model, optimizer, train_loader, num_epochs, checkpoint_freq=5):
# Try to load previous checkpoint
start_epoch = load_checkpoint(model, optimizer)
# Resume training from the last saved epoch
for epoch in range(start_epoch, num_epochs):
running_loss = 0.0
# Training loop for one epoch
for batch_idx, (inputs, targets) in enumerate(train_loader):
# Zero the gradients
optimizer.zero_grad()
# Forward pass
outputs = model(inputs)
loss = criterion(outputs, targets)
# Backward pass and optimize
loss.backward()
optimizer.step()
running_loss += loss.item()
# Print epoch stats
print(f'Epoch {epoch} loss: {running_loss/len(train_loader)}')
# Save checkpoint every checkpoint_freq epochs
if (epoch + 1) % checkpoint_freq == 0:
save_checkpoint(model, optimizer, epoch)
print('Training completed')
These are the steps performed by the train_model
function.
When the training starts, we try to load a previous checkpoint using
load_checkpoint
If a checkpoint exists, it resumes from that epoch
If no checkpoint exists, it starts from epoch 0
During training, we save checkpoints every
checkpoint_freq
epochsThe default is set to 5, meaning it saves every 5 epochs
You can adjust this frequency based on your needs
If the job gets preempted, when it restarts:
It will load the last saved checkpoint
Continue training from that epoch
All previous progress up to the last checkpoint is preserved
Advanced PyTorch Checkpointing Features for Model Training
While saving the model and optimizer states covers the basics, you might want to track additional information during training. For example, when conducting research, you'll often need to track your model's loss history to analyze performance and create visualizations for publications. This kind of tracking requires saving additional training metrics. Here's an enhanced version of our checkpointing that includes these extra features
# Enhanced checkpoint saving
def save_checkpoint(model, optimizer, epoch, loss_history, scheduler=None, path='checkpoint.pt'):
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss_history': loss_history, # Track loss over time
}
# Save scheduler state if it exists
if scheduler is not None:
checkpoint['scheduler_state_dict'] = scheduler.state_dict()
torch.save(checkpoint, path)
print(f"Checkpoint saved at epoch {epoch}")
# Enhanced checkpoint loading
def load_checkpoint(model, optimizer, scheduler=None, path='checkpoint.pt'):
if not os.path.exists(path):
print("No checkpoint found, starting from scratch")
return 0, [] # Return starting epoch and empty loss history
checkpoint = torch.load(path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# Load scheduler state if it exists
if scheduler is not None and 'scheduler_state_dict' in checkpoint:
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
epoch = checkpoint['epoch']
loss_history = checkpoint['loss_history']
print(f"Checkpoint loaded from epoch {epoch}")
return epoch, loss_history
This enhanced version:
Tracks training loss history, allowing you to plot learning curves even after preemption
Optionally saves and loads the learning rate scheduler state (a component that adjusts the learning rate during training)
Returns both the epoch number and loss history when loading
Now, let's see how to integrate these enhanced checkpoint functions into our training loop. The main differences from our previous implementation are the initialization and tracking of the loss history, along with passing this information to our checkpoint functions. This allows us to maintain a complete record of our model's performance throughout training, even across multiple preemptions. Here's the enhanced training loop:
The key changes from our previous training loop are:
The
load_checkpoint
function now returns both the epoch number and loss historyWe initialize an empty loss history if starting fresh
We track the average loss for each epoch in the
loss_history
listThe
save_checkpoint
function now includes this loss history
This enhanced version ensures that you maintain a complete record of your training progress, even if the job gets preempted multiple times. Keep in mind that this implementation is just a suggestion and can be adapted based on your specific project requirements. While we've focused on PyTorch here, the core concept of checkpointing - saving and loading training states - remains consistent across different deep learning frameworks like TensorFlow, Keras, etc. You can customize these checkpoint functions to track any metrics relevant to your research, such as accuracy, F1 score, precision, recall, or custom evaluation metrics. The key is to ensure that whatever information you need to analyze and validate your model's performance is properly saved and restored during the training process.
PyTorch Checkpointing for Inference
While checkpointing is commonly associated with model training, it's equally important for inference tasks. When running predictions on extensive datasets in our free-tier environment, your job might get preempted before processing all the data. Without proper checkpointing, you'd have to restart predictions from the beginning, wasting valuable compute time. Here's an approach to save your inference progress:
This implementation:
Tracks which samples have already been processed
Periodically saves prediction results
Resumes from the last saved point if preempted
Avoids recomputing predictions for already processed data