Timy Tang
Timy Tang

Reputation: 51

Tensorflow Wide&Deep tutorial Example with batch

I am excited about the new model (i.e., wide_n_deep learning) released by Google and Tensorflow. So I am trying to play with it by running the tutorial example.

As a common trick in Machine Learning, batch learning is important when the entire training data set is big. So I try to modify the provided wide_n_deep learning tutorial example to get batch learning, as follows:

index_in_epoch = 0
num_examples = df_train.shape[0]
for i in xrange(FLAGS.train_steps):
    startTime = datetime.now()
    print("start step %i" %i)
    start = index_in_epoch
    index_in_epoch += batch_size
    if index_in_epoch > num_examples:
        if start < num_examples:
          m.fit(input_fn=lambda: input_fn(df_train[start:num_examples], steps=1)
        df_train.reindex(np.random.permutation(df_train.index)
        start = 0
        index_in_epoch = batch_size
    if i%5 == 1:
        results = m.evaluate(input_fn=lambda: input_fn(df_test), steps = 1)
        for key in sorted(results):
          print("%s: %s %(key, results[key]))
    end = index_in_epoch
    m.fit(input_fn=lambda: input_fn(df_train[start:end], steps=1)

Simply speaking, I iterate the entire training data set batch by batch, and for every batch, I call the "fit" function to re-train the model.

The problem of this naive strategy is that the processing time is too much slow (for instance, we want to iterate a 4-million-record data set 100 times, with batch size to be 100k, the training and evaluation time would be approximately 1 week). So I really doubt I am using the batch learning in a proper way.

I would be appreciate if any talent can share your experience to handle the batch learning when playing with the wide_n_deep learning model.

Upvotes: 0

Views: 609

Answers (1)

user1454804
user1454804

Reputation: 1080

Every fit/evaluate call creates a graph and a session, then does the operation. If you do that in the loop, it will be slow. To make it faster, you need to provide an input_fn which will be called tensors batch by batch. If you read data from a dataframe you can use to_feature_columns_and_input_fn If you read data from a file which holds tf.Example, you can use something like read_batch_examples in your input_fn.

Upvotes: 0

Related Questions