Nova
Nova

Reputation: 11

Understanding estimator class of tensorflow

I'd like to understand why and where do we use tf.estimator.EstimatorSpec(). I read the documentation on the Tensorflow website but I can't get an intuitive idea about it.

Please explain it to me in simple language.

Upvotes: 1

Views: 205

Answers (1)

DomJack
DomJack

Reputation: 4183

I was a little bermused the first time I read the API, so I wrote this repo along with a basic explanation.

In short: tf.estimator.Estimator requires a model_fn as an input argument. That model_fn should be a function that maps (features, labels, mode, [config, params]) -> tf.estimator.EstimatorSpec. (config and params arguments are optional).

The EstimatorSpec itself is a specification of an estimator, and contains everything it needs to train, evaluate and predict except for the input data itself (these are provided in train/evaluate/predict methods of the tf.estimator.Estimator class).

Except from the above repository:

def get_logits(image):
    """Get logits from image."""
    x = image
    for filters in (32, 64):
        x = tf.layers.conv2d(x, filters, 3)
        x = tf.nn.relu(x)
        x = tf.layers.max_pooling2d(x, 3, 2)
    x = tf.reduce_mean(x, axis=(1, 2))
    logits = tf.layers.dense(x, 10)
    return logits


def get_estimator_spec(features, labels, mode):
    """
    Get an estimator specification.
    Args:
      features: mnist image batch, flaot32 tensor of shape
          (batch_size, 28, 28, 1)
      labels: mnist label batch, int32 tensor of shape (batch_size,)
      mode: one of `tf.estimator.ModeKeys`, i.e. {"train", "infer", "predict"}
    Returns:
      tf.estimator.EstimatorSpec
    """
    if mode not in {"train", "infer", "eval"}:
        raise ValueError('mode should be in {"train", "infer", "eval"}')

    logits = get_logits(features)
    preds = tf.argmax(logits, axis=-1)
    probs = tf.nn.softmax(logits, axis=-1)
    predictions = dict(preds=preds, probs=probs, image=features)

    if mode == 'infer':
        return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)

    loss = tf.losses.sparse_softmax_cross_entropy(labels, logits)

    optimizer = tf.train.AdamOptimizer(learning_rate=1e-3)
    step = tf.train.get_or_create_global_step()
    train_op = optimizer.minimize(loss, global_step=step)

    accuracy = tf.metrics.accuracy(labels, preds)

    return tf.estimator.EstimatorSpec(
        mode=mode, predictions=predictions,
        loss=loss, train_op=train_op, eval_metric_ops=dict(accuracy=accuracy))


model_dir = '/tmp/mnist_simple'


def get_estimator():
    return tf.estimator.Estimator(get_estimator_spec, model_dir)

Upvotes: 1

Related Questions