Reputation: 3301
I have two parameters, A and B, that I need to put to replace all the weight of the pre-trained model. So I want to utilize the forward calculation of the pre-trained model but not the weight.
I want to modify the weight of the model W = A + B, where A is a fixed tensor (not trainable), but B is a trainable parameter. So, in the end, my aim is to train B in the structure of the pre-trained model.
This is my attempt:
class Net(nn.Module):
def __init__(self, pre_model, B):
super(Net, self).__init__()
self.B = B
self.pre_model = copy.deepcopy(pre_model)
for params in self.pre_model.parameters():
params.requires_grad = False
def forward(self, x, A):
for i, params in enumerate(self.pre_model.parameters()):
params.copy_(A[i].detach().clone()) # detached because I don't need to train A
params.add_(self.B[i]) # I need to train B so no detach
params.retain_grad()
x = self.pre_model(x)
return x
And this is how I calculate A and B:
b = []
A = []
for params in list(pre_model.parameters()):
A.append(torch.rand_like(params))
b_temp = nn.Parameter(torch.rand_like(params))
b.append(b_temp.detach().clone())
B = nn.ParameterList(b)
I checked in the process, and B was already trained. But the problem is in every iteration, the training process keeps getting slower:
Epoch 1:
24%|██▍ | 47/196 [00:05<00:23, 6.44it/s]
57%|█████▋ | 111/196 [00:18<00:19, 4.28it/s]
96%|█████████▋| 189/196 [00:41<00:02, 2.90it/s]
Epoch 2:
6%|▌ | 11/196 [00:04<01:14, 2.50it/s]
I think I have detached all the parameters correctly, but I am not sure why it happened.
UPDATED:
credits to ptrblck from PyTorch Forum, you can run the minimal example code in Google Colab here. Or use the code below for the main iteration. You will see the training iteration keeps getting slower and slower.
from torch.cuda import synchronize
device = 'cuda'
pre_model = models.resnet18().to(device)
b = []
A = []
for params in list(pre_model.parameters()):
A.append(torch.rand_like(params))
b_temp = nn.Parameter(torch.rand_like(params))
b.append(b_temp.detach().clone())
B = nn.ParameterList(b)
modelwithAB = Net(pre_model, B)
optimizer = torch.optim.Adam(modelwithAB.parameters(), lr=1e-3)
image = torch.randn(2, 3, 224, 224).to(device)
print(torch.cuda.memory_allocated()/1024**2)
for i in tqdm(range(300)):
optimizer.zero_grad()
out = modelwithAB(image, A)
start = time.time()
out.mean().backward()
torch.cuda.synchronize()
optimizer.step()
if i%40==0:
print("-", torch.cuda.memory_allocated()/1024**2, "-", time.time()-start)
Upvotes: 4
Views: 533
Reputation: 3977
Problem is the backward pass.
from the moment you do params.add_(self.B[i])
params of pre_model will have params.requires_grad = True
again.
What will happen params depend on B which will again depend on params.
Means you will have AddBackward gradients and from the second iteration also CopyBackwards which will be chained together.
For every forward pass the computational graph will grow larger and larger, which slows down the backward pass that you experience.
Upvotes: 1
Reputation: 1690
ptrblck
: Based on your code snippet you are detaching A, which is the fixed tensor, while you are adding B to params potentially including its entire computation graph. Could you double check this, please?
Well, let's investigate! This code is main suspect:
def forward(self, x, A):
for i, params in enumerate(self.pre_model.parameters()):
params.copy_(A[i].detach().clone()) # detached because I don't need to train A
params.add_(self.B[i]) # I need to train B so no detach
params.retain_grad()
x = self.pre_model(x)
return x
We will vary iterations in yours for i in tqdm(range(300))
while watching graph size.
Check model body with torchinfo.summary
!pip install torchinfo
from torchinfo import summary
...
def forward(self, x, A=None):
if A is None:
A=self.B
...
summary( modelwithAB, input_size=image.shape, depth=10)
for i in tqdm(range(2))
:
===============================================================================================
Layer (type:depth-idx) Output Shape Param #
===============================================================================================
Net [2, 1000] 11,689,512
├─ResNet: 1-1 [2, 1000] --
│ └─Conv2d: 2-1 [2, 64, 112, 112] 9,408
│ └─BatchNorm2d: 2-2 [2, 64, 112, 112] 128
│ └─ReLU: 2-3 [2, 64, 112, 112] --
...
===============================================================================================
Total params: 23,379,024
Trainable params: 23,379,024
Non-trainable params: 0
Total mult-adds (G): 3.63
===============================================================================================
Input size (MB): 1.20
Forward/backward pass size (MB): 79.49
Params size (MB): 46.76
Estimated Total Size (MB): 127.46
===============================================================================================
for i in tqdm(range(30))
:
===============================================================================================
Total params: 23,379,024
Trainable params: 23,379,024
Non-trainable params: 0
Total mult-adds (G): 3.63
===============================================================================================
Input size (MB): 1.20
Forward/backward pass size (MB): 79.49
Params size (MB): 46.76
Estimated Total Size (MB): 127.46
===============================================================================================
Well, looks like all right here, need to go deeper. Now let's check gradient computation graph, we will count all backward nodes of different kind.
xxx = dict()
def add_nodes(var, consumer =None, grad =None ):
if hasattr(var, 'next_functions'):
try:
grads = var(grad)
grads = grads if isinstance(grads, tuple) else [grads]
if ( not hasattr(grads, '__iter__') ):
grads = [grads]
except:
grads = map ( (lambda x:None), var.next_functions )
for i, (u, grad) in enumerate(zip(var.next_functions, grads)):
#print(i,type( var ).__name__)
xxx[type( var ).__name__] = xxx.get(type( var ).__name__,0) +1
for uu in u:
add_nodes(uu, var, grad)
add_nodes( modelwithAB(image, A).grad_fn )
print(xxx)
1:
{'AddmmBackward0': 3, 'AddBackward0': 10214, 'CopyBackwards': 9704, 'ReshapeAliasBackward0': 1, 'MeanBackward1': 1, 'ReluBackward0': 766, 'CudnnBatchNormBackward0': 2424, 'ConvolutionBackward0': 2424, 'MaxPool2DWithIndicesBackward0': 256, 'TBackward0': 1}
2:
{'AddmmBackward0': 3, 'AddBackward0': 15066, 'CopyBackwards': 14556, 'ReshapeAliasBackward0': 1, 'MeanBackward1': 1, 'ReluBackward0': 766, 'CudnnBatchNormBackward0': 2424, 'ConvolutionBackward0': 2424, 'MaxPool2DWithIndicesBackward0': 256, 'TBackward0': 1}
3:
{'AddmmBackward0': 3, 'AddBackward0': 19918, 'CopyBackwards': 19408, 'ReshapeAliasBackward0': 1, 'MeanBackward1': 1, 'ReluBackward0': 766, 'CudnnBatchNormBackward0': 2424, 'ConvolutionBackward0': 2424, 'MaxPool2DWithIndicesBackward0': 256, 'TBackward0': 1}
30:
{'AddmmBackward0': 3, 'AddBackward0': 150922, 'CopyBackwards': 150412, 'ReshapeAliasBackward0': 1, 'MeanBackward1': 1, 'ReluBackward0': 766, 'CudnnBatchNormBackward0': 2424, 'ConvolutionBackward0': 2424, 'MaxPool2DWithIndicesBackward0': 256, 'TBackward0': 1}
Bingo. Count of AddBackward0
and CopyBackwards
nodes rizes every trainng step. So, whetever you expected to acheive, you can't do it with such parameter manipulation. Can't suggest a fix, because as well as ptrblck
I am not sure what you are up to. Why you don't happy with standart aproach - train head with fixed trained backbone?
Upvotes: 2