groupstudent
groupstudent

Reputation: 4317

In pytorch data parallel mode, how to use the global tensor?

In this example, I wish the z_proto could be global for different GPUs. However, in the data parallel mode, it is split into different GPUs as well. How to solve such a problem? Thank you.

class SequencePrototypeTokenClassification(nn.Module):
    def __init__(self,seq_model, label_num):
        super(SequencePrototypeTokenClassification, self).__init__()
        self.seq_model = seq_model
        self.label_num = label_num

    def forward(self, input_ids, token_type_ids, attention_mask, labels, z_proto, n_query, target_inds):
        z, _ = self.seq_model(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
        z_dim = z.size(-1)
        zq = z.squeeze().view(-1, z_dim)
        dists = euclidean_dist(zq, z_proto)
        log_p_y = F.log_softmax(-dists, dim=1).view(-1, self.label_num)
        loss_val = -log_p_y.gather(1, self.target_inds).squeeze().view(-1).mean()
        _, y_hat = log_p_y.max(1)

        return loss_val, y_hat

Upvotes: 1

Views: 1116

Answers (2)

groupstudent
groupstudent

Reputation: 4317

It turns out the DataParallel would only replicate the nn.Parameter of the nn.Module. So I random initialized a nn.Parameter named z_proto in the module and copy the value of tensor z_proto into the parameter. Then the parameter is replicated into 4 GPUs.

Upvotes: 0

Haran Rajkumar
Haran Rajkumar

Reputation: 2395

Based on your above code, z_proto seems to be one of the arguments of the forward function and not part of the model. Therefore, simply storing it in a tensor on the main GPU would enable it to have the same value across GPUs.

Edit

Based on the documentation, it seems that DataParallel splits all the inputs to the forward pass function across the GPUs. A method by which you can circumvent it is by storing it as a class variable inside the model object itself. You can update the value before calling the forward function if it's not a static variable.

class SequencePrototypeTokenClassification(nn.Module):
    def __init__(self,seq_model, label_num):
        ...
        self.z_proto = None
        ...
        ...


#Training loop
    ...
    model.z_proto = value
    model.forward()
    ...


Upvotes: 1

Related Questions