Reputation: 51
I am excited about the new model (i.e., wide_n_deep learning) released by Google and Tensorflow. So I am trying to play with it by running the tutorial example.
As a common trick in Machine Learning, batch learning is important when the entire training data set is big. So I try to modify the provided wide_n_deep learning tutorial example to get batch learning, as follows:
index_in_epoch = 0
num_examples = df_train.shape[0]
for i in xrange(FLAGS.train_steps):
startTime = datetime.now()
print("start step %i" %i)
start = index_in_epoch
index_in_epoch += batch_size
if index_in_epoch > num_examples:
if start < num_examples:
m.fit(input_fn=lambda: input_fn(df_train[start:num_examples], steps=1)
df_train.reindex(np.random.permutation(df_train.index)
start = 0
index_in_epoch = batch_size
if i%5 == 1:
results = m.evaluate(input_fn=lambda: input_fn(df_test), steps = 1)
for key in sorted(results):
print("%s: %s %(key, results[key]))
end = index_in_epoch
m.fit(input_fn=lambda: input_fn(df_train[start:end], steps=1)
Simply speaking, I iterate the entire training data set batch by batch, and for every batch, I call the "fit" function to re-train the model.
The problem of this naive strategy is that the processing time is too much slow (for instance, we want to iterate a 4-million-record data set 100 times, with batch size to be 100k, the training and evaluation time would be approximately 1 week). So I really doubt I am using the batch learning in a proper way.
I would be appreciate if any talent can share your experience to handle the batch learning when playing with the wide_n_deep learning model.
Upvotes: 0
Views: 609
Reputation: 1080
Every fit/evaluate call creates a graph and a session, then does the operation. If you do that in the loop, it will be slow.
To make it faster, you need to provide an input_fn
which will be called tensors batch by batch.
If you read data from a dataframe you can use to_feature_columns_and_input_fn
If you read data from a file which holds tf.Example
, you can use something like read_batch_examples
in your input_fn
.
Upvotes: 0