Llewyn S
Llewyn S

Reputation: 16

TensorFlow 2 How to use *args in tf.function?

Update:

Did a bit more testing and I can't reproduce the behaviour with:

import tensorflow as tf
import numpy as np

@tf.function
def tf_being_unpythonic(an_input, another_input):
    return an_input + another_input

@tf.function
def example(*inputs, other_args = True):
    return tf_being_unpythonic(*inputs)

class TestClass(tf.keras.Model):
    def __init__(self, a, b):
        super().__init__()
        self.a= a
        self.b = b

    @tf.function
    def call(self, *inps, some_kwarg=False):
        if some_kwarg:
            return self.a(*inps)
        return self.b(*inps)

class Model(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.inps = tf.keras.layers.Flatten()
        self.hl1 = tf.keras.layers.Dense(5)
        self.hl2 = tf.keras.layers.Dense(4)
        self.out = tf.keras.layers.Dense(1)

    @tf.function
    def call(self,observation):
        x = self.inps(observation)
        x = self.hl1(x)
        x = self.hl2(x)
        return self.out(x)


class Model2(Model):
    def __init__(self):
        super().__init__()
        self.prein = tf.keras.layers.Concatenate()

    @tf.function
    def call(self,b,c):
        x = self.prein([b,c])
        return super().call(x)   

am = Model()
pm = Model2()
test = TestClass(am,pm)

a = np.random.normal(size=(1,2,3))
b = np.random.normal(size=(1,2,4))

test(a,some_kwarg=True)
test(a,b) 

So it's probably a bug somewhere else.

@tf.function
def call(self, *inp, target=False, training=False):
    if not len(inp):
        raise ValueError("Call requires some input")
    if target:
        return self._target_network(*inp, training)
    return self._network(*inp, training)

I get:

ValueError: Input 0 of layer flatten is incompatible with the layer: : expected min_ndim=1, found ndim=0. Full shape received: []

But print(inp) gives:

(<tf.Tensor 'inp_0:0' shape=(1, 3) dtype=float32>,) 

I've since edited and was just uncommited toy code so can't investigate further. Will leave the question here so that everyone who doesn't get this issue won't have something to read.

Upvotes: 0

Views: 1353

Answers (2)

Dan Moldovan
Dan Moldovan

Reputation: 975

This may have been a bug that was resolved recently. *args and **kwargs should work fine.

Upvotes: 0

AlexisBRENON
AlexisBRENON

Reputation: 3079

I don't think that using a *args construct is a good practice for a tf.function. As you can see, most of the TF functions accepting a variable number of inputs use a tuple.

So, you can rewrite your function signature as:

def call(self, inputs, target=False, training=False)

and calling it with:

instance.call((i1, i2, i3), [...])
# instead of instance.call(i1, i2, i3, [...])

Edit

By the way, I don't see any error while using tf.function with a *args construct:

import tensorflow as tf

@tf.function
def call(*inp, target=False, training=False):
    if not len(inp):
        raise ValueError("Call requires some input")
    return inp[0]

def main():
    print(call(1))
    print(call(2, 2))
    print(call(3, 3, 3))


if __name__ == '__main__':
    main()
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)
tf.Tensor(3, shape=(), dtype=int32)

So you should provide us more informations about what you try to do and where the error is.

Upvotes: 1

Related Questions