Reputation: 610
I have built a generic python class for interacting with trained neural networks that are saved using "tf.saved_model.builder.SavedModelBuilder".
when I inherit from the class once with a given neural net, everything works correctly. however, when i inherit once more with a second neural net with different architecture, tensor flow throws an error that the shape doesn't fit: "Assign requires shapes of both tensors to match. lhs shape= [100,2] rhs shape= [400,4]"
these shapes are of the two different neural nets, but I don't see why would tensor flow remember about the first net.
Is there an easy way to fix this? and if not, what is the correct way of using multiple neural networks in a project?
here's the class code:
import tensorflow as tf
# prevents tensorflow from using GPU
config = tf.ConfigProto(
device_count={'GPU': 0}
)
class TFService():
def __init__(self, netName, inputName, outputName):
# opens a tensorflow session to use continously
self.session = tf.Session(config=config)
# loads the trained neural net
importDir = 'ocr/neural_nets/{}'.format(netName)
tf.saved_model.loader.load(
self.session,
[tf.saved_model.tag_constants.SERVING],
importDir
)
# saves the input and output tensors for the net
self.x = tf.get_default_graph().get_tensor_by_name(inputName)
self.y_pred = tf.get_default_graph().get_tensor_by_name(outputName)
def getPredictions(self, inputData):
# the object to feed the neural net
feed_dict = {self.x: inputData}
# runs the neural net and returns an array with the predictions
results = self.session.run(self.y_pred, feed_dict=feed_dict)
return results
Upvotes: 2
Views: 536
Reputation: 27042
Use different graphs for different nets.
You can do something like:
def __init__(self, netName, inputName, outputName):
self.graph = tf.Graph()
# opens a tensorflow session to use continously
# use self.graph as graph the the session
self.session = tf.Session(config=config, graph=self.graph)
tf.saved_model.loader.load(
self.session,
[tf.saved_model.tag_constants.SERVING],
importDir
)
# saves the input and output tensors for the net
self.x = self.graph.get_tensor_by_name(inputName)
self.y_pred = self.graph.get_tensor_by_name(outputName)
Upvotes: 2