Alex Rothberg
Alex Rothberg

Reputation: 10993

"Quantize" Tensorflow Graph to float16

How do you convert a Tensorflow graph from using float32 to float16? Currently there are graph optimizations for quantization and conversion to eight bit ints.

Trying to load float32 weights into a float16 graph fails with:

DataLossError (see above for traceback): Invalid size in bundle entry: key model/conv5_1/biases; stored size 1536; expected size 768
     [[Node: save/RestoreV2_16 = RestoreV2[dtypes=[DT_HALF], _device="/job:localhost/replica:0/task:0/cpu:0"](_recv_save/Const_0, save/RestoreV2_16/tensor_names, save/RestoreV2_16/shape_and_slices)]]
     [[Node: save/RestoreV2_3/_39 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/gpu:0", send_device="/job:localhost/replica:0/task:0/cpu:0", send_device_incarnation=1, tensor_name="edge_107_save/RestoreV2_3", tensor_type=DT_HALF, _device="/job:localhost/replica:0/task:0/gpu:0"]()]]

Upvotes: 11

Views: 5596

Answers (2)

Ramiro R.C.
Ramiro R.C.

Reputation: 473

I had this issue but I was loading a sub-graph which contained some variables that needed to be loaded or converted and some that not. Based on @Jendrik, here is a function that returns the assign operation, given a dictionary that maps the saved variables to the new graph:

def assign_and_convert_halfPrecision(restore_dictinary, CHECKPOINT_PATH):

    # Iterate over the dictionary containing the variables to load
    for variable_name_old, varible_new in restore_dictinary.items():

        # Load the variable from the checkpoint
        var = tf.contrib.framework.load_variable(CHECKPOINT_PATH, variable_name_old)

        # Assign to new graph
        if(var.dtype == np.float32) and (varible_new.dtype == np.float16):
            # If the variable is float16 in the new graph, we cast it
            tf.add_to_collection('assignOps', varible_new.assign(tf.cast(var, tf.float16)))
        else:
            # If the variable in the old graph is float16 or the new variable is float32, 
            # we load it directly
            tf.add_to_collection('assignOps', varible_new.assign(var))


    # Return the operation
    return tf.get_collection('assignOps')

To use it, just do:

# Create a trivial dictionary (all custom loading can be added here, like change of scope names)
restore_dictionary = dict()
for a in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=''):
    restore_dictionary[a.name[:-2]] = a

# Create the assignment and conversion op
assign_operation = assign_and_convert_halfPrecision(restore_dictionary, CHECKPOINT_PATH)

# Load
sess.run(assign_operation)

The loading can be controlled by modifying the dictionary, avoiding variables that should not be loaded or changing the scope of the variables to load.

Upvotes: 0

Jendrik
Jendrik

Reputation: 186

I think my solution is definitely not the best and not the one which is the most straight forward, but as nobody else posted anything:

What I did was training the network with full precision and saved them in a checkpoint. Then I built a copy of the network setting all variables desired to a dtype of tf.float16 and removing all the training nodes. Finally, I loaded and casted the variables the following way:

previous_variables = [
  var_name for var_name, _
  in tf.contrib.framework.list_variables('path-to-checkpoint-file')]
#print(previous_variables)
sess.run(tf.global_variables_initializer())
restore_map = {}
for variable in tf.global_variables():
    if variable.op.name in previous_variables:
        var = tf.contrib.framework.load_variable(
            'path-to-checkpoint-file', variable.op.name)
        if(var.dtype == np.float32):
            tf.add_to_collection('assignOps', variable.assign(
                tf.cast(var, tf.float16)))
        else:
            tf.add_to_collection('assignOps', variable.assign(var))
sess.run(tf.get_collection('assignOps'))

This obviously has issues if there are tensors of float32 that you don't want to convert, which I luckily don't have as I want to convert all my nodes to float16 precision. In case you have those you could further filter with other if statements. I hope this answers your question.

Upvotes: 8

Related Questions