hac81acnh
hac81acnh

Reputation: 82

How to fine-tune the pruned model in PyTorch?

I'd like to fine-tune a pruned model. I tried to do as below;

  1. train a model
  2. save the trained model # (trained_model.model)
  3. load the trained model and prune it
  4. save the pruned model # (pruned_model.model)
  5. load the pruned model and train it

I could do from (1) to (4), but after (5), the weights of the trained model do not contain zeros(= it is not a pruned model)

I want to do pruning aware training?

Here is my code working on Google Colab. I guess that I have to do something at (4) or (5) or I couldn't find the way...

Definitions

# import libraries
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.optim as optim
from torch.autograd import Variable
import torch.nn.utils.prune as prune

# network
class Net(nn.Module):    
    def __init__(self):
        super(Net, self).__init__()
        
        self.fc = nn.Sequential(
            nn.Linear(10, 12),  nn.BatchNorm1d(12),  nn.ReLU(),   nn.Dropout(p=0.1), 
            nn.Linear(12, 12),  nn.BatchNorm1d(12),  nn.ReLU(),   nn.Dropout(p=0.1), 
            nn.Linear(12, 10), nn.Sigmoid()
        )
    
    def forward(self, input):
        x = self.fc(input)
        return x

net = Net()

# data
x_train = np.random.uniform(-1, 1, [100, 10])
y_train = 2 * x_train
x_train = torch.from_numpy(x_train.astype(np.float32)).float()
y_train = torch.from_numpy(y_train.astype(np.float32)).float()

x_test = np.random.uniform(-1, 1, [100, 10])
y_test = 2 * x_test
x_test = torch.from_numpy(x_test.astype(np.float32)).float()
y_test = torch.from_numpy(y_test.astype(np.float32)).float()

train_set = torch.utils.data.TensorDataset(x_train, y_train)
test_set = torch.utils.data.TensorDataset(x_test, y_test)

train_loader = DataLoader(dataset=train_set, batch_size=50, shuffle=True)
test_loader = DataLoader(dataset=train_set, batch_size=50, shuffle=True)

For (1)

epochs = 100
train_len = len(x_train)    
test_len = len(x_test)

net_optimizer = torch.optim.Adam(net.parameters(), lr=0.01, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
criterion = nn.MSELoss()

net.train()
for epoch in range(epochs):
    train_loss = 0.0
    val_loss = 0.0

    # Training the model
    net.train()
    counter = 0

    for data in train_loader:
        inputs, labels = data
        net_optimizer.zero_grad()
        outputs = net.forward(inputs)

        loss = criterion(outputs, labels)
        loss.backward()
        net_optimizer.step()
        train_loss += loss.item() * inputs.size(0)
    
    # Evaluating the model
    net.eval()
    with torch.no_grad():
        for data in test_loader:
            inputs, labels = data            
            output = net.forward(inputs)            
            valloss = criterion(output, labels)
            val_loss += valloss.item() * inputs.size(0)

    train_loss = train_loss/train_len
    valid_loss = val_loss/test_len
    print('[%d] Training Loss: %.6f, Validation Loss: %.6f'  % (epoch + 1, train_loss, valid_loss))

For (2)

torch.save(net.state_dict(), 'trained_model.model')

For (3)

trained_model = Net()
trained_model.load_state_dict(torch.load('trained_model.model'))
tmp_module = trained_model.fc[0]
prune.ln_structured(tmp_module, name="weight", amount=0.5, n = 1, dim = 1)
prune.remove(tmp_module, "weight")

For (4)

torch.save(trained_model.state_dict(), 'pruned_model.model')

For (5)

del net
net = Net()
net.load_state_dict(torch.load('pruned_model.model'))

###
then do (1)
###

Upvotes: 1

Views: 2539

Answers (1)

user3800305
user3800305

Reputation: 21

In step (3) prune.ln_structured(tmp_module, name="weight", amount=0.5, n = 1, dim = 1) introduces two new parameters to the model 'weight_orig' and 'weight_mask'. The product of these two parameters gives 'weight'.

prune.remove(tmp_module, "weight") removes these two parameters and replaces them with a new parameter 'weight' whose values are the product of 'weight_orig' and 'weight_mask'. Essentially, after this step, there is no mask associated with the weights. So, when you fine-tune (or re-train) the model, all the weights including the ones that were pruned get updated. Thus, giving non-zero values everywhere.

The solution is to call prune.remove(tmp_module, "weight") after fine-tuning. Thus the code after pruning would be:

prune.ln_structured(tmp_module, name="weight", amount=0.5, n = 1, dim = 1)

then do (1)

prune.remove(tmp_module, "weight")

Upvotes: 2

Related Questions