fuenfundachtzig
fuenfundachtzig

Reputation: 8352

Can tensorflow's tf.function be used with methods of dataclasses?

Can methods of dataclasses be decorated with @tf.function? A straight-forward test

@dataclass
class Doubler:
    @tf.function
    def double(a):
        return a*2

gives an error

d = Doubler()
d.double(2)

saying that Doubler is not hashable (TypeError: unhashable type: 'Doubler'), which I believe is because hashing is disabled by default for dataclasses. Is this a general limitation or can it be made to work? I found this answer that seems to indicate that it doesn't work.

Upvotes: 1

Views: 385

Answers (1)

AloneTogether
AloneTogether

Reputation: 26708

I think the official recommendation from Tensorflow is to use tf.experimental.ExtensionType:

import tensorflow as tf

class Doubler(tf.experimental.ExtensionType):
    @tf.function
    def double(self, a):
        return a*2
d = Doubler()
d.double(2)

According to the docs:

The tf.experimental.ExtensionType base class works similarly to typing.NamedTuple and @dataclasses.dataclass from the standard Python library. In particular, it automatically adds a constructor and special methods (such as repr and eq) based on the field type annotations.

If you read further down in the docs, you will see what features are provided by default.

Upvotes: 2

Related Questions