Stefan
Stefan

Reputation: 444

Tensorflow 'trainable=False' is ignored

I want to pass a fixed weight matrix to a 2D convolution operation in tensorflow. I tried putting trainable=False as follows but TF seems to ignore the option.

w = tf.Variable(w, trainable=False, dtype=tf.float32, name='upscaleW')
data = tf.nn.conv2d_transpose(data, w, outshapeF, strides, padding="SAME", data_format=data_format, name='UpsamplingDeconv2D')

It is constantly losing precision during training. The 1's become 0.98 then 0.96 etc and the 0's become 0.012 etc.

If I do tf.trainable_variables() the upscaleW are not there. I can only find them in tf.global_variables(), so they are not even in the list of trainable variables. I can't figure out how to freeze the weights.

Possibly this line is at fault? https://github.com/tensorflow/tensorflow/blob/r1.1/tensorflow/python/ops/nn_ops.py#L1075

Upvotes: 2

Views: 1754

Answers (2)

prosti
prosti

Reputation: 46409

Setting trainable=False keeps the variable out of the GraphKeys.TRAINABLE_VARIABLES collection in the graph, so they won't be trained when back-propagating.

It should work. No recent bugs.

Upvotes: 1

Stefan
Stefan

Reputation: 444

Nevermind. My bad. In my code I was passing to minimize(var_list=tf.contrib.framework.get_variables()) instead of get_trainable_variables which obviously overrides the trainable=False argument.

Upvotes: 2

Related Questions