Reputation: 82
I'd like to fine-tune a pruned model. I tried to do as below;
I could do from (1) to (4), but after (5), the weights of the trained model do not contain zeros(= it is not a pruned model)
I want to do pruning aware training?
Here is my code working on Google Colab. I guess that I have to do something at (4) or (5) or I couldn't find the way...
Definitions
# import libraries
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.optim as optim
from torch.autograd import Variable
import torch.nn.utils.prune as prune
# network
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc = nn.Sequential(
nn.Linear(10, 12), nn.BatchNorm1d(12), nn.ReLU(), nn.Dropout(p=0.1),
nn.Linear(12, 12), nn.BatchNorm1d(12), nn.ReLU(), nn.Dropout(p=0.1),
nn.Linear(12, 10), nn.Sigmoid()
)
def forward(self, input):
x = self.fc(input)
return x
net = Net()
# data
x_train = np.random.uniform(-1, 1, [100, 10])
y_train = 2 * x_train
x_train = torch.from_numpy(x_train.astype(np.float32)).float()
y_train = torch.from_numpy(y_train.astype(np.float32)).float()
x_test = np.random.uniform(-1, 1, [100, 10])
y_test = 2 * x_test
x_test = torch.from_numpy(x_test.astype(np.float32)).float()
y_test = torch.from_numpy(y_test.astype(np.float32)).float()
train_set = torch.utils.data.TensorDataset(x_train, y_train)
test_set = torch.utils.data.TensorDataset(x_test, y_test)
train_loader = DataLoader(dataset=train_set, batch_size=50, shuffle=True)
test_loader = DataLoader(dataset=train_set, batch_size=50, shuffle=True)
For (1)
epochs = 100
train_len = len(x_train)
test_len = len(x_test)
net_optimizer = torch.optim.Adam(net.parameters(), lr=0.01, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
criterion = nn.MSELoss()
net.train()
for epoch in range(epochs):
train_loss = 0.0
val_loss = 0.0
# Training the model
net.train()
counter = 0
for data in train_loader:
inputs, labels = data
net_optimizer.zero_grad()
outputs = net.forward(inputs)
loss = criterion(outputs, labels)
loss.backward()
net_optimizer.step()
train_loss += loss.item() * inputs.size(0)
# Evaluating the model
net.eval()
with torch.no_grad():
for data in test_loader:
inputs, labels = data
output = net.forward(inputs)
valloss = criterion(output, labels)
val_loss += valloss.item() * inputs.size(0)
train_loss = train_loss/train_len
valid_loss = val_loss/test_len
print('[%d] Training Loss: %.6f, Validation Loss: %.6f' % (epoch + 1, train_loss, valid_loss))
For (2)
torch.save(net.state_dict(), 'trained_model.model')
For (3)
trained_model = Net()
trained_model.load_state_dict(torch.load('trained_model.model'))
tmp_module = trained_model.fc[0]
prune.ln_structured(tmp_module, name="weight", amount=0.5, n = 1, dim = 1)
prune.remove(tmp_module, "weight")
For (4)
torch.save(trained_model.state_dict(), 'pruned_model.model')
For (5)
del net
net = Net()
net.load_state_dict(torch.load('pruned_model.model'))
###
then do (1)
###
Upvotes: 1
Views: 2539
Reputation: 21
In step (3)
prune.ln_structured(tmp_module, name="weight", amount=0.5, n = 1, dim = 1)
introduces two new parameters to the model 'weight_orig' and 'weight_mask'. The product of these two parameters gives 'weight'.
prune.remove(tmp_module, "weight")
removes these two parameters and replaces them with a new parameter 'weight' whose values are the product of 'weight_orig' and 'weight_mask'. Essentially, after this step, there is no mask associated with the weights. So, when you fine-tune (or re-train) the model, all the weights including the ones that were pruned get updated. Thus, giving non-zero values everywhere.
The solution is to call prune.remove(tmp_module, "weight")
after fine-tuning. Thus the code after pruning would be:
prune.ln_structured(tmp_module, name="weight", amount=0.5, n = 1, dim = 1)
then do (1)
prune.remove(tmp_module, "weight")
Upvotes: 2