Reputation: 89
I'm implementing a vehicle re-identification model in Tensorflow.
In a separate file, I have another subclassed model which I have saved using model.save(path)
. In this file, I import it using keras.models.load_model(path)
and use that model as part of my subclassed model. I can confirm that this external model trains without issues.
I also create a subclassed layer within this file (ConvexCombination).
I use a custom training step, seen in train_step()
and a custom forward pass, seen in call()
.
I have read that this error is commonly caused by the incorrectly-shaped output from the ImageDataGenerator class, which I am using, but I can't figure out where this issue arises.
One thing which would be useful to understand this issue is exactly what is the data object passed into train_step(self, data)
? Is it just a single batch from the ImageDataGenerator? If this is the case I'm unsure as to where to problem is with the shapes.
The full error and my code can be seen here: https://vehiclereidjupyternotebook.s3.eu-west-2.amazonaws.com/Full_pipeline-2.html
Upvotes: 0
Views: 132
Reputation: 89
I was calling model.fit()
on my imported model instead of the actual model.
Upvotes: 1