Reputation: 11
I'd like to understand why and where do we use tf.estimator.EstimatorSpec()
. I read the documentation on the Tensorflow website but I can't get an intuitive idea about it.
Please explain it to me in simple language.
Upvotes: 1
Views: 205
Reputation: 4183
I was a little bermused the first time I read the API, so I wrote this repo along with a basic explanation.
In short: tf.estimator.Estimator
requires a model_fn
as an input argument. That model_fn
should be a function that maps (features, labels, mode, [config, params]) -> tf.estimator.EstimatorSpec
. (config and params arguments are optional).
The EstimatorSpec
itself is a specification of an estimator, and contains everything it needs to train, evaluate and predict except for the input data itself (these are provided in train
/evaluate
/predict
methods of the tf.estimator.Estimator
class).
Except from the above repository:
def get_logits(image):
"""Get logits from image."""
x = image
for filters in (32, 64):
x = tf.layers.conv2d(x, filters, 3)
x = tf.nn.relu(x)
x = tf.layers.max_pooling2d(x, 3, 2)
x = tf.reduce_mean(x, axis=(1, 2))
logits = tf.layers.dense(x, 10)
return logits
def get_estimator_spec(features, labels, mode):
"""
Get an estimator specification.
Args:
features: mnist image batch, flaot32 tensor of shape
(batch_size, 28, 28, 1)
labels: mnist label batch, int32 tensor of shape (batch_size,)
mode: one of `tf.estimator.ModeKeys`, i.e. {"train", "infer", "predict"}
Returns:
tf.estimator.EstimatorSpec
"""
if mode not in {"train", "infer", "eval"}:
raise ValueError('mode should be in {"train", "infer", "eval"}')
logits = get_logits(features)
preds = tf.argmax(logits, axis=-1)
probs = tf.nn.softmax(logits, axis=-1)
predictions = dict(preds=preds, probs=probs, image=features)
if mode == 'infer':
return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)
loss = tf.losses.sparse_softmax_cross_entropy(labels, logits)
optimizer = tf.train.AdamOptimizer(learning_rate=1e-3)
step = tf.train.get_or_create_global_step()
train_op = optimizer.minimize(loss, global_step=step)
accuracy = tf.metrics.accuracy(labels, preds)
return tf.estimator.EstimatorSpec(
mode=mode, predictions=predictions,
loss=loss, train_op=train_op, eval_metric_ops=dict(accuracy=accuracy))
model_dir = '/tmp/mnist_simple'
def get_estimator():
return tf.estimator.Estimator(get_estimator_spec, model_dir)
Upvotes: 1