Swapnil
Swapnil

Reputation: 191

How many iterations a needed to train tensorflow with the entire MNIST data set (60000 images)?

The MNIST set consists of 60,000 images for training set. While training my Tensorflow, I want to run the train step to train the model with the entire training set. The deep learning example on the Tensorflow website uses 20,000 iterations with a batch size of 50 (totaling to 1,000,000 batches). When I try more than 30,000 iterations, my number predictions fail (predicts 0 for all handwritten numbers). My questions is, how many iterations should I use with a batch size of 50 to train the tensorflow model with the entire MNIST set?

self.mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
for i in range(FLAGS.training_steps):
    batch = self.mnist.train.next_batch(50)
    self.train_step.run(feed_dict={self.x: batch[0], self.y_: batch[1], self.keep_prob: 0.5})
    if (i+1)%1000 == 0:
       saver.save(self.sess, FLAGS.checkpoint_dir + 'model.ckpt', global_step = i)

Upvotes: 6

Views: 6922

Answers (4)

tdMJN6B2JtUe
tdMJN6B2JtUe

Reputation: 428

I have found that with MNIST, training on 3,833 images (validating on 56,167 because 60k**0.75 is just over 3.833) per epoch tends to converge well before 500 epochs. By "converge," I mean that validation loss does not decrease for 50 consecutive epochs of training with batch size 16; see this repo for an example of using early stopping with tf.keras; it mattered a lot to me in this case because I was doing model search and did not have time to train a single model very long.

Upvotes: 0

LITDataScience
LITDataScience

Reputation: 392

You can use something like no_improve_epoch and set it to let's say 3. What it will simply mean that if in 3 iterations there is no improvement of >1%, then stop the iterations.

no_improve_epoch= 0
        with tf.Session() as sess:
            sess.run(cls.init)
            if cls.config.reload=='True':
                print(cls.config.reload)
                cls.logger.info("Reloading the latest trained model...")
                saver.restore(sess, cls.config.model_output)
            cls.add_summary(sess)
            for epoch in range(cls.config.nepochs):
                cls.logger.info("Epoch {:} out of {:}".format(epoch + 1, cls.config.nepochs))
                dev = train
                acc, f1 = cls.run_epoch(sess, train, dev, tags, epoch)

                cls.config.lr *= cls.config.lr_decay

                if f1 >= best_score:
                    nepoch_no_imprv = 0
                    if not os.path.exists(cls.config.model_output):
                        os.makedirs(cls.config.model_output)
                    saver.save(sess, cls.config.model_output)
                    best_score = f1
                    cls.logger.info("- new best score!")

                else:
                    no_improve_epoch+= 1
                    if nepoch_no_imprv >= cls.config.nepoch_no_imprv:
                        cls.logger.info("- early stopping {} Iterations without improvement".format(
                            nepoch_no_imprv))
                        break

Sequence Tagging GITHUB

Upvotes: 0

Burak
Burak

Reputation: 61

With Machine learning you tend to have serious cases of diminishing returns. for example here is a list of accuracy from one of my CNNs:

Epoch 0 current test set accuracy :  0.5399
Epoch 1 current test set accuracy :  0.7298
Epoch 2 current test set accuracy :  0.7987
Epoch 3 current test set accuracy :  0.8331
Epoch 4 current test set accuracy :  0.8544
Epoch 5 current test set accuracy :  0.8711
Epoch 6 current test set accuracy :  0.888
Epoch 7 current test set accuracy :  0.8969
Epoch 8 current test set accuracy :  0.9064
Epoch 9 current test set accuracy :  0.9148
Epoch 10 current test set accuracy :  0.9203
Epoch 11 current test set accuracy :  0.9233
Epoch 12 current test set accuracy :  0.929
Epoch 13 current test set accuracy :  0.9334
Epoch 14 current test set accuracy :  0.9358
Epoch 15 current test set accuracy :  0.9395
Epoch 16 current test set accuracy :  0.942
Epoch 17 current test set accuracy :  0.9436
Epoch 18 current test set accuracy :  0.9458

As you can see the returns start to fall off after ~10 Epochs*, however this may vary based on your network and learning rate. Based on how critical/ how much time you have the amount that is good to do varies, but I have found 20 to be a reasonable number

*I have always used the word epoch to mean one entire run through a data set but i am unaware as to the accuracy of that definition, each epoch here is ~429 training steps with batches of size 128.

Upvotes: 4

Yao Zhang
Yao Zhang

Reputation: 5781

I think that depends on your stop criteria. You can stop training when loss doesn't improve, or you can have a validation data set, and stop training when validation accuracy doesn't improve any more.

Upvotes: 2

Related Questions