Reputation: 191
for my project I'm trying to do inference based on my trained model stored in saved_model.pb. I doubt that the mistake is due to my code you can see over here but more likely is due to an installation problem:
from PIL import Image
import numpy as np
import scipy
from scipy import misc
import matplotlib.pyplot as plt
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
with tf.Graph().as_default() as graph: # Set default graph as graph
with tf.Session() as sess:
# Load the graph in graph_def
print("load graph")
# We load the protobuf file from the disk and parse it to retrive the unserialized graph_drf
with gfile.FastGFile("saved_model.pb",'rb') as f:
from scipy.io import wavfile
samplerate, data = wavfile.read('sound.wav')
# Set FCN graph to the default graph
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
sess.graph.as_default()
# Import a graph_def into the current default Graph (In this case, the weights are (typically) embedded in the graph)
tf.import_graph_def(
graph_def,
input_map=None,
return_elements=None,
name="",
op_dict=None,
producer_op_list=None
)
# Print the name of operations in the session
for op in graph.get_operations():
print("Operation Name :",op.name) # Operation name
print("Tensor Stats :",str(op.values())) # Tensor name
# INFERENCE Here
l_input = graph.get_tensor_by_name('Inputs/fifo_queue_Dequeue:0') # Input Tensor
l_output = graph.get_tensor_by_name('upscore32/conv2d_transpose:0') # Output Tensor
print("Shape of input : ", tf.shape(l_input))
f.global_variables_initializer()
# Run model on single image
Session_out = sess.run( m_output, feed_dict = {m_input : data} )
print("Predicted class:", class_names[Session_out[0].argmax()] )
The traceback is the following
Traceback (most recent call last):
File "/home/pi/model_inference/test.py", line 11, in <module>
graph_def.ParseFromString(f.read())
File "/usr/local/lib/python3.7/dist-packages/google/protobuf/message.py", line 199, in ParseFromString
return self.MergeFromString(serialized)
File "/usr/local/lib/python3.7/dist-packages/google/protobuf/internal/python_message.py", line 1145, in MergeFromString
if self._InternalParse(serialized, 0, length) != length:
File "/usr/local/lib/python3.7/dist-packages/google/protobuf/internal/python_message.py", line 1212, in InternalParse
pos = field_decoder(buffer, new_pos, end, self, field_dict)
File "/usr/local/lib/python3.7/dist-packages/google/protobuf/internal/decoder.py", line 754, in DecodeField
if value._InternalParse(buffer, pos, new_pos) != new_pos:
File "/usr/local/lib/python3.7/dist-packages/google/protobuf/internal/python_message.py", line 1212, in InternalParse
pos = field_decoder(buffer, new_pos, end, self, field_dict)
File "/usr/local/lib/python3.7/dist-packages/google/protobuf/internal/decoder.py", line 733, in DecodeRepeatedField
if value.add()._InternalParse(buffer, pos, new_pos) != new_pos:
File "/usr/local/lib/python3.7/dist-packages/google/protobuf/internal/python_message.py", line 1212, in InternalParse
pos = field_decoder(buffer, new_pos, end, self, field_dict)
File "/usr/local/lib/python3.7/dist-packages/google/protobuf/internal/decoder.py", line 888, in DecodeMap
if submsg._InternalParse(buffer, pos, new_pos) != new_pos:
File "/usr/local/lib/python3.7/dist-packages/google/protobuf/internal/python_message.py", line 1199, in InternalParse
buffer, new_pos, wire_type) # pylint: disable=protected-access
File "/usr/local/lib/python3.7/dist-packages/google/protobuf/internal/decoder.py", line 989, in _DecodeUnknownField
(data, pos) = _DecodeUnknownFieldSet(buffer, pos)
File "/usr/local/lib/python3.7/dist-packages/google/protobuf/internal/decoder.py", line 968, in _DecodeUnknownFieldSet
(data, pos) = _DecodeUnknownField(buffer, pos, wire_type)
File "/usr/local/lib/python3.7/dist-packages/google/protobuf/internal/decoder.py", line 993, in _DecodeUnknownField
raise _DecodeError('Wrong wire type in tag.')
google.protobuf.message.DecodeError: Wrong wire type in tag.
Important to note is that I'm trying this out on a raspberry pi v4 (thus linux running on it). I would be glad about any hint what to do. Thanks in advance!
Upvotes: 1
Views: 14322
Reputation: 61
Try using tf.saved_model.load(path)
where path
is the path to your saved model's folder containing the assets and variables folder.follow this link to the tensorflow object-detection api inference tutorial.
Upvotes: 2
Reputation: 2526
It looks like your file "saved_model.pb"
is not a saved (wireformat) protobuffer of the message type GraphDef
. Maybe you can look how that was saved and find some instructions on how to load it back? Just guessing from the name, can it be a keras model and you have to use tf.keras.models.load_model
?
Upvotes: 0