Antoine
Antoine

Reputation: 587

tensorflow_hub : module spec export with checkpoint path doesn't save all variables

I want to train GANs with tensorflow and then export the generator and the discriminator as tensorflow_hub modules.
For that:
- I define my GAN architecture with tensorflow
- train it and save checkpoints
- create a module_spec with different tags like:
(set(), {'batch_size': 8, 'model': 'gen'})
({'bs8', 'gen'}, {'batch_size': 8, 'model': 'gen'})
({'bs8', 'disc'}, {'batch_size': 8, 'model': 'disc'})
- export with module_spec at tf_hub_path using a checkpoint_path that I saved during training

Then, I can load my generator with the command :

hub.Module(tf_hub_path, tags={"gen", "bs8"})

But, when I try to load the discriminator using a similar command :

hub.Module(tf_hub_path, tags={"disc", "bs8"})

I got the error:

ValueError: Tensor discriminator/linear/bias is not found in b'/tf_hub/variables/variables' checkpoint {'generator/fc_noise/kernel': [2, 48], 'generator/fc_noise/bias': [48]}

So, I concluded that the variables present in the discriminator weren't saved in the module on disk. I checked the different sources of error that I imagined:

Finally, I suppose that my error comes from the :

module_spec.export(tf_hub_path, checkpoint_path=checkpoint_path)

Only the tag "gen" is taken into account to export the variables from checkpoint_path. I also checked that the name of the variables were corresponding between the module.variable_map and the list variables from checkpoint path. Here is the variable map for the module with tag "disc":

print(module.variable_map)
{'discriminator/linear/bias': <tf.Variable 'module_8/discriminator/linear/bias:0' shape=(1,) dtype=float32>, 'discriminator/linear/kernel': <tf.Variable 'module_8/discriminator/linear/kernel:0' shape=(3, 1) dtype=float32>}

I have

Thanks for your help

Upvotes: 1

Views: 197

Answers (1)

Antoine
Antoine

Reputation: 587

I found a way to handle this problem, even though I think it's not the cleanest way to do this:

The next line of code define the module by default, when calling hub.Module with no tags:

(set(), {'batch_size': 8, 'model': 'gen'})

In fact, I realized that this set of parameters was defining which graph was exported through module_spec.export. It explains why I was able to access the variables of the generator when importing the module, but not the one of the discriminator.
Thus, I decided to use by default this set of parameters :

(set(), {'batch_size': 8, 'model': 'both'})

And, in the _module_fn method called by hub.create_module_spec, I defined the inputs (and respectively the outputs) of both the generator and the discriminator as inputs (respectively outputs) of my model. Thus, when exporting the module_spec, I am able to access all the variables of the graph.

Upvotes: 1

Related Questions