Reputation: 13
I have a PyTorch DataLoader and want to retrieve the Dataset object that the loader wraps around. If this is possible, how? Or does the dataset object only exist for pre-loaded datasets on torch?
The end goal is to easily integrate data in dataloader format into code setup for a dataset format (e.g. CIFAR10).
Where in the original code there is:
from torchvision import transforms, datasets
from typing import *
import torch
import os
from torch.utils.data import Dataset
def get_dataset(dataset, split):
if dataset == "CIFAR10"
return _cifar10(split)
def _cifar10(split: str) -> Dataset:
if split == "train":
return datasets.CIFAR10("./dataset_cache", train=True, download=True)
dataset = get_dataset("CIFAR10", train)
for i in range(len(dataset)):
...
I have tried importing the whole dataset at once:
from torchvision import transforms, datasets
from typing import *
import torch
import os
from torch.utils.data import Dataset
def get_dataset(dataset, split):
if dataset == "CIFAR10"
return _cifar10(split)
elif dataset == "mydataset"
return _mydataset(split)
def _mydataset(split: str) -> Dataset:
files = [file for file in os.listdir(database_directory + '/' + split)]
total_num_images = 0
for file in files:
number_images = len([name for name in os.listdir(database_directory +
'/' + split + '/' + file)])
total_num_images += number_images
if split == "train":
mydataset = torch.utils.data.DataLoader(
datasets.ImageFolder(dataset_directory + '/train'),batch_size=total_num_images)
return mydataset
dataset = get_dataset("mydataset", train)
for i in range(len(dataset)):
...
But this returns the error 'DataLoader' object is not subscriptable.
Upvotes: 1
Views: 3607
Reputation: 40738
You can access the dataset
attribute on data.DataLoader
to get its underlying data.Dataset
object. As seen in the source code here.
Upvotes: 3