Reputation: 3753
ds_train = tf.data.experimental.make_csv_dataset(
file_pattern = "./df_profile_seq_fill_csv/*.csv",
batch_size=batch_size, column_names=use_cols, label_name='label',
select_columns= select_cols,
num_parallel_reads=30,
shuffle_buffer_size=10000)
I read the data from csv, where the label
column is the label of integers, such as 0, 1,2 ...
model.fit( ds_train, validation_data=ds_test, steps_per_epoch=10000,
verbose=1,
epochs=1000000
)
I want to filter out all the samples where label == 0
,both for ds_train and ds_test.
Any methods to realize this ? Thanks.
Upvotes: 0
Views: 441
Reputation: 2011
One way to do it is firstly create dataset from csv with batch 1 (batch is a required arugment). Then filter "batches" which are examples and then re-batch again:
class_number_to_get_rid_of = 0
TRAIN_DATA_URL = "https://storage.googleapis.com/tf-datasets/titanic/train.csv"
train_file_path = tf.keras.utils.get_file("train.csv", TRAIN_DATA_URL)
dataset = tf.data.experimental.make_csv_dataset(train_file_path, batch_size=1)
dataset_filtered = dataset.filter(lambda p: tf.reduce_all(tf.not_equal(p['survived'], [class_number_to_get_rid_of])))
dataset_filtered = dataset_filtered.batch(5)
Upvotes: 1