剑荡八荒
剑荡八荒

Reputation: 13

TensorFlow use dataset to replace function feed_dict

when I learn a tensorflow project,find one line code:

cls_prob, box_pred = sess.run([output_cls_prob, output_box_pred], feed_dict={input_img: blob})

But, this line code It took a lot of time. (use CPU need 15 seconds...┭┮﹏┭┮)

By consulting information, I find use function 'dataset' could solve this problem which took a lot of time, How should I use it?

source of 'blob':

img = cv2.imread('./imgs/001.jpg')
img_scale = float(600) / min(img_data.shape[0], img_data.shape[1])
if np.round(img_scale * max(img_data.shape[0], img_data.shape[1])) > 1200:
    img_scale = float(1200) / max(img_data.shape[0], img_data.shape[1])
img_data = cv2.resize(img_data, None, None, fx=img_scale, fy=img_scale, interpolation=cv2.INTER_LINEAR)
img_orig = img_data.astype(np.float32, copy=True)
blob = np.zeros((1, img_data.shape[0], img_data.shape[1], 3),dtype=np.float32)
blob[0, 0:img_data.shape[0], 0:img_data.shape[1], :] = img_orig

source of 'output_cls_prob'&'output_box_pred'&'input_img':

# Actually,read PB model...
input_img = sess.graph.get_tensor_by_name('Placeholder:0')
output_cls_prob = sess.graph.get_tensor_by_name('Reshape_2:0')
output_box_pred = sess.graph.get_tensor_by_name('rpn_bbox_pred/Reshape_1:0')

Parameter type:

blob:type 'numpy.ndarray'

output_cls_prob:class 'tensorflow.python.framework.ops.Tensor'

output_box_pred:class 'tensorflow.python.framework.ops.Tensor'

input_img:class 'tensorflow.python.framework.ops.Tensor'

Upvotes: 1

Views: 943

Answers (2)

rlys
rlys

Reputation: 480

tf.data is the recommended API for tensorflow input pipelines. Here is a tutorial on tensorflow.org. For your example, the section "Decoding image data and resizing it" could be most useful. For example, you could do something like:

# Reads an image from a file, decodes it into a dense tensor, and resizes it
# to a fixed shape.
def _parse_function(filename):
  image_string = tf.read_file(filename)
  image_decoded = tf.image.decode_jpeg(image_string)
  image_resized = tf.image.resize_images(image_decoded, [new_width, new_height])
  image_resized = tf.expand_dims(image_resized, 0)  # Adds size 1 dimension
  return image_resized

# A vector of filenames.
filenames = tf.constant(["./imgs/001.jpg", ...])

dataset = tf.data.Dataset.from_tensor_slices(filenames)
dataset = dataset.map(_parse_function)

And instead of having input_img be a placeholder, change:

input_img = tf.placeholder(tf.float32)
output_class_prob, output_class_pred = (... use input_img ...)

to:

iterator = dataset.make_one_shot_iterator()
input_img = iterator.get_next()
output_class_prob, output_class_pred = (... use input_img ...)

Upvotes: 1

m33n
m33n

Reputation: 1751

First of all you should know that the use of Dataset API has a great impact in performance when multiples GPUs are used... Otherwise is almost identical to feed_dict. I recommend you to read this other answer from a TF developer, it has almost everything one needs to know to create a mental image of the benefits of this new API.

Upvotes: 0

Related Questions