Wei
Wei

Reputation: 341

Text input with Tensorflow

I am playing with Tensorflow and trying to build a RNN language model. I am struggling with how to read raw text input file.

Tensorflow guide mentioned a few approaches, including:

  1. tf.data.Dataset.from_tensor_slices() - which assumes my data is available in memory (np.array?)
  2. tf.data.TFRecordDataset (no idea how to use this)
  3. tf.data.TextLineDataset (what's the difference with 2? the API page are almost identical)

Confused with 2 and 3, I can only try approach 1, but facing the following issues:

  1. what if my data is too big to fit in memory?
  2. TF requires a fix-length, padded format, how do I do it? - Do I: Decide on a fix-length value (e.g. 30), read each line into a list, truncate the list to 30 if it is longer then 30, fill '0's to to make each line at least 30 long, append the list to a numpy array/matrix ?

I am sure these are such common problems that tensorflow much have provided built-in functions!

Upvotes: 2

Views: 2597

Answers (1)

Maxim
Maxim

Reputation: 53766

If your data is in text files (csv, tsv or just a collection of lines), the best way is to process it is with tf.data.TextLineDataset; tf.data.TFRecordDataset has a similar API, but it's for TFRecord binary format (checkout this nice post if you want some details).

A good example of text line processing via dataset API is TensorFlow Wide & Deep Learning Tutorial (the code is here). Here's the input function used there:

def input_fn(data_file, num_epochs, shuffle, batch_size):
  """Generate an input function for the Estimator."""
  assert tf.gfile.Exists(data_file), (
      '%s not found. Please make sure you have either run data_download.py or '
      'set both arguments --train_data and --test_data.' % data_file)

  def parse_csv(value):
    print('Parsing', data_file)
    columns = tf.decode_csv(value, record_defaults=_CSV_COLUMN_DEFAULTS)
    features = dict(zip(_CSV_COLUMNS, columns))
    labels = features.pop('income_bracket')
    return features, tf.equal(labels, '>50K')

  # Extract lines from input files using the Dataset API.
  dataset = tf.data.TextLineDataset(data_file)

  if shuffle:
    dataset = dataset.shuffle(buffer_size=_NUM_EXAMPLES['train'])

  dataset = dataset.map(parse_csv, num_parallel_calls=5)

  # We call repeat after shuffling, rather than before, to prevent separate
  # epochs from blending together.
  dataset = dataset.repeat(num_epochs)
  dataset = dataset.batch(batch_size)

  iterator = dataset.make_one_shot_iterator()
  features, labels = iterator.get_next()
  return features, labels

Here's what's going on in this snippet:

  • tf.data.TextLineDataset(data_file) line creates a Dataset object, assigned to dataset. It is a wrapper, not the contents holder, so the data is never read entirely into memory.

  • Dataset API allows to pre-process the data, e.g. with shuffle, map, batch and other methods. Note that API is functional, meaning that no data is processed when you call Dataset methods, they just define what transformations will be performed with tensors when the session actually starts and an iterator is evaluated (see below).

  • Finally, dataset.make_one_shot_iterator() returns an iterator tensor, from which one can read the values. You can evaluate features and labels and they will get the values of the data batches after transformation.

  • Also note that if you train your model on a GPU, the data will be streamed to the device directly, without intermediate stop in a client (python script itself).

Depending on your particular format, you won't probably need to parse the csv columns and simply read lines one by one.


Suggested reading: Importing Data guide.

Upvotes: 1

Related Questions