Amir Hossein F
Amir Hossein F

Reputation: 420

Getting the weights from TensorFlow's estimator.DNNClassifier

I understand that DNNClassifier is now trained via estimator.DNNClassifier. Before it was trained using contrib.learn.DNNClassifier so we could extract the weights using get_variable_names(). But there is no such method in the estimator.DNNClassifier. If contrib.learn is deprecated now, then how do we get the weights from the new estimator.DNNClassifier?

Upvotes: 2

Views: 931

Answers (1)

de3
de3

Reputation: 2000

Apparently, weights are called 'kernels' (learnt from this question)

For example, for:

estimator = tf.estimator.DNNClassifier(
   feature_columns=feature_columns, 
   hidden_units=[2])

estimator.train(input_fn=input_fn_train)

You can use get_variable_value like this:

print(estimator.get_variable_value("dnn/hiddenlayer_0/kernel"))
print(estimator.get_variable_value("dnn/hiddenlayer_0/bias")) 

Upvotes: 4

Related Questions