Reputation: 121
I am using the high-level Estimator on TF:
estim = tf.contrib.learn.Estimator(...)
estim.fit ( some_input )
If some_input has x
, y
, and batch_size
, the codes run but with a warning; so I tried to use input_fn
, and managed to send x
, y
through this input_fn
, but not to send the batch_size
. Didn't find any example for it.
Could anyone share a simple example that uses input_fn
as input to the estim.fit
/ estim.evaluate
, and uses batch_size
as well?
Do I have to use tf.train.batch
? If so, how does it merge into the higher-level implementation (tf.layers
) - I don't know the graph's tf.Graph() or session?
Below is the warning I got:
WARNING:tensorflow:From /usr/local/lib/python2.7/dist-packages/tensorflow/contrib/learn/python/learn/monitors.py:657: calling evaluate
(from tensorflow.contrib.learn.python.learn.estimators.estimator) with y is deprecated and will be removed after 2016-12-01.
Instructions for updating: Estimator is decoupled from Scikit Learn interface by moving into separate class SKCompat. Arguments x, y and batch_size are only available in the SKCompat class, Estimator will only accept input_fn.
Example conversion:
est = Estimator(...) -> est = SKCompat(Estimator(...))
Upvotes: 4
Views: 3773
Reputation: 1131
The link provided in Roi's own comment was indeed really helpful. Since I was struggling with the same question as well for a while, I would like to summarize the answer provided by the link above as a reference:
def batched_input_fn(dataset_x, dataset_y, batch_size):
def _input_fn():
all_x = tf.constant(dataset_x, shape=dataset_x.shape, dtype=tf.float32)
all_y = tf.constant(dataset_y, shape=dataset_y.shape, dtype=tf.float32)
sliced_input = tf.train.slice_input_producer([all_x, all_y])
return tf.train.batch(sliced_input, batch_size=batch_size)
return _input_fn
This can then be used like this example (using TensorFlow v1.1):
model = CustomModel(FLAGS.learning_rate)
estimator= tf.estimator.Estimator(model_fn=model.build(), params=model.params())
estimator.train(input_fn=batched_input_fn(
train.features,
train.labels,
FLAGS.batch_size),
steps=FLAGS.train_steps)
Unfortunately, this approach is about 10x slower compared to manual feeding (using TensorFlows low-level API) or compared to using the whole dataset with train.shape[0] == batch_size
and not using train.sliced_input_producer()
and train.batch()
at all. At least on my machine (CPU only). I'm really wondering why this approach is so slow. Any ideas?
Edited:
I could speed it up a bit by using num_threads
> 1 as a parameter for train.batch()
. On a VM with 2 CPUs, I'm able to double the performance using this batching mechanism compared to the default num_threads=1
. But still, it is 5x slower than manual feeding.
But results might be different on a native system or a system that uses all CPU cores for the input-pipeline and the GPU for the model computation. Would be great if somebody could post his results in the comments.
Upvotes: 4