Saran
Saran

Reputation: 1844

tf.data filter dataset using label predicate

I am trying to filter the CIFAR10 training and test data with specific labels as given below,

import tensorflow as tf
from tensorflow.keras import datasets, layers, models
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
import numpy as np

Dataset

dataset = datasets.cifar10.load_data()

Split the dataset

train_data = tf.data.Dataset.from_tensor_slices((dataset[0][0],dataset[0][1]))
test_data = tf.data.Dataset.from_tensor_slices((dataset[1][0],dataset[1][1]))

Filter function

def filter_f(datas,filter_labels = tf.constant([0,1,2])):
  x = tf.not_equal(datas[1],filter_labels)
  x = tf.reduce_sum(tf.cast(x, tf.uint8))
  return tf.greater(x, tf.constant(0,tf.uint8))

dataset = train_data.filter(filter_f).batch(200)

as per similar issue. However, the filter function returns the unfiltered in the above code.

labels = []
for i, x in enumerate(tfds.as_numpy(dataset)):
    labels.append(x[1][0][0])
print(labels)

Returns

[4, 7, 5, 6, 0, 5, 5, 6, 5, 3, 6, 7, 0, 0, 6, 3]

To reproduce the result, please use this colab link

Upvotes: 0

Views: 3331

Answers (1)

thushv89
thushv89

Reputation: 11333

I'm not sure the exact issue underneath. Nevertheless, if you just need to remove data belonging to a specific class, you can use the following.

dataset = train_data.filter(lambda x,y: tf.reduce_all(tf.not_equal(y, [0,1,2]))).batch(200)

Upvotes: 4

Related Questions