OK 400
OK 400

Reputation: 831

Get shape using tf.function()

I have this function

train_step_signature = [
    tf.TensorSpec(shape=(None, None), dtype=tf.int32)
]

@tf.function(input_signature=train_step_signature)
def train_step(inp):
   # do stuff

I need to use the first dim of inp in one operation (a loop with range the shape 0 of inp), but when I try, and error pops out:

TypeError: 'NoneType' object cannot be interpreted as an integer

That is obviously because of the train_step_signature. I've seen that it works if I drop train_step_signature from the args, but it takes a lot of more time to process my code. My question is, is there anyway to get this first shape without loosing the train_step_signature arg?

Upvotes: 0

Views: 408

Answers (1)

Alexey Tochin
Alexey Tochin

Reputation: 683

You are probably using a pythonic loop like for i in range(inp.shape[0]) that is not possible because inp.shape[0] is None inside tf.function. Do not be afraid to use tf.while_loop inside tf.function.

Alternatively, try to use tf.shape(inp)[0] instead for inp.shape[0].

Upvotes: 1

Related Questions