Rainb
Rainb

Reputation: 2465

How can I "see" the model/network when loading a model from tfhub?

I'm new to this topic, so forgive me my lack of knowledge. There is a very good model called inception resnet v2 that basically works like this, the input is an image and outputs a list of predictions with their positions and bounded rectangles. I find this very useful, and I thought of using the already worked model in order to recognize things that it now can't (for example if a human is wearing a mask or not). Yes, I wanted to add a new recognition class to the model.

import tensorflow as tf
import tensorflow_hub as hub
mod = hub.load("https://tfhub.dev/google/faster_rcnn/openimages_v4/inception_resnet_v2/1")

mod is an object of type tensorflow.python.training.tracking.tracking.AutoTrackable, reading the documentation (that was only available on the source code was a bit hard to understand without context) and I tried to inspect some of it's properties in order to see if I could figure it out by myself.

And well, I didn't. How can I see the network, the layers, the weights? the fit methods, Is it's all abstracted away?. Can I convert it to keras? I want to experiment with it, see if I can modify it, and see if I could export the model to another representation, for example pytorch.

I wanted to do this because I thought it'd be better to modify an already working model instead of creating one from scratch. Also because I'm not good at training models myself.

Upvotes: 0

Views: 457

Answers (2)

kempy
kempy

Reputation: 616

You can dir the loaded model asset to see what's defined on it

m = hub.load(handle)
dir(model)

As mentioned in the other answer, you can also look at the signatures with print(m.signatures)

Hub models are SavedModel assets and do not have a keras .fit method on them. If you want to train the model from scratch, you'll need to go to the source code.

Some models have more extensive exported interfaces including access to individual layers, but this model does not.

Upvotes: 1

Maria Belyalova
Maria Belyalova

Reputation: 11

I've run into this issue too. Tensorflow hub guide says:

This error frequently arises when loading models in TF1 Hub format with the hub.load() API in TF2. Adding the correct signature should fix this problem.

mod = hub.load(handle).signatures['default']

As an example, you can see this notebook.

Upvotes: 1

Related Questions