beardybear
beardybear

Reputation: 159

Tensorflow 2: building an estimator with custom layer and tf.map_fn

I am trying to build an TensorFlow2 estimator from a custom Keras model. The model takes as input a tensor of shape [batch_size, n, h, w, c]. I need to apply a CNN on each [n, h, w, c] tensor from the back. To do that, I am using tf.map_fn:

make_model(params):
    batch = Input(shape=[n, h, w, c], batch_size=params.batch_size, name='inputs')
    feature_extraction = SomeCustomLayer()
    x = tf.map_fn(feature_extraction, batch)
    ...
    softmax_score = softmax(x)
    return tf.keras.Model(inputs=batch, outputs=softmax_score, name='custom_model')

When I compile and convert the model into an estimator, everything run fine:

model = make_model(params)
model.compile(optimizer=optimizer, loss=loss_function, metrics=metrics_list)
estimator = tf.keras.estimator.model_to_estimator(milcnn)

However, when I start the training, it fails miserably:

training_log = estimator.train(input_fn=lambda: training_dataset)
...
WARNING:tensorflow:The graph (<tensorflow.python.framework.ops.Graph object at 0x7fa5ebbdf6d0>) of the iterator is different from the graph (<tensorflow.python.framework.ops.Graph object at 0x7fa618050910>) the dataset: tf.Tensor(<unprintable>, shape=(), dtype=variant) was created in. If you are using the Estimator API, make sure that no part of the dataset returned by the `input_fn` function is defined outside the `input_fn` function. Please ensure that all datasets in the pipeline are created in the same graph as the iterator. NOTE: This warning will become an error in future versions of TensorFlow.
...
Traceback (most recent call last):
  File "/opt/anaconda3/envs/direx/lib/python3.7/site-packages/tensorflow_core/python/data/ops/dataset_ops.py", line 2104, in make_initializable_iterator
    return dataset._make_initializable_iterator(shared_name)  # pylint: disable=protected-access
AttributeError: 'BatchDataset' object has no attribute '_make_initializable_iterator'

During handling of the above exception, another exception occurred:
...
RuntimeError: Attempting to capture an EagerTensor without building a function.

I am quite confused at this stage. My dataset is working perfectly with my model when I am using it directly as a Keras model. So, I expect it to be valid as well with the Estimator interface. Is the issue truely coming from a misuse of the estimator input_fn, or is it coming from the way I build the estimator or the Keras model?

Upvotes: 0

Views: 558

Answers (1)

beardybear
beardybear

Reputation: 159

I figured out the issue. I was initializing my dataset prior to the training loop:

dataset = input_fn(params)
 estimator.train(input_fn=lambda: training_dataset)

In fact, you have to directly pass input_fn as argument:

estimator.train(input_fn=lambda: input_fn(params))

Upvotes: 2

Related Questions