Almog Levi
Almog Levi

Reputation: 161

How to get the filename of a sample from a DataLoader?

I need to write a file with the result of the data test of a Convolutional Neural Network that I trained. The data include speech data collection. The file format needs to be "file name, prediction", but I am having a hard time to extract the file name. I load the data like this:

import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader

TEST_DATA_PATH = ...

trans = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

test_dataset = torchvision.datasets.MNIST(
    root=TEST_DATA_PATH,
    train=False,
    transform=trans,
    download=True
)

test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False)

and I am trying to write to the file as follows:

f = open("test_y", "w")
with torch.no_grad():
    for i, (images, labels) in enumerate(test_loader, 0):
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        file = os.listdir(TEST_DATA_PATH + "/all")[i]
        format = file + ", " + str(predicted.item()) + '\n'
        f.write(format)
f.close()

The problem with os.listdir(TESTH_DATA_PATH + "/all")[i] is that it is not synchronized with the loaded files order of test_loader. What can I do?

Upvotes: 15

Views: 25655

Answers (3)

Mai Hai
Mai Hai

Reputation: 1350

If you using PyCharm or any IDE that has debug tool, let use it to take a look inside your data_loader, hope you can see a list of filenames, like my case.

In my case, My data_loader was created by mmsegmentation. Data_loader created by mmsegmentation

Upvotes: 2

prosti
prosti

Reputation: 46439

In general case DataLoader is there to provide you the batches from the Dataset(s) it has inside.

AS @Barriel mentioned in case of single/multi-label classification problems, the DataLoader doesn't have image file name, just the tensors representing the images , and the classes / labels.

However, DataLoader constructor when loading objects can take small things (together with the Dataset you may pack the targets/labels and the file names if you like) , even a dataframe

This way, the DataLoader may somehow grab that what you need.

Upvotes: 1

Berriel
Berriel

Reputation: 13641

Well, it depends on how your Dataset is implemented. For instance, in the torchvision.datasets.MNIST(...) case, you cannot retrieve the filename simply because there is no such thing as the filename of a single sample (MNIST samples are loaded in a different way).

As you did not show your Dataset implementation, I'll tell you how this could be done with the torchvision.datasets.ImageFolder(...) (or any torchvision.datasets.DatasetFolder(...)):

f = open("test_y", "w")
with torch.no_grad():
    for i, (images, labels) in enumerate(test_loader, 0):
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        sample_fname, _ = test_loader.dataset.samples[i]
        f.write("{}, {}\n".format(sample_fname, predicted.item()))
f.close()

You can see that the path of the file is retrieved during the __getitem__(self, index), especifically here.

If you implemented your own Dataset (and perhaps would like to support shuffle and batch_size > 1), then I would return the sample_fname on the __getitem__(...) call and do something like this:

for i, (images, labels, sample_fname) in enumerate(test_loader, 0):
    # [...]

This way you wouldn't need to care about shuffle. And if the batch_size is greater than 1, you would need to change the content of the loop for something more generic, e.g.:

f = open("test_y", "w")
for i, (images, labels, samples_fname) in enumerate(test_loader, 0):
    outputs = model(images)
    pred = torch.max(outputs, 1)[1]
    f.write("\n".join([
        ", ".join(x)
        for x in zip(map(str, pred.cpu().tolist()), samples_fname)
    ]) + "\n")
f.close()

Upvotes: 11

Related Questions