Reputation: 53
I am using TensorFlow for Poets to detect features in clothing images. I have trained 4 different models(sleeve, shape, length & hemline). Now i pass image urls to each of the model and store the result. Since i have huge data (100k images), so using spark to broadcast 4 models once and passing image RDD to detect the features. It's taking exponential time. It starts from 3 secs/images & keep increasing the execution time. When scripts already sensed 10k images, its execution time reaches to 8 secs/ images. I am new to Tensorflow, will be very thankful if get any idea to make the execution time linear.
def getLabelDresses(file_name):
resultDict = {}
t = read_tensor_from_image_file(file_name,
input_height=input_height,
input_width=input_width,
input_mean=input_mean,
input_std=input_std)
input_name = "import/" + input_layer
output_name = "import/" + output_layer
with tf.Graph().as_default() as g:
graph_def = tf.GraphDef()
graph_def.ParseFromString(model_data_hemline.value)
tf.import_graph_def(graph_def)
input_operation_hemline = g.get_operation_by_name(input_name);
output_operation_hemline = g.get_operation_by_name(output_name);
with tf.Session() as sess:
results = sess.run(output_operation_hemline.outputs[0],{input_operation_hemline.outputs[0]: t})
results = np.squeeze(results)
top_k = results.argsort()[-1:][::-1]
labels = load_labels(label_file_hemline)
resultDict['hemline'] = labels[top_k[0]]
with tf.Graph().as_default() as g:
graph_def = tf.GraphDef()
graph_def.ParseFromString(model_data_shape.value)
tf.import_graph_def(graph_def)
input_operation_shape = g.get_operation_by_name(input_name);
output_operation_shape = g.get_operation_by_name(output_name);
with tf.Session() as sess:
results = sess.run(output_operation_shape.outputs[0],{input_operation_shape.outputs[0]: t})
results = np.squeeze(results)
top_k = results.argsort()[-1:][::-1]
labels = load_labels(label_file_shape)
resultDict['shape'] = labels[top_k[0]]
with tf.Graph().as_default() as g:
graph_def = tf.GraphDef()
graph_def.ParseFromString(model_data_length.value)
tf.import_graph_def(graph_def)
input_operation_length = g.get_operation_by_name(input_name);
output_operation_length = g.get_operation_by_name(output_name);
with tf.Session() as sess:
results = sess.run(output_operation_length.outputs[0],{input_operation_length.outputs[0]: t})
results = np.squeeze(results)
top_k = results.argsort()[-1:][::-1]
labels = load_labels(label_file_length)
resultDict['length'] = labels[top_k[0]]
with tf.Graph().as_default() as g:
graph_def = tf.GraphDef()
graph_def.ParseFromString(model_data_sleeve.value)
tf.import_graph_def(graph_def)
input_operation_sleeve = g.get_operation_by_name(input_name);
output_operation_sleeve = g.get_operation_by_name(output_name);
with tf.Session() as sess:
results = sess.run(output_operation_sleeve.outputs[0],{input_operation_sleeve.outputs[0]: t})
results = np.squeeze(results)
top_k = results.argsort()[-1:][::-1]
labels = load_labels(label_file_sleeve)
resultDict['sleeve'] = labels[top_k[0]]
return resultDict;
model_file_hemline = "/home/ubuntu/amit/feature-sensing/tf_files/attribute/dresses/hemline/retrained_graph_hemline.pb"
label_file_hemline = "/home/ubuntu/amit/feature-sensing/tf_files/attribute/dresses/hemline/retrained_labels_hemline.txt"
model_file_length = "/home/ubuntu/amit/feature-sensing/tf_files/attribute/dresses/length/retrained_graph_length.pb"
label_file_length = "/home/ubuntu/amit/feature-sensing/tf_files/attribute/dresses/length/retrained_labels_length.txt"
model_file_shape = "/home/ubuntu/amit/feature-sensing/tf_files/attribute/dresses/shape/retrained_graph_shape.pb"
label_file_shape = "/home/ubuntu/amit/feature-sensing/tf_files/attribute/dresses/shape/retrained_labels_shape.txt"
model_file_sleeve = "/home/ubuntu/amit/feature-sensing/tf_files/attribute/dresses/sleeve/retrained_graph_sleeve.pb"
label_file_sleeve = "/home/ubuntu/amit/feature-sensing/tf_files/attribute/dresses/sleeve/retrained_labels_sleeve.txt"
with gfile.FastGFile(model_file_hemline, "rb") as f:
model_data = f.read()
model_data_hemline = sc.broadcast(model_data)
with gfile.FastGFile(model_file_length, "rb") as f:
model_data = f.read()
model_data_length = sc.broadcast(model_data)
with gfile.FastGFile(model_file_shape, "rb") as f:
model_data = f.read()
model_data_shape = sc.broadcast(model_data)
with gfile.FastGFile(model_file_sleeve, "rb") as f:
model_data = f.read()
model_data_sleeve = sc.broadcast(model_data)
def calculate(row):
path = "/tmp/"+row.guid
url = row.modelno
print(path, url)
if(url is not None):
import urllib.request
urllib.request.urlretrieve(url, path)
t1=time.time()
result = getLabelDresses(path)
print(time.time()-t1)
print(result)
return row
return row
product2.rdd.map(calculate).collect()
Upvotes: 0
Views: 68
Reputation: 2156
Every call to getLabelDresses
in your code adds operations to the graph.
Split your code into setup (model loading) part, executed once and execution part that is executed for each image. The latter should only contain calls to Session.run
.
Another option is to clear graph before processing next image using tf.reset_default_graph
. But it is less preferable.
Upvotes: 1