FightGravity
FightGravity

Reputation: 89

Tensorflow - Training on condition

I am training a neural network with tensorflow (1.12) in a supervised fashion. I'd like to only train on specific examples. The examples are created on the fly by cutting out subsequences, hence I want to do the conditioning within tensorflow.

This is my original part of code:

train_step, gvs = minimize_clipped(optimizer, loss,
                               clip_value=FLAGS.gradient_clip,
                               return_gvs=True)
gradients = [g for (g,v) in gvs]
gradient_norm = tf.global_norm(gradients)
tf.summary.scalar('gradients/norm', gradient_norm)
eval_losses = {'loss1': loss1,
               'loss2': loss2}

The training step is later executed as:

batch_eval, _ = sess.run([eval_losses, train_step])

I was thinking about inserting something like

train_step_fake = ????
eval_losses_fake = tf.zeros_like(tensor)
train_step_new = tf.cond(my_cond, train_step, train_step_fake)
eval_losses_new = tf.cond(my_cond, eval_losses, eval_losses_fake)

and then doing

batch_eval, _ = sess.run([eval_losses, train_step])

However, I am not sure how to create a fake train_step.

Also, is this a good idea in general or is there a smoother way of doing this? I am using a tfrecords pipeline, but no other high-level modules (like keras, tf.estimator, eager execution etc.).

Any help is obviously greatly appreciated!

Upvotes: 1

Views: 421

Answers (1)

Stewart_R
Stewart_R

Reputation: 14485

Answering the specific question first. It's certainly possible to only perform your training step based on the tf.cond outcome. Note that the 2nd and 3rd params are lambdas though so more something like:

train_step_new = tf.cond(my_cond, lambda: train_step, lambda: train_step_fake)
eval_losses_new = tf.cond(my_cond, lambda: eval_losses, lambda: eval_losses_fake)

Your instinct that this may not be the right thing to do is correct though.

It's much more preferable (both in terms of efficiency and in terms of reading and reasoning about your code) to filter out the data you want to ignore before it gets to your model in the first place.

This is something you could achieve using the Dataset API. which has a really useful filter() method you could use. If you are using the dataset api to read your TFRecords right now then this should be as simple as adding something along the lines of:

dataset = dataset.filter(lambda x: {whatever op you were going to use in tf.cond})

If you are not yet using the dataset API, now is probably the time to have a little read up on it and consider it rather than butchering the model with that tf.cond() to act as a filter.

Upvotes: 1

Related Questions