Charlie Parker
Charlie Parker

Reputation: 5301

How does one create a data set in pytorch and save it into a file to later be used?

I want to extract data from cifar10 in a specific order according to some criterion f(Image,label) (for the sake of an example lets say f(Image,label) simply computes the sum of all the pixels in Image). Then it I want to generate 1 file for the train set and 1 file for the test set that I can later load in a dataloader to use for training a neural net.

How do I do this? My current idea was simply to loop through the data with data loader with shuffle off and remember the indices of the images and the score and then sort the indices according to the score and then loop through everything again and create some giant numpy array and save it. After I save it I’d use torch.utils.data.TensorDataset(X_train, X_test) to wrap with TensorDataset and feed to DataLoader.

I think it might work for a small data set like cifar10 at the very least, right?

Another very important thing for me is that I also want to only train on the first K images (especially since I already sorted them the first K have a special meaning which I want to keep) so respecting but training only with a fraction will be important.


https://discuss.pytorch.org/t/how-does-one-create-a-data-set-in-pytorch-and-save-it-into-a-file-to-later-be-used/16742

Upvotes: 2

Views: 13895

Answers (1)

mai
mai

Reputation: 11

the simplest way to save this is to just read to an array and then do numpy.save('file',data,allow_pickle =False) to load it you then to data = numpy.load('file')

remember to set the batch size to 1 and torch.to_numpy() everything

once you do this its fairly simple to just rebuild your data loader and reload data loader with your dataset

use numpy.load('file', mmap_mode='r') if you need to get at the data without loading it all to ram (helps with those pesky 600gb datasets)

for those asking why numpy.save() whole datasets? : sometimes your data needs post processing, batching, reshaping, and this takes a lot of CPU time. You don't want to be using that cpu time re-crunching your data before you send it to your model.

the next step up is to start using databases and servers, it does this, but better with more SQL!

in terms of using slices of data its just a mater of reloading your dataset with dataset(data[k:],label[k:]) instead of dataset(data,label)

Upvotes: 1

Related Questions