Reputation: 1565
I created a Neural Network using TensorFlow via Keras API in Python that leverages the ResNet50
pretrained network to be able to classify 133 different breeds of dogs.
I now want to be able to deploy this model so that it can be used through TensorFlow.js, however I'm having difficulties in getting ResNet50
to work. I'm capable of being able to transfer a NN that I created from scratch to TensorFlow.js without a problem, but transferring one using a pretrained network isn't as straightforward.
Here is the Python code that I'm trying to adapt:
from keras.applications.resnet50 import ResNet50
ResNet50_model = ResNet50(weights="imagenet") # download ImageNet challenge weights
def extractResNet50(tensor): # tensor shape is (1, 224, 224, 3)
return ResNet50(weights='imagenet', include_top=False, pooling="avg").predict(preprocess_input(tensor))
def dogBreed(img_path):
tensor = convertToTensor(img_path) # I can do this in TF.js already with no issue
resnetTensor = extractResNet50(tensor) # tensor returned in shape (1, 2048)
resnetTensor = np.expand_dims(resnetTensor, axis=0) # repeat this line 2 more times to get shape (1, 1, 1, 1, 2048)
# code below I can convert to TF.js without issue
prediction = model.predict(resnetTensor[0])
How would I convert everything above, except for code lines 1 and 4 of dogBreed()
, to be used in TensorFlow.js?
Upvotes: 0
Views: 1591
Reputation: 18371
Resnet is such a big network that it has not been imported yet on the browser and I doubt if it one day will. At least it is not as for the latest version of tensorflowJs (version 0.14)
One you can do on the other hand is to save your Python keras model and then to import the frozen model on Js for prediction.
Update: You are using resnet50 as the feature extractor for your model. In that case the frozen model that you will save needs to contain both Resnet50 and your model topology and weights.
1- Instead of having two separated architecture in python, create only one network using tensorflow directly and not keras. Then the frozen model will contain Resnet. This might not work properly in the browser as the size of Resnet is quite big (I have not tested it myself)
2- Instead of using Resnet in the browser, consider using coco-ssd
or mobilenet
that can be used in the browser as feature-extractor. You can see how to use them on the official repo
Upvotes: 1