AI-Lottery Winner
AI-Lottery Winner

Reputation: 65

How to load Tensorflow Dataset "Iris" and change the labels into one-hot encode

I'm trying to load the "iris" dataset directly from tensorflow datasets and I'm stuck. I'm use to working with CSVs.

import tensorflow as tf
import tensorflow_datasets as tfds

data = tfds.load("iris",split='train[:80%]', as_supervised=True)
data = data.batch(10)
features, labels = data

I don't know how I'm supposed to separate the features X,y. The labels are in a different tensor from the features, but I don't know how to access them to work with. I'd like to one hot encode the labels and feed them into the model, but I'm stuck here.

The tensorflow docs are sparse with info on how to do this. any help is much appreciated

Upvotes: 5

Views: 3650

Answers (1)

Nicolas Gervais
Nicolas Gervais

Reputation: 36604

You can one-hot your labels within the .map() method and tf.one_hot, like that:

data = data.batch(10).map(lambda x, y: (x, tf.one_hot(y, depth=3)))

print(next(iter(data))[1])
<tf.Tensor: shape=(10, 3), dtype=float32, numpy=
array([[1., 0., 0.],
       [0., 0., 1.],
       [0., 1., 0.],
       [0., 0., 1.],
       [1., 0., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [1., 0., 0.],
       [0., 0., 1.]], dtype=float32)>

Fully-working minimal example:

import tensorflow as tf
import tensorflow_datasets as tfds

data = tfds.load("iris",split='train[:80%]', as_supervised=True)
data = data.batch(10).map(lambda x, y: (x, tf.one_hot(y, depth=3))).repeat()

model = tf.keras.models.Sequential([
    tf.keras.layers.Dense(8, activation='relu'),
    tf.keras.layers.Dense(16, activation='relu'),
    tf.keras.layers.Dense(3, activation='softmax')
])

model.compile(loss='categorical_crossentropy', optimizer='adam', 
    metrics=['categorical_accuracy'])

history = model.fit(data, steps_per_epoch=8, epochs=10)
Epoch 10/10
1/8 [==>...........................] - ETA: 0s - loss: 0.8848 - cat_acc: 0.6000
8/8 [==============================] - 0s 4ms/step - loss: 0.8549 - cat_acc: 0.5250

Upvotes: 8

Related Questions