Lars Francke
Lars Francke

Reputation: 726

How can I find out whether a field in a dataclass has the default value or whether it's explicitly set?

I have a dataclass for which I'd like to find out whether each field was explicitly set or whether it was populated by either default or default_factory.

I know that I can get all fields using dataclasses.fields(...) and that'll probably work for fields that use default but not easily for fields using default_factory.

My end goal is to merge two dataclass instances A and B. While B should only override fields of A where A is using the default value.

The use case is a configuration object that can be specified in multiple locations and some have a higher priority than others.

Edit: An example

from dataclasses import dataclass, field

def bar():
  return "bar"

@dataclass
class Configuration:
  foo: str = field(default_factory=bar)

conf1 = Configuration(
)

conf2 = Configuration(
  foo="foo"
)

conf3 = Configuration(
  foo="bar"
)

I'd like to detect that conf1.foo is using the default value and conf2.foo & conf3.foo were explicitly set.

Upvotes: 3

Views: 7738

Answers (1)

Arne
Arne

Reputation: 20147

As a start, something like this merge function is probably what you could write given your knowledge about fields, with the example of instance z showing its shortcomings. But given that this implementation uses the dataclass tools exactly in the way that they are intended means that it's rather stable, so if at all possible you'd want to use this:

from dataclasses import asdict, dataclass, field, fields, MISSING


@dataclass
class A:
    a: str
    b: float = 5
    c: list = field(default_factory=list)


def merge(base, add_on):
    retain = {}
    for f in fields(base):
        val = getattr(base, f.name)
        if val == f.default:
            continue
        if f.default_factory != MISSING:
            if val == f.default_factory():
                continue
        retain[f.name] = val
    kwargs = {**asdict(add_on), **retain}
    return type(base)(**kwargs)


fill = A('1', 1, [1])

x = A('a')
y = A('a', 2, [3])
z = A('a', 5, [])
print(merge(x, fill))  # good: A(a='a', b=1, c=[1])
print(merge(y, fill))  # good: A(a='a', b=2, c=[3])
print(merge(z, fill))  # bad:  A(a='a', b=1, c=[1])

Getting the z case right is going to involve some kind of hack, I'd personally just decorate the dataclass again:

from dataclasses import asdict, dataclass, field, fields


def mergeable(inst):
    old_init = inst.__init__

    def new_init(self, *args, **kwargs):
        self.__customs = {f.name for f, _ in zip(fields(self), args)}
        self.__customs |= kwargs.keys()
        old_init(self, *args, **kwargs)

    def merge(self, other):
        retain = {n: v for n, v in asdict(self).items() if n in self.__customs}
        kwargs = {**asdict(other), **retain}
        return type(self)(**kwargs)

    inst.__init__ = new_init
    inst.merge = merge
    return inst


@mergeable
@dataclass
class A:
    a: str
    b: float = 5
    c: list = field(default_factory=list)


fill = A('1', 1, [1])

x = A('a')
y = A('a', 2, [3])
z = A('a', 5, [])

print(x.merge(fill))  # good: A(a='a', b=1, c=[1])
print(y.merge(fill))  # good: A(a='a', b=2, c=[3])
print(z.merge(fill))  # good: A(a='a', b=5, c=[])

This is very likely to have some hard to guess side-effects though, so use at your own risk.

Upvotes: 4

Related Questions