haroldw
haroldw

Reputation: 41

How get trainable variable count when using Tf Estimator?

I created a CNN classifier model using the tf estimator framework. However, I could not access variables defined in the model. tf.trainable_variables() always return 0. How can i access variables using tf estimator? In particular, how can i get a count of the total number of parameters (adding up the dimensions of all variables.

Thanks, Harold

Upvotes: 4

Views: 1342

Answers (2)

Mockingbird
Mockingbird

Reputation: 1031

As mentioned above, you should use:

Once you have the variables, you can use one of the following ways in order to get the total number of estimator parameters.

  • Either multiply the shape dims of each one of the variables with numpy.prod and then sum it:

    sum([np.prod(est.get_variable_value(var).shape) for var in est.get_variable_names()])

  • Or sum the sizes of the variables with numpy.ndarray.size and then sum it:

    sum([est.get_variable_value(var).size for var in est.get_variable_names()])

Upvotes: 2

xingzhang ren
xingzhang ren

Reputation: 106

You can use get_variable_names() to get all variable names, and use get_variable_value(name) to get the variable value by name.

Please use your code like:

estimator = tf.estimator.Estimator(...)
params = estimator.get_variable_names()
for p in params:
    print(p, estimator.get_variable_value(p).shape)

The more information is https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator#get_variable_names and https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator#get_variable_value.

Note: you must create the graph first and then will get the variables.

Upvotes: 1

Related Questions