Reputation: 151
Currently, I have a pre-trained model that uses a DataLoader for reading a batch of images for training the model.
self.data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False,
num_workers=1, pin_memory=True)
...
model.eval()
for step, inputs in enumerate(test_loader.data_loader):
outputs = model(torch.cat([inputs], 1))
...
I want to process (make predictions) on images, as they arrive from a queue. It should be similar to a code that reads a single image and runs the model to make predictions on it. Something along the following lines:
from PIL import Image
new_input = Image.open(image_path)
model.eval()
outputs = model(torch.cat([new_input ], 1))
I was wondering if you could guide me how to do this and apply the same transformations in the DataLoader.
Upvotes: 1
Views: 8575
Reputation: 2552
You can use do it with IterableDataset :
from torch.utils.data import IterableDataset
class MyDataset(IterableDataset):
def __init__(self, image_queue):
self.queue = image_queue
def read_next_image(self):
while self.queue.qsize() > 0:
# you can add transform here
yield self.queue.get()
return None
def __iter__(self):
return self.read_next_image()
and batch_size = 1 :
import queue
import torchvision.transforms.functional as TF
buffer = queue.Queue()
new_input = Image.open(image_path)
buffer.put(TF.to_tensor(new_input))
# ... Populate queue here
dataset = MyDataset(buffer)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1)
for data in dataloader:
model(data) # data is one-image batch of size [1,3,H,W] where 3 - number of color channels
Upvotes: 1
Reputation: 2730
I don't know about dataLoader but you can load a single image using following function:
def safe_pil_loader(path, from_memory=False):
try:
if from_memory:
img = Image.open(path)
res = img.convert('RGB')
else:
with open(path, 'rb') as f:
img = Image.open(f)
res = img.convert('RGB')
except:
res = Image.new('RGB', (227, 227), color=0)
return res
And for applying transformation you can do as follows:
trans = transforms.Compose([
transforms.Resize(299),
transforms.CenterCrop(299),
transforms.ToTensor(),
normalize,
])
img=trans(img)
Upvotes: 0