M.K.
M.K.

Reputation: 11

How to split dataset by label or each set of data; Pytorch

I have a dataset on apple images and their sugar level. I took 6 photos of one apple for the dataset. So an apple has 6 photos & its sugar level.

I want to split the dataset into train and validation. I want apple images of the whole(6 photos in one set) to go into the train or validation set. I don't know how to split in that way.

This is CSV file for dataset

Apple is the label. enter image description here

Thank you in advance!

Upvotes: 0

Views: 1180

Answers (1)

jhso
jhso

Reputation: 3283

You could simply find the apple IDs and split by those instead. This could then be passed into a dataset class so that they are split across apple ids, rather than the standard approach of splitting randomly across the rows of the df.

apple_df = pd.read_csv(...)
apple_ids = apple_df['apple'].unique() #drop_duplicates() if DataFrame
apple_ids = apple_ids.sample(frac=1) #shuffle
train_val_split = int(0.9 * len(apple_ids))
train_apple_ids = apple_ids[:train_val_split]
val_apple_ids = apple_ids[train_val_split:]

class apple_dset(torch.utils.data.Dataset):
     def __init__(self,df)
          super(apple_dset,self).__init__()
          self.df = df
     def __len__(self):
          return len(self.df.index)
     def __getitem__(self,idx):
          apple = self.df.iloc[idx]
          # do loading...
          return img, label

train_apple_df = apple_df.loc[apple_df['apple'].isin([train_apple_ids])]
val_apple_df = apple_df.loc[apple_df['apple'].isin([val_apple_ids])]

train_apple_ds = apple_dset(train_apple_df)
val_apple_ds = apple_dset(val_apple_df)
 
  

Upvotes: 1

Related Questions