Strandtasche
Strandtasche

Reputation: 118

Making Prediction with tensorflow's estimator.DNNRegressor

I am quite new to tensorflow and in order to learn to use it I am currently trying to implement a very simple DNNRegressor that predicts the movement of an object in 2D but I can't seem to the the predict function to work.

for this purpose I have some Input data - x and y coordinates of the object in a number of previous time steps. I want the output to a reasonable estimation of the location the object if it continues to move in the same direction with the same speed.

I am using tensorflow version 1.8.0

My regressor is defined like this:

CSV_COLUMN_NAMES = ['X_0', 'X_1', 'X_2', 'X_3', 'X_4', 'Y_0', 'Y_1', 'Y_2', 'Y_3', 'Y_4', 'Y_5']

my_feature_columns = []
for key in columnNames:
     my_feature_columns.append(tf.feature_column.numeric_column(key=key))


regressor = estimator.DNNRegressor(feature_columns=my_feature_columns,
                                           label_dimension=1,
                                           hidden_units=hidden_layers,
                                           model_dir=MODEL_PATH,
                                           dropout=dropout,
                                           config=test_config)

my input is, like the one in the tensorflow tutorial on premade estimators, a dict with the column as key. An example for this input can be seen here.

regressor.train(arguments) and regressor.evaluate(arguments) seem to work just fine, but predict does not.

parallel to the code on the tensorflow site I tried to do this:

y_pred = regressor.predict(input_fn=eval_input_fn(X_test, labels=None, batch_size=1))

and it seems like that works as well.

The problem I'm facing now is that I can't get anything from that y_pred object.

when I enter print(y_pred) I get <generator object Estimator.predict at 0x7fd9e8899888> which would suggest to me that should be able to iterate over it but

    for elem in y_pred:
        print(elem)

results in TypeError: unsupported callable

Again, I'm quite new to this and I am sorry if the answer is obvious but I would be very grateful if someone could tell me what I'm doing wrong here.

Upvotes: 0

Views: 1129

Answers (1)

Vijay Mariappan
Vijay Mariappan

Reputation: 17191

The input_fn to regressor.predict should be a function. See the definition:

input_fn: A function that constructs the features.

You need to change your code to:

y_pred = regressor.predict(input_fn=eval_input_fn)

Upvotes: 1

Related Questions