yanachen
yanachen

Reputation: 3753

how to filter dataset when reading in tensorflow?

ds_train = tf.data.experimental.make_csv_dataset(
    file_pattern = "./df_profile_seq_fill_csv/*.csv",
    batch_size=batch_size, column_names=use_cols, label_name='label',
    select_columns= select_cols,
    num_parallel_reads=30, 
    shuffle_buffer_size=10000)

I read the data from csv, where the label column is the label of integers, such as 0, 1,2 ...

model.fit( ds_train, validation_data=ds_test, steps_per_epoch=10000,
     verbose=1,
    epochs=1000000
)

I want to filter out all the samples where label == 0 ,both for ds_train and ds_test. Any methods to realize this ? Thanks.

Upvotes: 0

Views: 441

Answers (1)

Proko
Proko

Reputation: 2011

One way to do it is firstly create dataset from csv with batch 1 (batch is a required arugment). Then filter "batches" which are examples and then re-batch again:

class_number_to_get_rid_of = 0
TRAIN_DATA_URL = "https://storage.googleapis.com/tf-datasets/titanic/train.csv"

train_file_path = tf.keras.utils.get_file("train.csv", TRAIN_DATA_URL)
dataset = tf.data.experimental.make_csv_dataset(train_file_path, batch_size=1)
dataset_filtered = dataset.filter(lambda p: tf.reduce_all(tf.not_equal(p['survived'], [class_number_to_get_rid_of])))
dataset_filtered = dataset_filtered.batch(5)

Upvotes: 1

Related Questions