Reputation: 1421
I have a quick question. I am developing a model in tensorflow, and need to use the iteration number in a formula during the construction phase. I know how to use global_step, but I am not using an already existing optimizer. I am calculating my own gradients with
grad_W, grad_b = tf.gradients(xs=[W, b], ys=cost)
grad_W = grad_W +rnd.normal(0,1.0/(1+epoch)**0.55)
and then using
new_W = W.assign(W - learning_rate * (grad_W))
new_b = b.assign(b - learning_rate * (grad_b))
and would like to use the epoch value in the formula before updating my weights. How can I do it in the best way possible? I have a sess.run() part and would like to pass to the model the epoch number, but cannot directly use a tensor. From my run call
_, _, cost_ = sess.run([new_W, new_b ,cost],
feed_dict = {X_: X_train_tr, Y: labels_, learning_rate: learning_r})
I would like to pass the epoch number. How do you usually do it?
Thanks in advance, Umberto
EDIT:
Thanks for the hints. So seems to work
grad_W = grad_W + tf.random_normal(grad_W.shape,
0.0,1.0/tf.pow(0.01+tf.cast(epochv, tf.float32),0.55))
but I still have to see if that is what I need and if is working as intended. Ideas and Feedback would be great!
Upvotes: 0
Views: 308
Reputation: 3476
You can define epoch
as a non-trainable tf.Variable
in your graph and increment it at the end of each epoch. You can define an operation with tf.assign_add
to do the incrementation and run it end of each epoch.
Instead of rnd.normal
you will also need to use tf.random_normal
then.
Example:
epoch = tf.Variable(0, trainable=False) # 0 is initial value
# increment by 1 when the next op is run
epoch_incr_op = tf.assign_add(epoch, 1, name='incr_epoch')
# Define any operations that depend on 'epoch'
# Note we need to cast the integer 'epoch' to float to use in tf.pow
grad_W = grad_W + tf.random_normal(grad_W.shape, 0.0,
1.0/tf.pow(1+tf.cast(epoch, tf.float32), 0.55))
# Training loop
while running_epoch:
_, _, cost_ = sess.run([new_W, new_b ,cost],
feed_dict = {X_: X_train_tr, Y: labels_, learning_rate: learning_r})
# At end of epoch, increment epoch counter
sess.run(epoch_incr_op)
Upvotes: 1