Edward Khachatryan
Edward Khachatryan

Reputation: 469

Can't pickle recursive nested defaultdict

I have a recursive nested defaultdict class defined as

from collections import defaultdict

class NestedDict(defaultdict):
    def __init__(self):
        super().__init__(self.__class__)

sitting in a nested_dict.py file.

When I try to pickle it, e.g.

import pickle
from nested_dict import NestedDict

d = NestedDict()
pickle.loads(pickle.dumps(d))

I get TypeError: __init__() takes 1 positional argument but 2 were given.

What's exactly happening here?

Upvotes: 2

Views: 1245

Answers (1)

Martijn Pieters
Martijn Pieters

Reputation: 1122312

The defaultdict class implements a object.__reduce__() method where the second element of the returned tuple (the arguments for the constructor) is always going to be the factory object:

>>> d = NestedDict()
>>> d.__reduce__()
(<class '__main__.NestedDict'>, (<class '__main__.NestedDict'>,), None, None, <dict_itemiterator object at 0x110df59a8>)

That argument is then passed to the NestedDict() call to re-build the object. The exception is thrown because the NestedDict class doesn’t accept an argument.

You can override the __reduce__ method in your subclass:

class NestedDict(defaultdict):
    def __init__(self):
        super().__init__(self.__class__)
    def __reduce__(self):
        return (type(self), (), None, None, iter(self.items()))

The above produces the exact same elements defaultdict.__reduce__() returns, except that the second element is now an empty tuple.

You could also just accept and ignore a single argument:

class NestedDict(defaultdict):
    def __init__(self, _=None):  # accept a factory and ignore it
        super().__init__(self.__class__)

The _ name is commonly used to mean I am ignoring this value.

An alternative implementation could just subclass dict and provide a custom __missing__ method; this method is called for keys not in the dictionary:

class NestedDict(dict):
    def __missing__(self, key):
        nested = self[key] = type(self)()
        return nested
    def __repr__(self):
        return f'{type(self).__name__}({super().__repr__()})'

This works exactly like your version, but doesn't need additional pickle support methods:

>>> d = NestedDict()
>>> d['foo']
NestedDict({})
>>> d['foo']['bar']
NestedDict({})
>>> d
NestedDict({'foo': NestedDict({'bar': NestedDict({})})})
>>> pickle.loads(pickle.dumps(d))
NestedDict({'foo': NestedDict({'bar': NestedDict({})})})

Upvotes: 2

Related Questions