Rose Perrone
Rose Perrone

Reputation: 63586

Tensorflow: How to define a one-hot feature column for a canned estimator

My one-hot encoding appears to incorrectly have 3 dimensions during training (I think it should have 2), which causes an OOM. How am I constructing the one-hot feature column incorrectly?

I get this error when I begin to train the neural net:

OOM when allocating tensor with shape[114171,829,829]

[[Node: dnn/input_from_feature_columns/input_layer/air_store_id_indicator/one_hot = OneHot[T=DT_FLOAT, TI=DT_INT64, axis=-1, _device="/job:localhost/replica:0/task:0/gpu:0"](dnn/input_from_feature_columns/input_layer/air_store_id_indicator/SparseToDense/_149, dnn/input_from_feature_columns/input_layer/air_store_id_indicator/one_hot/depth, dnn/input_from_feature_columns/input_layer/air_store_id_indicator/one_hot/on_value, dnn/input_from_feature_columns/input_layer/air_store_id_indicator/one_hot/off_value)]]

I tried to define a one-hot feature column for use in my DNNRegressor as follows:

tf.feature_column.indicator_column(
    tf.feature_column.categorical_column_with_identity(key='id', num_buckets=df_train['id'].unique().size))

In my input_fn to DNNRegressor::fit(), I populate the one-hot encoding like this:

labels, uniques = pd.factorize(df_train['id'])
returned_feature_columns[k] = tf.one_hot(labels, uniques.size, 1, 0)

When I print that one-hot encoding, its dimensions appear correct, because I have 114171 training examples, and 829 unique ids:

Tensor("one_hot:0", shape=(114171, 829), dtype=int32)

Upvotes: 0

Views: 925

Answers (1)

J.E.K
J.E.K

Reputation: 1461

The defined tensor is consuming to much memory. There is a 2GB limit for the tf.GraphDef protocol buffer. You should train your model with smaller batches. There is a nice higher level Estimator API to build a input_fn for pandas dataframes:

input_fn = tf.estimator.inputs.pandas_input_fn(
  x=pd.DataFrame({'x':x_data}),
  num_epochs=num_epochs,
  shuffle=True)

For more details you can find documentation here.

Upvotes: 1

Related Questions