Reputation: 11
I'd like to understand why and where do we use tf.estimator.EstimatorSpec()
. I read the documentation on the Tensorflow website but I can't get an intuitive idea about it.
Please explain it to me in simple language.
Upvotes: 1
Views: 205
Reputation: 4183
I was a little bermused the first time I read the API, so I wrote this repo along with a basic explanation.
In short: tf.estimator.Estimator
requires a model_fn
as an input argument. That model_fn
should be a function that maps (features, labels, mode, [config, params]) -> tf.estimator.EstimatorSpec
. (config and params arguments are optional).
The EstimatorSpec
itself is a specification of an estimator, and contains everything it needs to train, evaluate and predict except for the input data itself (these are provided in train
methods of the tf.estimator.Estimator
Except from the above repository:
def get_logits(image):
"""Get logits from image."""
x = image
for filters in (32, 64):
x = tf.layers.conv2d(x, filters, 3)
x = tf.nn.relu(x)
x = tf.layers.max_pooling2d(x, 3, 2)
x = tf.reduce_mean(x, axis=(1, 2))
logits = tf.layers.dense(x, 10)
return logits
def get_estimator_spec(features, labels, mode):
Get an estimator specification.
features: mnist image batch, flaot32 tensor of shape
(batch_size, 28, 28, 1)
labels: mnist label batch, int32 tensor of shape (batch_size,)
mode: one of `tf.estimator.ModeKeys`, i.e. {"train", "infer", "predict"}
if mode not in {"train", "infer", "eval"}:
raise ValueError('mode should be in {"train", "infer", "eval"}')
logits = get_logits(features)
preds = tf.argmax(logits, axis=-1)
probs = tf.nn.softmax(logits, axis=-1)
predictions = dict(preds=preds, probs=probs, image=features)
if mode == 'infer':
return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)
loss = tf.losses.sparse_softmax_cross_entropy(labels, logits)
optimizer = tf.train.AdamOptimizer(learning_rate=1e-3)
step = tf.train.get_or_create_global_step()
train_op = optimizer.minimize(loss, global_step=step)
accuracy = tf.metrics.accuracy(labels, preds)
return tf.estimator.EstimatorSpec(
mode=mode, predictions=predictions,
loss=loss, train_op=train_op, eval_metric_ops=dict(accuracy=accuracy))
model_dir = '/tmp/mnist_simple'
def get_estimator():
return tf.estimator.Estimator(get_estimator_spec, model_dir)
Upvotes: 1