Olli
Olli

Reputation: 1126

How to delete tensors from pytorch graph?

I'm doing predictions on images for object detection in a for loop. I've actually ran into the same issue with tensorflow and hoped I could solve it with pytorch. At least now it seems I have found out what the issue is (naively assuming it's the same for tensorflow)

I predict like this

 model = detection.fasterrcnn_resnet50_fpn(pretrained=True, 
    progress=True,pretrained_backbone=True).to(DEVICE)
    for i in tqdm(range(train.shape[0])):
        image = cv2.imread(train_img_paths[i])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = image.transpose((2, 0, 1))
        image = image / 255.0
        image = np.expand_dims(image, axis=0)
        image = torch.FloatTensor(image)
        image = image.to(DEVICE)
        predictions = model(image)[0]

Now through the garbage collector I found that each and every image stays in the graph. Is there away to avoid it?

I have not been able to use dataloader or dataset with the detection models (same with tensorflow hub)

Upvotes: 0

Views: 864

Answers (1)

jhso
jhso

Reputation: 3283

Don't forget when you're doing testing to turn gradient accumulation off. You can do this by either wrapping your code like:

with torch.no_grad():
     model.eval()
     out = model(x)

or if your code is a function, using a decorator to do the same thing:

@torch.no_grad()
def model_proc(model,x):
    model.eval()
    return model(x)

Upvotes: 1

Related Questions