bob
bob

Reputation: 152

Tensorflow dataset interleave from_generator throws InvalidArgumentError

I have a generator which I am trying to interleave:

def hello(i):
  for j in tf.range(i):
    yield j

ds = tf.data.Dataset.range(10).interleave(
       lambda ind: tf.data.Dataset.from_generator(lambda: hello(ind), output_types=(tf.int32,)))

for x in ds.take(1):
  print(x)

But I get this error:

TypeError: An op outside of the function building code is being passed
a "Graph" tensor. It is possible to have Graph tensors
leak out of the function building context by including a
tf.init_scope in your function building code.
For example, the following function will fail:
  @tf.function
  def has_init_scope():
    my_constant = tf.constant(1.)
    with tf.init_scope():
      added = my_constant * 2
The graph tensor has name: args_0:0


     [[{{node PyFunc}}]]

Tensorflow version: 2.3.2

Upvotes: 2

Views: 775

Answers (1)

Lescurel
Lescurel

Reputation: 11631

The problem is in the way you are building your generator function. Instead of using a lambda, you should use the args keyword argument to specify the argument passed to your generator function.

ds = tf.data.Dataset.range(10).interleave(
    lambda ind: tf.data.Dataset.from_generator(
        hello, args=(ind,), output_types=tf.int32
    )
)

For TF2.4, note that you should use output_signature instead of output_types, as the latter is deprecated. (In that case output_signature=tf.TensorSpec(shape=(), dtype=tf.int32,)).

Upvotes: 2

Related Questions