Reputation: 587
I want to train GANs with tensorflow and then export the generator and the discriminator as tensorflow_hub modules.
For that:
- I define my GAN architecture with tensorflow
- train it and save checkpoints
- create a module_spec with different tags like:
(set(), {'batch_size': 8, 'model': 'gen'})
({'bs8', 'gen'}, {'batch_size': 8, 'model': 'gen'})
({'bs8', 'disc'}, {'batch_size': 8, 'model': 'disc'})
- export with module_spec at tf_hub_path using a checkpoint_path that I saved during training
Then, I can load my generator with the command :
hub.Module(tf_hub_path, tags={"gen", "bs8"})
But, when I try to load the discriminator using a similar command :
hub.Module(tf_hub_path, tags={"disc", "bs8"})
I got the error:
ValueError: Tensor discriminator/linear/bias is not found in b'/tf_hub/variables/variables' checkpoint {'generator/fc_noise/kernel': [2, 48], 'generator/fc_noise/bias': [48]}
So, I concluded that the variables present in the discriminator weren't saved in the module on disk. I checked the different sources of error that I imagined:
Then, I was wondering if the checkpoint were correctly saving all the variables in my graph.
checkpoint_path = tf.train.latest_checkpoint(self.model_dir)
inspect_list = tf.train.list_variables(checkpoint_path)
print(inspect_list)
[('disc_step_1/beta1_power', []),
('disc_step_1/beta2_power', []),
('discriminator/linear/bias', [1]),
('discriminator/linear/bias/d_opt', [1]),
('discriminator/linear/bias/d_opt_1', [1]),
('discriminator/linear/kernel', [3, 1]),
('discriminator/linear/kernel/d_opt', [3, 1]),
('discriminator/linear/kernel/d_opt_1', [3, 1]),
('gen_step/beta1_power', []),
('gen_step/beta2_power', []),
('generator/fc_noise/bias', [48]),
('generator/fc_noise/bias/g_opt', [48]),
('generator/fc_noise/bias/g_opt_1', [48]),
('generator/fc_noise/kernel', [2, 48]),
('generator/fc_noise/kernel/g_opt', [2, 48]),
('generator/fc_noise/kernel/g_opt_1', [2, 48]),
('global_step', []),
('global_step_disc', [])]
Thus, I saw that all the variables were correctly saved inside the checkpoints. Only the two variables related to the generator were correctly exported in the tf hub module on disk.
Finally, I suppose that my error comes from the :
module_spec.export(tf_hub_path, checkpoint_path=checkpoint_path)
Only the tag "gen" is taken into account to export the variables from checkpoint_path. I also checked that the name of the variables were corresponding between the module.variable_map and the list variables from checkpoint path. Here is the variable map for the module with tag "disc":
print(module.variable_map)
{'discriminator/linear/bias': <tf.Variable 'module_8/discriminator/linear/bias:0' shape=(1,) dtype=float32>, 'discriminator/linear/kernel': <tf.Variable 'module_8/discriminator/linear/kernel:0' shape=(3, 1) dtype=float32>}
I have
Thanks for your help
Upvotes: 1
Views: 197
Reputation: 587
I found a way to handle this problem, even though I think it's not the cleanest way to do this:
The next line of code define the module by default, when calling hub.Module with no tags:
(set(), {'batch_size': 8, 'model': 'gen'})
In fact, I realized that this set of parameters was defining which graph was exported through module_spec.export. It explains why I was able to access the variables of the generator when importing the module, but not the one of the discriminator.
Thus, I decided to use by default this set of parameters :
(set(), {'batch_size': 8, 'model': 'both'})
And, in the _module_fn method called by hub.create_module_spec, I defined the inputs (and respectively the outputs) of both the generator and the discriminator as inputs (respectively outputs) of my model. Thus, when exporting the module_spec, I am able to access all the variables of the graph.
Upvotes: 1