Reputation: 97
I am trying to convert the labels of a tf.data.Dataset to one hot encoded labels. I am using this dataset. I've added titles (sentiment, text) to the columns, everything else is original.
Here is the code I use to encode the labels (positive, negative, neutral) to one hot (3,):
def _map_func(text, labels):
labels_enc = []
for label in labels:
if label=='negative':
label = -1
elif label=='neutral':
label = 0
else:
label = 1
label = tf.one_hot(
label, 3, name='label', axis=-1)
labels_enc.append(label)
return text, labels_enc
raw_train_ds = tf.data.experimental.make_csv_dataset(
'./data/sentiment_data/train.csv', BATCH_SIZE, column_names=['sentiment', 'text'],
label_name='sentiment', header=True
)
train_ds = raw_train_ds.cache().prefetch(buffer_size=AUTOTUNE)
train_ds = train_ds.map(_map_func)
I am getting the error: ValueError: Value [<tf.Tensor 'while/label:0' shape=(3,) dtype=float32>] is not convertible to a tensor with dtype <dtype: 'float32'> and shape (1, 3).
The second argument for the _map_func(text, label)
label has the shape (64,) type=string.
If I understood tensorflows tf.data.Dataset.map function correctly it creates a new dataset with the transformations applied by the transformation function. But as the error states the column for the labels can't be converted from a column with one string to a column with a list containing 3 floats. Is there any way to force the type of the new column to accept the encoded labels?
Thanks for the help :)
Upvotes: 2
Views: 2981
Reputation: 97
I solved the issue by using a TensorFlow TensorArray like so:
def _map_func(text, labels):
i=0
labels_enc = tf.TensorArray(tf.float32, size=0, dynamic_size=True,
clear_after_read=False)
for label in labels:
if label=='negative':
label = tf.constant(-1)
elif label=='neutral':
label = tf.constant(0)
else:
label = tf.constant(1)
label = tf.one_hot(
label, 3, name='label', axis=-1)
labels_enc.write(i, label)
i = i+1
return text, labels_enc.concat()
Upvotes: 1
Reputation: 36604
The mapping function is applied per element, so you don't need to create a list and loop through the batch items. Try it for one sample only:
def _map_func(text, label):
if label=='negative':
label = -1
elif label=='neutral':
label = 0
else:
label = 1
label = tf.one_hot(label, 3, name='label', axis=-1)
return text, label
Upvotes: 3