Reputation: 341
I am playing with Tensorflow and trying to build a RNN language model. I am struggling with how to read raw text input file.
Tensorflow guide mentioned a few approaches, including:
tf.data.Dataset.from_tensor_slices()
- which assumes my data is available in memory (np.array?)tf.data.TFRecordDataset
(no idea how to use this)tf.data.TextLineDataset
(what's the difference with 2? the API page are almost identical)Confused with 2 and 3, I can only try approach 1, but facing the following issues:
I am sure these are such common problems that tensorflow much have provided built-in functions!
Upvotes: 2
Views: 2597
Reputation: 53766
If your data is in text files (csv, tsv or just a collection of lines), the best way is to process it is with tf.data.TextLineDataset
; tf.data.TFRecordDataset
has a similar API, but it's for TFRecord
binary format (checkout this nice post if you want some details).
A good example of text line processing via dataset API is TensorFlow Wide & Deep Learning Tutorial (the code is here). Here's the input function used there:
def input_fn(data_file, num_epochs, shuffle, batch_size):
"""Generate an input function for the Estimator."""
assert tf.gfile.Exists(data_file), (
'%s not found. Please make sure you have either run data_download.py or '
'set both arguments --train_data and --test_data.' % data_file)
def parse_csv(value):
print('Parsing', data_file)
columns = tf.decode_csv(value, record_defaults=_CSV_COLUMN_DEFAULTS)
features = dict(zip(_CSV_COLUMNS, columns))
labels = features.pop('income_bracket')
return features, tf.equal(labels, '>50K')
# Extract lines from input files using the Dataset API.
dataset = tf.data.TextLineDataset(data_file)
if shuffle:
dataset = dataset.shuffle(buffer_size=_NUM_EXAMPLES['train'])
dataset = dataset.map(parse_csv, num_parallel_calls=5)
# We call repeat after shuffling, rather than before, to prevent separate
# epochs from blending together.
dataset = dataset.repeat(num_epochs)
dataset = dataset.batch(batch_size)
iterator = dataset.make_one_shot_iterator()
features, labels = iterator.get_next()
return features, labels
Here's what's going on in this snippet:
tf.data.TextLineDataset(data_file)
line creates a Dataset
object, assigned to dataset
. It is a wrapper, not the contents holder, so the data is never read entirely into memory.
Dataset
API allows to pre-process the data, e.g. with shuffle
, map
, batch
and other methods. Note that API is functional, meaning that no data is processed when you call Dataset
methods, they just define what transformations will be performed with tensors when the session actually starts and an iterator is evaluated (see below).
Finally, dataset.make_one_shot_iterator()
returns an iterator tensor, from which one can read the values. You can evaluate features
and labels
and they will get the values of the data batches after transformation.
Also note that if you train your model on a GPU, the data will be streamed to the device directly, without intermediate stop in a client (python script itself).
Depending on your particular format, you won't probably need to parse the csv columns and simply read lines one by one.
Suggested reading: Importing Data guide.
Upvotes: 1