2D_
2D_

Reputation: 601

FLAGS = None meaning?

I am new to python & tensorFlow, and I am following this MNIST tutorial on tensorFlow documentation.

In the first bit, I don't know what FLAGS = None does here. I searched in Google, and came back empty. Seems like this is too obvious to others?

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import sys

from tensorflow.examples.tutorials.mnist import input_data

import tensorflow as tf

FLAGS = None


def main(_):
  # Import data
  mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)

So what is FLAGS and how it is used? e.g., FLAGS.data_dir

Any help would be appreciated!

Upvotes: 6

Views: 5444

Answers (2)

hpaulj
hpaulj

Reputation: 231415

Initializing FLAGS=None is just a way of initializing the global constant. If left as is it will raise an error in main, since None does not have any attributes.

But if set via an argparse parser as shown in the fuller examples, it is a simple object with a variety of attributes. main assumes one those attributes is called data_dir.

If after the

FLAGS, unparsed = parser.parse_known_args()
print(FLAGS)

you should see Namespace(data_dir='a directory', ....), where the value for data_dir was parsed from the command line.

Upvotes: 3

Taku
Taku

Reputation: 33724

This was the full code you're looking at: I'll explain:

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import sys

from tensorflow.examples.tutorials.mnist import input_data

import tensorflow as tf

FLAGS = None   #Adds a default value to FLAGS


def main(_):  #Everything inside the function is not checked until it's called
  mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True) #FLAGS is not None anymore because it got changed below

  x = tf.placeholder(tf.float32, [None, 784])
  W = tf.Variable(tf.zeros([784, 10]))
  b = tf.Variable(tf.zeros([10]))
  y = tf.matmul(x, W) + b

  y_ = tf.placeholder(tf.float32, [None, 10])

  cross_entropy = tf.reduce_mean(
      tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))
  train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

  sess = tf.InteractiveSession()
  tf.global_variables_initializer().run()
  # Train
  for _ in range(1000):
    batch_xs, batch_ys = mnist.train.next_batch(100)
    sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})

  # Test trained model
  correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
  accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  print(sess.run(accuracy, feed_dict={x: mnist.test.images,
                                      y_: mnist.test.labels}))

if __name__ == '__main__': 
  parser = argparse.ArgumentParser()
  parser.add_argument('--data_dir', type=str, default='/tmp/tensorflow/mnist/input_data',
                      help='Directory for storing input data')

  FLAGS, unparsed = parser.parse_known_args() #Here it changed the value of FLAGS to the first thing returned from parser.parse_known_args()

  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) #runs the app (calling main)

what was happening is that FLAGS got changed here: FLAGS, unparsed = parser.parse_known_args()

Upvotes: 4

Related Questions