Reputation: 469
I have a recursive nested defaultdict
class defined as
from collections import defaultdict
class NestedDict(defaultdict):
def __init__(self):
super().__init__(self.__class__)
sitting in a nested_dict.py
file.
When I try to pickle it, e.g.
import pickle
from nested_dict import NestedDict
d = NestedDict()
pickle.loads(pickle.dumps(d))
I get TypeError: __init__() takes 1 positional argument but 2 were given
.
What's exactly happening here?
Upvotes: 2
Views: 1245
Reputation: 1122312
The defaultdict
class implements a object.__reduce__()
method where the second element of the returned tuple (the arguments for the constructor) is always going to be the factory object:
>>> d = NestedDict()
>>> d.__reduce__()
(<class '__main__.NestedDict'>, (<class '__main__.NestedDict'>,), None, None, <dict_itemiterator object at 0x110df59a8>)
That argument is then passed to the NestedDict()
call to re-build the object. The exception is thrown because the NestedDict
class doesn’t accept an argument.
You can override the __reduce__
method in your subclass:
class NestedDict(defaultdict):
def __init__(self):
super().__init__(self.__class__)
def __reduce__(self):
return (type(self), (), None, None, iter(self.items()))
The above produces the exact same elements defaultdict.__reduce__()
returns, except that the second element is now an empty tuple.
You could also just accept and ignore a single argument:
class NestedDict(defaultdict):
def __init__(self, _=None): # accept a factory and ignore it
super().__init__(self.__class__)
The _
name is commonly used to mean I am ignoring this value.
An alternative implementation could just subclass dict
and provide a custom __missing__
method; this method is called for keys not in the dictionary:
class NestedDict(dict):
def __missing__(self, key):
nested = self[key] = type(self)()
return nested
def __repr__(self):
return f'{type(self).__name__}({super().__repr__()})'
This works exactly like your version, but doesn't need additional pickle support methods:
>>> d = NestedDict()
>>> d['foo']
NestedDict({})
>>> d['foo']['bar']
NestedDict({})
>>> d
NestedDict({'foo': NestedDict({'bar': NestedDict({})})})
>>> pickle.loads(pickle.dumps(d))
NestedDict({'foo': NestedDict({'bar': NestedDict({})})})
Upvotes: 2