user1519665
user1519665

Reputation: 511

tf.reshape versus (tf.expand_dims + tf.squeeze... etc)

Are there any performance improvements to be had from using tf.expand_dims() or tf.squeeze()... etc instead of tf.reshape()?

It feels like for readability tf.reshape() is often the best choice because you can perform any amount/combination of the reshaping steps in one line and you're absolutely sure what the final shape will be.

However, I've read that tf.reshape() makes copies of the data internally. Does tf.expand_dims() or tf.squeeze() not do this? Are there performance improvements or other reasons to use the competitors to tf.reshape()?

Upvotes: 2

Views: 1800

Answers (1)

Vlad
Vlad

Reputation: 8595

In TF1.x, in particular in TF1.12.0, all of the methods have the same performance on CPU:

import tensorflow as tf
with tf.device('cpu:0'):
    tensor = tf.random.normal(shape=(1, 3, 2))

    newaxis = tensor[tf.newaxis, ...]
    expanded_dims = tf.expand_dims(tensor, 0)
    reshaped = tf.reshape(tensor, (1, ) + tuple(tensor.get_shape().as_list()))

    squeezed = tf.squeeze(tensor)
    reshaped2 = tf.reshape(tensor, (3, 2))

sess = tf.Session()
%timeit -n 10000 sess.run(newaxis) # 84.3 µs ± 767 ns per loop 
%timeit -n 10000 sess.run(expanded_dims) # 83.3 µs ± 837 ns per loop
%timeit -n 10000 sess.run(reshaped) # 83.5 µs ± 946 ns per loop

%timeit -n 10000 sess.run(squeezed) # 81.9 µs ± 852 ns per loop
%timeit -n 10000 sess.run(reshaped2) # 83.9 µs ± 852 ns per loop

On GPU, tf.newaxis and tf.squeeze() are the fastest ones:

import tensorflow as tf
with tf.device('gpu:0'):
    tensor = tf.random.normal(shape=(1, 3, 2))

    newaxis = tensor[tf.newaxis, ...] # <-- Fastest to add new axis
    expanded_dims = tf.expand_dims(tensor, 0)
    reshaped = tf.reshape(tensor, (1, ) + tuple(tensor.get_shape().as_list()))

    squeezed = tf.squeeze(tensor) # <-- Fastest to remove unit-sized dims
    reshaped2 = tf.reshape(tensor, (3, 2))

sess = tf.Session()
%timeit -n 10000 sess.run(newaxis) # 133 µs ± 1.58 µs per loop
%timeit -n 10000 sess.run(expanded_dims) # 140 µs ± 1.4 µs per loop
%timeit -n 10000 sess.run(reshaped) #153 µs ± 1.22 µs per loop

%timeit -n 10000 sess.run(squeezed) # 134 µs ± 1.86 µs per loop
%timeit -n 10000 sess.run(reshaped2) # 153 µs ± 1.19 µs per loop

In TF2.0 tf.expand_dims() to add dimension and tf.squeeze() are the fastest (CPU):

import tensorflow as tf

tensor = tf.random.normal(shape=(1, 3, 2))

%timeit -n 10000 tf.expand_dims(tensor, 0) # 7.07 µs ± 162 ns per loop
%timeit -n 10000 tf.reshape(tensor, (1, ) + tuple(tensor.shape.as_list())) # 21.3 µs ± 326 ns per loop
%timeit -n 10000 tensor[tf.newaxis, ...] # 42.9 µs ± 565 ns per loop

%timeit -n 10000 tf.squeeze(tensor) # 9.85 µs ± 166 ns per loop
%timeit -n 10000 tf.reshape(tensor, shape=(3, 2)) # 18.2 µs ± 386 ns per loop 

Upvotes: 4

Related Questions