Federico Màlato
Federico Màlato

Reputation: 125

How do I reload weights and biases on a CNN with Tensorflow?

I trained a model with tensorflow and exported the meta graph. Then, when it comes to import the trained graph and load a saved variable, the following error occurs:

"C:\Program Files (x86)\Microsoft Visual Studio\Shared\Python36_64\python.exe" C:/Users/fredd/PycharmProjects/CNN/detectionDemo.py
Traceback (most recent call last):
File "C:\Program Files (x86)\Microsoft Visual Studio\Shared\Python36_64\lib\site-packages\tensorflow\python\client\session.py", line 1327, in _do_call
return fn(*args)
File "C:\Program Files (x86)\Microsoft Visual Studio\Shared\Python36_64\lib\site-packages\tensorflow\python\client\session.py", line 1312, in _run_fn
options, feed_dict, fetch_list, target_list, run_metadata)
File "C:\Program Files (x86)\Microsoft Visual Studio\Shared\Python36_64\lib\site-packages\tensorflow\python\client\session.py", line 1420, in _call_tf_sessionrun
status, run_metadata)
File "C:\Program Files (x86)\Microsoft Visual Studio\Shared\Python36_64\lib\site-packages\tensorflow\python\framework\errors_impl.py", line 516, in __exit__
c_api.TF_GetCode(self.status.status))
tensorflow.python.framework.errors_impl.InvalidArgumentError: You must feed a value for placeholder tensor 'x' with dtype float and shape [16,96,128,3]
 [[Node: x = Placeholder[dtype=DT_FLOAT, shape=[16,96,128,3], _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "C:/Users/fredd/PycharmProjects/CNN/detectionDemo.py", line 62, in <module>
print(sess.run('y_pred:0'))
File "C:\Program Files (x86)\Microsoft Visual Studio\Shared\Python36_64\lib\site-packages\tensorflow\python\client\session.py", line 905, in run
run_metadata_ptr)
File "C:\Program Files (x86)\Microsoft Visual Studio\Shared\Python36_64\lib\site-packages\tensorflow\python\client\session.py", line 1140, in _run
feed_dict_tensor, options, run_metadata)
File "C:\Program Files (x86)\Microsoft Visual Studio\Shared\Python36_64\lib\site-packages\tensorflow\python\client\session.py", line 1321, in _do_run
run_metadata)
File "C:\Program Files (x86)\Microsoft Visual Studio\Shared\Python36_64\lib\site-packages\tensorflow\python\client\session.py", line 1340, in _do_call
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: You must feed a value for placeholder tensor 'x' with dtype float and shape [16,96,128,3]
 [[Node: x = Placeholder[dtype=DT_FLOAT, shape=[16,96,128,3], _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]

Caused by op 'x', defined at:
File "C:/Users/fredd/PycharmProjects/CNN/detectionDemo.py", line 60, in <module>
saver = tf.train.import_meta_graph('results/steering_model.meta')
File "C:\Program Files (x86)\Microsoft Visual Studio\Shared\Python36_64\lib\site-packages\tensorflow\python\training\saver.py", line 1927, in import_meta_graph
**kwargs)
File "C:\Program Files (x86)\Microsoft Visual Studio\Shared\Python36_64\lib\site-packages\tensorflow\python\framework\meta_graph.py", line 741, in import_scoped_meta_graph
producer_op_list=producer_op_list)
File "C:\Program Files (x86)\Microsoft Visual Studio\Shared\Python36_64\lib\site-packages\tensorflow\python\util\deprecation.py", line 432, in new_func
return func(*args, **kwargs)
File "C:\Program Files (x86)\Microsoft Visual Studio\Shared\Python36_64\lib\site-packages\tensorflow\python\framework\importer.py", line 577, in import_graph_def
op_def=op_def)
File "C:\Program Files (x86)\Microsoft Visual Studio\Shared\Python36_64\lib\site-packages\tensorflow\python\framework\ops.py", line 3290, in create_op
op_def=op_def)
File "C:\Program Files (x86)\Microsoft Visual Studio\Shared\Python36_64\lib\site-packages\tensorflow\python\framework\ops.py", line 1654, in __init__
self._traceback = self._graph._extract_stack()  # pylint: disable=protected-access

InvalidArgumentError (see above for traceback): You must feed a value for placeholder tensor 'x' with dtype float and shape [16,96,128,3]
 [[Node: x = Placeholder[dtype=DT_FLOAT, shape=[16,96,128,3], _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]

What to do? Also, is there a way to visualize the graph I created?

EDIT

The full code is:

sess = tf.Session()
saver = tf.train.import_meta_graph('results/steering_model.meta')
saver.restore(sess, 'results/steering_model')
print(sess.run('y_pred:0'))

While the full code of my CNN is:

data = dataset.read_train_sets(train_path, 128, 96, classes, validation_size)

session = tf.Session()

x = tf.placeholder(tf.float32, shape=[batch_size, 96, 128, 3], name='x')

layer_conv1 = cnn.create_convolutional_layer(input=x,
                                     num_input_channels=3,
                                     conv_filter_size=3,
                                     num_filters=128)

layer_conv2 = cnn.create_convolutional_layer(input=layer_conv1,
                                     num_input_channels=128,
                                     conv_filter_size=3,
                                     num_filters=128)

layer_conv3 = cnn.create_convolutional_layer(input=layer_conv2,
                                     num_input_channels=128,
                                     conv_filter_size=3,
                                     num_filters=128)

layer_flat = cnn.create_flatten_layer(layer_conv3)

layer_fc1 = cnn.create_fc_layer(input=layer_flat,
                        num_inputs=layer_flat.get_shape()[1:4].num_elements(),
                        num_outputs=32,
                        use_relu=True)

layer_fc2 = cnn.create_fc_layer(input=layer_fc1,
                        num_inputs=32,
                        num_outputs=num_classes,
                        use_relu=True)

y_true = tf.placeholder(tf.float32, shape=[None, num_classes], name='y_true')
y_true_cls = tf.argmax(y_true, dimension=1)

y_pred = tf.nn.softmax(layer_fc2,name='y_pred')
y_pred_cls = tf.argmax(y_pred, dimension=1)
session.run(tf.global_variables_initializer())

cross_entropy = tf.nn.softmax_cross_entropy_with_logits_v2(logits=layer_fc2,
                                                labels=y_true)
cost = tf.reduce_mean(cross_entropy)

optimizer = tf.train.AdamOptimizer(learning_rate=1e-4).minimize(cost)
optimizer2 = tf.train.GradientDescentOptimizer(learning_rate=0.01).minimize(cost)
correct_prediction = tf.equal(y_pred_cls, y_true_cls)
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

def show_progress(epoch, feed_dict_train, feed_dict_validate, val_loss):
    acc = session.run(accuracy, feed_dict=feed_dict_train)
    val_acc = session.run(accuracy, feed_dict=feed_dict_validate)
    msg = "Training Epoch {0} --- Training Accuracy: {1:>6.1%}, Validation Accuracy: {2:>6.1%},  Validation Loss: {3:.3f}"
    print(msg.format(epoch + 1, acc, val_acc, val_loss))

total_iterations = 0

saver = tf.train.Saver()

def train(num_iteration):
    global total_iterations
    initOp = tf.global_variables_initializer()
    session.run(initOp)

    for i in range(total_iterations,
               total_iterations + num_iteration):

        x_batch, y_true_batch, _, cls_batch = data.train.next_batch(batch_size)
        x_valid_batch, y_valid_batch, _, valid_cls_batch = data.valid.next_batch(batch_size)

        feed_dict_tr = {x: x_batch,
                    y_true: y_true_batch}
        feed_dict_val = {x: x_valid_batch,
                     y_true: y_valid_batch}

        session.run(optimizer, feed_dict=feed_dict_tr)

        val_loss = session.run(cost, feed_dict=feed_dict_val)
        epoch = i

        show_progress(epoch, feed_dict_tr, feed_dict_val, val_loss)
        saver.save(session, 'results/steering_model')

    total_iterations += num_iteration

session.run(tf.global_variables_initializer())
train(500)

The network is saved successfully, but after the import I can't use any of the variables previously saved.

Upvotes: 0

Views: 216

Answers (1)

DomJack
DomJack

Reputation: 4183

The error isn't related to your saving/loading, but your session.run call. The graph you save/load has a placeholder (x) which you need to feed using feed_dict argument of Session.run just like if you constructed it manually. You can get it using graph.get_tensor_by_name

# after loading meta graph
graph = tf.get_default_graph()
x = graph.get_tensor_by_name('x:0')
y_pred = graph.get_tensor_by_name('y_pred:0')

x_data = np.random.normal(batch_size, 96, 128, 3)  # use actual data

session.run(y_pred, feed_dict={x: x_data}

Upvotes: 2

Related Questions