Reputation: 420
I understand that DNNClassifier is now trained via estimator.DNNClassifier
. Before it was trained using contrib.learn.DNNClassifier
so we could extract the weights using get_variable_names()
. But there is no such method in the estimator.DNNClassifier
. If contrib.learn
is deprecated now, then how do we get the weights from the new estimator.DNNClassifier
?
Upvotes: 2
Views: 931
Reputation: 2000
Apparently, weights are called 'kernels' (learnt from this question)
For example, for:
estimator = tf.estimator.DNNClassifier(
feature_columns=feature_columns,
hidden_units=[2])
estimator.train(input_fn=input_fn_train)
You can use get_variable_value
like this:
print(estimator.get_variable_value("dnn/hiddenlayer_0/kernel"))
print(estimator.get_variable_value("dnn/hiddenlayer_0/bias"))
Upvotes: 4