Reputation: 4317
In this example, I wish the z_proto
could be global for different GPUs. However, in the data parallel mode, it is split into different GPUs as well. How to solve such a problem? Thank you.
class SequencePrototypeTokenClassification(nn.Module):
def __init__(self,seq_model, label_num):
super(SequencePrototypeTokenClassification, self).__init__()
self.seq_model = seq_model
self.label_num = label_num
def forward(self, input_ids, token_type_ids, attention_mask, labels, z_proto, n_query, target_inds):
z, _ = self.seq_model(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
z_dim = z.size(-1)
zq = z.squeeze().view(-1, z_dim)
dists = euclidean_dist(zq, z_proto)
log_p_y = F.log_softmax(-dists, dim=1).view(-1, self.label_num)
loss_val = -log_p_y.gather(1, self.target_inds).squeeze().view(-1).mean()
_, y_hat = log_p_y.max(1)
return loss_val, y_hat
Upvotes: 1
Views: 1116
Reputation: 4317
It turns out the DataParallel
would only replicate the nn.Parameter
of the nn.Module
. So I random initialized a nn.Parameter
named z_proto
in the module and copy the value of tensor z_proto
into the parameter. Then the parameter is replicated into 4 GPUs.
Upvotes: 0
Reputation: 2395
Based on your above code, z_proto
seems to be one of the arguments of the forward function and not part of the model. Therefore, simply storing it in a tensor
on the main GPU would enable it to have the same value across GPUs.
Based on the documentation, it seems that DataParallel
splits all the inputs to the forward pass function across the GPUs. A method by which you can circumvent it is by storing it as a class variable inside the model object itself. You can update the value before calling the forward function if it's not a static variable.
class SequencePrototypeTokenClassification(nn.Module):
def __init__(self,seq_model, label_num):
...
self.z_proto = None
...
...
#Training loop
...
model.z_proto = value
model.forward()
...
Upvotes: 1