Tae-Sung Shin
Tae-Sung Shin

Reputation: 20620

register_buffer a dict object in PyTorch

I thought this was a simple question but I couldn't find an answer.

I want a member variable of a pytorch module to be saved/loaded with model state_dict. I can do that in init with the following line.

        self.register_buffer('loss_weight', torch.tensor(loss_weight))

But what if loss_weight is a dict object? Is it allowed? If so, how can I convert it to a tensor?

When tried, I got an error "Could not infer dtype of dict."

Upvotes: 0

Views: 701

Answers (1)

Karl
Karl

Reputation: 5473

Per the docs, the name argument must be a string, and the tensor argument must be a pytorch tensor.

If you have a dict of buffers, you could consider using a dedicated nn.Module for that purpose. Something like this:

class BufferDict(nn.Module):
    def __init__(self, input_dict):
        super().__init__()
        for k,v in input_dict.items():
            self.register_buffer(k, v)
            
input_dict = {'a' : torch.randn(4), 'b' : torch.randn(5)}

bd = BufferDict(input_dict)
bd.state_dict()
> OrderedDict([('a', tensor([ 0.1908,  1.6965, -0.3710,  0.4551])),
               ('b', tensor([-0.6943, -0.0534,  0.1779,  1.3607, -0.2236]))])

Upvotes: 3

Related Questions