Reputation: 754
I am using tensorflow slim resnet_v2 to extract image features. the resnet_v2_152.ckpt is from :resnet_v2_152.ckpt This is my code.
import tensorflow as tf
import tensorflow.contrib.slim.python.slim.nets.resnet_v2 as resnet_v2
def cnn_model_fn(features, labels, mode):
net, end_points = resnet_v2.resnet_v2_152(inputs=features, is_training=mode == tf.estimator.ModeKeys.TRAIN)
if mode == tf.estimator.ModeKeys.PREDICT:
return tf.estimator.EstimatorSpec(mode=mode, predictions=net)
else:
raise NotImplementedError('only support predict!')
def parse_filename(filename):
image_string = tf.read_file(filename)
image_decoded = tf.image.decode_jpeg(image_string, channels=3)
image_resized = tf.image.resize_images(image_decoded, [256, 256])
return image_resized
def dataset_input_fn(dataset, num_epochs=None, batch_size=128, shuffle=False, buffer_size=1000, seed=None):
def input_fn():
d = dataset.repeat(num_epochs).batch(batch_size)
if shuffle:
d = d.shuffle(buffer_size)
iterator = d.make_one_shot_iterator()
next_example = iterator.get_next()
return next_example
return input_fn
filenames = sorted(tf.gfile.Glob('/root/data/COCO/download/val2014/*'))
dataset = tf.contrib.data.Dataset.from_tensor_slices(filenames).map(parse_filename)
input_fn = dataset_input_fn(dataset, num_epochs=1, batch_size=1, shuffle=False)
estimator = tf.estimator.Estimator(model_fn=cnn_model_fn, model_dir=None)
es = estimator.predict(input_fn=input_fn,
checkpoint_path='/root/data/checkpoints/resnet_v2_152_2017_04_14/resnet_v2_152.ckpt')
print(es.__next__())
print("Done!")
And I got the error like this:
2017-09-10 22:06:36.875590: W tensorflow/core/framework/op_kernel.cc:1192] Not found: Tensor name "resnet_v2_152/block1/unit_1/bottleneck_v2/conv1/biases" not found in checkpoint files /root/data/checkpoints/resnet_v2_152_2017_04_14/resnet_v2_152.ckpt
[[Node: save/RestoreV2_1 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/cpu:0"](_arg_save/Const_0_0, save/RestoreV2_1/tensor_names, save/RestoreV2_1/shape_and_slices)]]
Traceback (most recent call last):
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1327, in _do_call
return fn(*args)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1306, in _run_fn
status, run_metadata)
File "/usr/lib/python3.5/contextlib.py", line 66, in __exit__
next(self.gen)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/errors_impl.py", line 466, in raise_exception_on_not_ok_status
pywrap_tensorflow.TF_GetCode(status))
tensorflow.python.framework.errors_impl.NotFoundError: Tensor name "resnet_v2_152/block1/unit_1/bottleneck_v2/conv1/biases" not found in checkpoint files /root/data/checkpoints/resnet_v2_152_2017_04_14/resnet_v2_152.ckpt
[[Node: save/RestoreV2_1 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/cpu:0"](_arg_save/Const_0_0, save/RestoreV2_1/tensor_names, save/RestoreV2_1/shape_and_slices)]]
[[Node: save/RestoreV2_242/_309 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/gpu:0", send_device="/job:localhost/replica:0/task:0/cpu:0", send_device_incarnation=1, tensor_name="edge_1240_save/RestoreV2_242", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/gpu:0"]()]]
I think I can solve this by initialize conv1/biases to 0,but tensorflow Estimator did not give me such a function. How can I fix that?
Upvotes: 3
Views: 902
Reputation: 923
I think, you are expecting to load pre-trained weights but not just initialize variables in resnet. You should consider using tf.train.Scaffold object.
Model routine should look like this
def cnn_model_fn(features, labels, mode):
with slim.arg_scope(resnet_v2.resnet_arg_scope()):
logits, end_points = resnet_v2.resnet_v2_152(features,
is_training=mode == tf.estimator.ModeKeys.TRAIN)
checkpoint_file = 'resnet_v2_152.ckpt'
init_fn = slim.assign_from_checkpoint_fn(
checkpoint_file,
[var for var in tf.global_variables()])
saver = tf.train.Saver(max_to_keep=10)
scaffold = tf.train.Scaffold(init_fn=init_fn, saver=saver)
if mode == tf.estimator.ModeKeys.PREDICT:
return tf.estimator.EstimatorSpec(mode=mode,
predictions={'logits': logits},
scaffold=scaffold)
else:
raise NotImplementedError('only support predict!')
Upvotes: 1