Reputation: 1126
I'm doing predictions on images for object detection in a for loop. I've actually ran into the same issue with tensorflow and hoped I could solve it with pytorch. At least now it seems I have found out what the issue is (naively assuming it's the same for tensorflow)
I predict like this
model = detection.fasterrcnn_resnet50_fpn(pretrained=True,
progress=True,pretrained_backbone=True).to(DEVICE)
for i in tqdm(range(train.shape[0])):
image = cv2.imread(train_img_paths[i])
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = image.transpose((2, 0, 1))
image = image / 255.0
image = np.expand_dims(image, axis=0)
image = torch.FloatTensor(image)
image = image.to(DEVICE)
predictions = model(image)[0]
Now through the garbage collector I found that each and every image stays in the graph. Is there away to avoid it?
I have not been able to use dataloader or dataset with the detection models (same with tensorflow hub)
Upvotes: 0
Views: 864
Reputation: 3283
Don't forget when you're doing testing to turn gradient accumulation off. You can do this by either wrapping your code like:
with torch.no_grad():
model.eval()
out = model(x)
or if your code is a function, using a decorator to do the same thing:
@torch.no_grad()
def model_proc(model,x):
model.eval()
return model(x)
Upvotes: 1