How to Load a Partial Pretrained PyTorch Model

Loading a Partial Pretrained PyTorch Model

Sometimes, you may need to load only specific parts of a pretrained PyTorch model, instead of the entire model. This could be useful for:

  • Fine-tuning specific layers of a model
  • Transfer learning, where you use pre-trained features from a model for a new task
  • Reducing memory footprint when working with large models

Loading Specific Layers

1. State Dict Method

The simplest method involves loading the state dict of the pretrained model and copying only the desired layers.

Code Output
 import torch import torch.nn as nn class MyModel(nn.Module): def __init__(self): super(MyModel, self).__init__() self.conv1 = nn.Conv2d(3, 64, kernel_size=3) self.conv2 = nn.Conv2d(64, 128, kernel_size=3) self.fc = nn.Linear(128 * 4 * 4, 10) def forward(self, x): x = self.conv1(x) x = self.conv2(x) x = x.view(-1, 128 * 4 * 4) x = self.fc(x) return x # Load pretrained model pretrained_model = torch.load("pretrained_model.pth") # Create a new model model = MyModel() # Copy only the convolutional layers model.conv1.load_state_dict(pretrained_model["conv1"]) model.conv2.load_state_dict(pretrained_model["conv2"]) # Initialize the fully connected layer randomly torch.nn.init.xavier_uniform_(model.fc.weight) torch.nn.init.zeros_(model.fc.bias) # Now you can train the model with the pretrained convolutional layers and a newly initialized FC layer 
 # No output, but the model is initialized with pretrained convolutional layers. 

2. Using Named Modules

You can access specific layers in the pretrained model using their names and load their state dicts directly.

Code Output
 import torch import torch.nn as nn # Load pretrained model pretrained_model = torch.load("pretrained_model.pth") # Access the convolutional layers using their names conv1_weights = pretrained_model["conv1.weight"] conv1_bias = pretrained_model["conv1.bias"] conv2_weights = pretrained_model["conv2.weight"] conv2_bias = pretrained_model["conv2.bias"] # Create a new model model = MyModel() # Load weights directly into the corresponding layers model.conv1.weight.data = conv1_weights model.conv1.bias.data = conv1_bias model.conv2.weight.data = conv2_weights model.conv2.bias.data = conv2_bias # Initialize the fully connected layer randomly torch.nn.init.xavier_uniform_(model.fc.weight) torch.nn.init.zeros_(model.fc.bias) # Now you can train the model with the pretrained convolutional layers and a newly initialized FC layer 
 # No output, but the model is initialized with pretrained convolutional layers. 

Loading Submodules

If you need to load a whole submodule (like an encoder or decoder) from a pretrained model, use the following steps:

1. Define a Submodule

Create a class for the submodule that you want to load.

Code Output
 class Encoder(nn.Module): def __init__(self): super(Encoder, self).__init__() self.conv1 = nn.Conv2d(3, 64, kernel_size=3) self.conv2 = nn.Conv2d(64, 128, kernel_size=3) def forward(self, x): x = self.conv1(x) x = self.conv2(x) return x class Decoder(nn.Module): def __init__(self): super(Decoder, self).__init__() self.fc = nn.Linear(128 * 4 * 4, 10) def forward(self, x): x = x.view(-1, 128 * 4 * 4) x = self.fc(x) return x 
 # No output, but the submodule classes are defined. 

2. Load Submodule from Pretrained Model

Load the pretrained model and extract the desired submodule.

Code Output
 # Load pretrained model pretrained_model = torch.load("pretrained_model.pth") # Create instances of the submodules encoder = Encoder() decoder = Decoder() # Load the pretrained weights into the submodules encoder.load_state_dict(pretrained_model["encoder"]) decoder.load_state_dict(pretrained_model["decoder"]) 
 # No output, but the submodules are initialized with pretrained weights. 

3. Utilize the Loaded Submodules

Now you can use the loaded submodules as part of your main model.

Code Output
 class MyModel(nn.Module): def __init__(self): super(MyModel, self).__init__() self.encoder = Encoder() self.decoder = Decoder() def forward(self, x): x = self.encoder(x) x = self.decoder(x) return x # Instantiate the full model model = MyModel() # Now the model will utilize the pretrained encoder and decoder 
 # No output, but the model is initialized with pretrained encoder and decoder. 

Conclusion

This article presented several methods to load partial pretrained models in PyTorch. Choose the method that best fits your specific needs based on the layers or submodules you need to load.

Leave a Reply

Your email address will not be published. Required fields are marked *