Saving, Loading, and Compiling Models with PyTorch and Running on NVIDIA DGX
Building on the previous tutorial, this guide will show you how to save, load, and compile models using PyTorch to fully leverage the power of the NVIDIA DGX AI Clusters available at UAlbany.
Saving & Loading Models
After training your model and achieving satisfactory results, it's wise to save it. Saving the model allows you to load it later and use it for inference without retraining. If you skip saving, you'll need to retrain the model from scratch every time you want to use it, which is both time-consuming and inefficient. In PyTorch, there are two main ways to save a model.
Saving only the model's state dictionary (the learned parameters).
Saving the entire model, including its architecture.
Saving just the state dictionary is the more common approach, offering flexibility and portability, as you can modify or reload the model architecture separately. On the other hand, saving the entire model captures both the architecture and the parameters in one file, making it convenient but less flexible across different environments or PyTorch versions. Let’s explore both methods in more detail.
Saving & Loading the Model’s State Dictionary
The best way to save and load a model in PyTorch involves using the torch.save()
and torch.load()
functions along with model.load_state_dict()
for loading the saved model parameters. Actually, when it comes to saving and loading models, these are the three core functions to be familiar with.
torch.save
: Saves a serialized object to disk.torch.load
: Deserialize object files to memory.load_state_dict
: Loads a model’s parameter dictionary using a deserialized state_dict.
Feel free to check the official PyTorch tutorial on Saving and Loading Models.
Please refer to the following example to save the model's weights and other parameters to a file (model_state.pt
in this case).
# Save Model State
torch.save(model.state_dict(), 'model_state.pt')
Now, to load the saved model’s weights, you first need to create an instance of the model class and then load the saved parameters into it:
# Create Model Instance
model = Multiclass()
# Load the Saved state_dict Into the Model
model.load_state_dict(torch.load('model_state.pt'))
# Set the Model to Evaluation Mode Before Inference
model.eval()
Saving & Loading the Whole Model (Architecture Included)
You can save the entire model in PyTorch, including its architecture, using torch.save()
. This method is convenient if you want to save both the model's structure and the learned parameters in a single step. However, it is generally less flexible than saving only the state dictionary. Please refer to the following example to save the model's architecture, along with the parameters, to a file (model.pt
in this case).
# Save the Entire Model
torch.save(model, 'model.pt')
Now, to load the entire saved model, you just need to use the following function.
Compiling Models
In PyTorch, compiling a model involves transforming the model into an optimized version that can improve performance, particularly for inference. This is done using torch.compile()
, which applies various optimization techniques to make the model run faster and more efficiently.
Compiling a model can significantly speed up inference by optimizing computational graphs and applying low-level optimizations.
It can lead to better resource utilization and reduced latency, especially beneficial for deployment in production environments.
It is meant to work out of the box (or with minor changes) with any type of model.
Feel free to check the official PyTorch documentation on Getting Started With torch.compiler.
Since compiled models are optimized for a specific runtime environment, they need to be recompiled if moved to a different hardware setup or environment. That is why there isn't a built-in way to save the model in a precompiled state. The compilation process is tied to the runtime environment, so when you save the model, it will only save the model's weights and architecture, not the compiled optimizations. Every time you load the model, you would need to recompile it by calling torch.compile()
again to benefit from the performance improvements.
By combining torch.compile()
with DataParallel
, you can leverage both optimizations for performance and multi-GPU support. In other words, the compilation step ensures that the model is optimized for better performance, while DataParallel
distributes the workload across multiple GPUs if available. Nonetheless, bear in mind that DataParallel
introduces some overhead due to data splitting and result aggregation. Therefore, for small models or batches, this overhead might outweigh the benefits of parallelism.
When performing inference, gradients are not needed since you’re not updating the model’s parameters. Disabling gradient calculation saves memory and computational resources, leading to faster and more efficient inference. This is why we use torch.no_grad()
, a PyTorch context manager that disables gradient computation. It helps reduce memory usage, minimize overhead, and improve speed. Feel free to check the official PyTorch documentation on torch.no_grad.
Although compiling a model might be viewed as a code optimization step, it can be a required prerequisite depending on the target environment. As mentioned earlier, the compilation process is closely tied to the runtime environment. For instance, when running inferences on the IBM AIU Cluster, you can only leverage its hardware by compiling the model with a custom backend (sendnn
) for the PyTorch library, which enables access to the IBM AIU.
Â