Kid_Learning_C
Kid_Learning_C

Reputation: 3603

In tensorflow estimator class, what does it mean to train one step?

Specifically, within one step, how does it training the model? What is the quitting condition for the gradient descent and back propagation?

Docs here: https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator#train

e.g.

  mnist_classifier = tf.estimator.Estimator(model_fn=cnn_model_fn)

  train_input_fn = tf.estimator.inputs.numpy_input_fn(
      x={"x": X_train},
      y=y_train,
      batch_size=50,
      num_epochs=None,
      shuffle=True)

  mnist_classifier.train(
      input_fn=train_input_fn,
      steps=100,
      hooks=[logging_hook])

I understand that training one step means that we feed the neural network model with batch_size many data points once. My questions is, within this one step, how many times does it perform gradient descent? Does it do back propagation and gradient descent just once or does it keep performing gradient descent until the model weights reach a optimal for this batch of data?

Upvotes: 2

Views: 716

Answers (3)

Dan D.
Dan D.

Reputation: 8567

The input function emits batches (when num_epochs=None, num_batches is infinite):

num_batches = num_epochs * (num_samples / batch_size)

One step is processing 1 batch, if steps > num_batches, the training will stop after num_batches.

Upvotes: 0

DocDriven
DocDriven

Reputation: 3974

In addition to @David Parks answer, using batches for performing gradient descent is referred to as stochastic gradient descent. Instead of updating the weights after each training sample, you average over the sum of gradients of the batch and use this new gradient to update your weights.

For example, if you have 1000 trainings samples and use batches of 200, you calculate the average gradient for 200 samples, and update your weights with it. That means that you only perform 5 updates overall instead of updating your weights 1000 times. On sufficiently big data sets, you will experience a much faster training process.

Michael Nielsen has a really nice way to explain this concept in his book.

Upvotes: 3

David Parks
David Parks

Reputation: 32071

1 step = 1 gradient update. And each gradient update step requires one forward pass and one backward pass.

The stopping condition is generally left up to you and is arguably more art than science. Commonly you will plot (tensorboard is handy here) your cost, training accuracy, and periodically your validation set accuracy. The low point on validation accuracy is generally a good point to stop. Depending on your dataset validation accuracy may drop and at some point increase again, or it may simply flatten out, at which point the stopping condition often correlates with the developer's degree of impatience.

Here's a nice article on stopping conditions, a google search will turn up plenty more.

https://stats.stackexchange.com/questions/231061/how-to-use-early-stopping-properly-for-training-deep-neural-network

Another common approach to stopping is to drop the learning rate every time you compute that no change has occurred to validation accuracy for some "reasonable" number of steps. When you've effectively hit 0 learning rate, you call it quits.

Upvotes: 2

Related Questions