Nagabhushan S N
Nagabhushan S N

Reputation: 7267

Altering the tensorflow graph and resuming training

I'm trying to load the pretrained weights of MCnet model and resume training. The pretrained model provided here is trained with parameters K=4, T=7. But, I want a model with parameters K=4,T=1. Instead of starting training from scratch, I want to load the weights from this pretrained model. But since the graph has changed, I'm unable to load the pretrained model.

InvalidArgumentError (see above for traceback): Restoring from checkpoint failed. This is most likely due to a mismatch between the current graph and the graph from the checkpoint. Please ensure that you have not altered the graph expected based on the checkpoint. Original error:

Assign requires shapes of both tensors to match. lhs shape= [5,5,15,64] rhs shape= [5,5,33,64]
     [[node save/Assign_13 (defined at /media/nagabhushan/Data02/SNB/IISc/Research/04_Gaming_Video_Prediction/Workspace/VideoPrediction/Literature/01_MCnet/src/snb/mcnet.py:108) ]]

Is it possible to load the pretrained model with the new graph?

What I have tried:
Previously, I wanted to port the pretrained model from on older version of tensorflow to a newer one. I got this answer in SO which helped me port the model. The idea is to create the new graph and load variables existing in new graph from the saved one.

with tf.Session() as sess:
    _ = MCNET(image_size=[240, 320], batch_size=8, K=4, T=1, c_dim=3, checkpoint_dir=None, is_train=True)
    tf.global_variables_initializer().run(session=sess)

    ckpt_vars = tf.train.list_variables(model_path.as_posix())
    ass_ops = []
    for dst_var in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):
        for (ckpt_var, ckpt_shape) in ckpt_vars:
            if dst_var.name.split(":")[0] == ckpt_var and dst_var.shape == ckpt_shape:
                value = tf.train.load_variable(model_path.as_posix(), ckpt_var)
                ass_ops.append(tf.assign(dst_var, value))

    # Assign the variables
    sess.run(ass_ops)
    saver = tf.train.Saver()
    saver.save(sess, save_path.as_posix())

I tried the same here and it worked, meaning I got a new trained model for K=4,T=1. But I'm not sure if it is valid! I mean, will the weights make sense? Is this the right way to do it?

Info about the Model:
MCnet is a model used for video prediction i.e. given K past frames, it can predict the next T frames.

Any help is appreciated

Upvotes: 3

Views: 125

Answers (1)

learner
learner

Reputation: 3472

MCnet model has a generator and a discriminator. Generator is LSTM based and hence there is no problem in loading of the weights by varying the number of timesteps T. However the discriminator, as they've coded it, is convolutional. To apply convolutional layers on video, they are concatenating the frames in channel dimension. With K=4,T=7, you get a video of length 11 with 3 channels. When you concatenate them along channel dimension, you get an image with 33 channels. When they define discriminator, they define the first layer of discriminator to have 33 input channels and hence the weights have similar dimension. But with K=4,T=1, video length is 5 and the final image has 15 channels and so the weights would have 15 channels. This is the mismatch error you're observing. To fix this, you can pick weights from the first 15 channels only (for lack of a better way I can think of). Code below:

with tf.Session() as sess:
    _ = MCNET(image_size=[240, 320], batch_size=8, K=4, T=1, c_dim=3, checkpoint_dir=None, is_train=True)
    tf.global_variables_initializer().run(session=sess)

    ckpt_vars = tf.train.list_variables(model_path.as_posix())
    ass_ops = []
    for dst_var in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):
        for (ckpt_var, ckpt_shape) in ckpt_vars:
            if dst_var.name.split(":")[0] == ckpt_var:
                if dst_var.shape == ckpt_shape:
                    value = tf.train.load_variable(model_path.as_posix(), ckpt_var)
                    ass_ops.append(tf.assign(dst_var, value))
                else:
                    value = tf.train.load_variable(model_path.as_posix(), ckpt_var)
                    if dst_var.shape[2] <= value.shape[2]:
                        adjusted_value = value[:, :, :dst_var.shape[2]]
                    else:
                        adjusted_value = numpy.random.random(dst_var.shape)
                        adjusted_value[:, :, :value.shape[2], ...] = value
                    ass_ops.append(tf.assign(dst_var, adjusted_value))

    # Assign the variables
    sess.run(ass_ops)
    saver = tf.train.Saver()
    saver.save(sess, save_path.as_posix())  

Upvotes: 4

Related Questions