Reputation: 8352
Can methods of dataclass
es be decorated with @tf.function
? A straight-forward test
@dataclass
class Doubler:
@tf.function
def double(a):
return a*2
gives an error
d = Doubler()
d.double(2)
saying that Doubler
is not hashable (TypeError: unhashable type: 'Doubler'
), which I believe is because hashing is disabled by default for dataclasses. Is this a general limitation or can it be made to work? I found this answer that seems to indicate that it doesn't work.
Upvotes: 1
Views: 385
Reputation: 26708
I think the official recommendation from Tensorflow is to use tf.experimental.ExtensionType
:
import tensorflow as tf
class Doubler(tf.experimental.ExtensionType):
@tf.function
def double(self, a):
return a*2
d = Doubler()
d.double(2)
According to the docs:
The tf.experimental.ExtensionType base class works similarly to typing.NamedTuple and @dataclasses.dataclass from the standard Python library. In particular, it automatically adds a constructor and special methods (such as repr and eq) based on the field type annotations.
If you read further down in the docs, you will see what features are provided by default.
Upvotes: 2