user10248589
user10248589

Reputation:

Reducing batch size in pytorch

I am new to programming in pytorch. I am getting this error which says cuda out of memory. So I have to reduce the batch size. Can someone tell me how to do it in python code? I also don't know my current batch size.

p.s. I am trying to run the Deep Image Prior's super-resolution. Here's the code.

The error I am getting is when running the optimization. It says

RuntimeError: Cuda out of memory.

from __future__ import print_function
import matplotlib.pyplot as plt
%matplotlib inline

import argparse
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import numpy as np
from models import *

import torch
import torch.optim
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader


import warnings
warnings.filterwarnings("ignore")

from skimage.measure import compare_psnr
from models.downsampler import Downsampler

from utils.sr_utils import *

torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark =True
dtype = torch.cuda.FloatTensor

imsize = -1
factor = 16 # 8
enforse_div32 = 'CROP' # we usually need the dimensions to be divisible by a power of two (32 in this case)
PLOT = True
path_to_image = '/home/smitha/deep-image-prior/tnew.tif'
imgs = load_LR_HR_imgs_sr(path_to_image , imsize, factor, enforse_div32)
imgs['bicubic_np'], imgs['sharp_np'], imgs['nearest_np'] = get_baselines(imgs['LR_pil'], imgs['HR_pil'])
if PLOT:
    plot_image_grid([imgs['HR_np'], imgs['bicubic_np'], imgs['sharp_np'], imgs['nearest_np']], 4,12);
    print ('PSNR bicubic: %.4f   PSNR nearest: %.4f' %  (
                                    compare_psnr(imgs['HR_np'], imgs['bicubic_np']), 
                                    compare_psnr(imgs['HR_np'], imgs['nearest_np'])))
input_depth = 8
INPUT =     'noise'
pad   =     'reflection'
OPT_OVER =  'net'
KERNEL_TYPE='lanczos2'
LR = 5
tv_weight = 0.0
OPTIMIZER = 'adam'
if factor == 16: 
    num_iter = 10
    reg_noise_std = 0.01
elif factor == 8:
    num_iter = 40
    reg_noise_std = 0.05
else:
    assert False, 'We did not experiment with other factors'
net_input = get_noise(input_depth, INPUT, (imgs['HR_pil'].size[1],  imgs['HR_pil'].size[0])).type(dtype).detach()
NET_TYPE = 'skip' # UNet, ResNet
net = get_net(input_depth, 'skip', pad,
          skip_n33d=128, 
          skip_n33u=128, 
          skip_n11=4, 
          num_scales=5,
          upsample_mode='bilinear').type(dtype)
mse = torch.nn.MSELoss().type(dtype)

img_LR_var = np_to_torch(imgs['LR_np']).type(dtype)

downsampler = Downsampler(n_planes=3, factor=factor,  kernel_type=KERNEL_TYPE, phase=0.5, preserve_size=True).type(dtype) 
def closure():
    global i, net_input
    if reg_noise_std > 0:
        net_input = net_input_saved + (noise.normal_() * reg_noise_std)


    out_HR = net(net_input)
    out_LR = downsampler(out_HR)

    total_loss = mse(out_LR, img_LR_var) 

    if tv_weight > 0:
        total_loss += tv_weight * tv_loss(out_HR)

    total_loss.backward()

    # Log
    psnr_LR = compare_psnr(imgs['LR_np'], torch_to_np(out_LR))
    psnr_HR = compare_psnr(imgs['HR_np'], torch_to_np(out_HR))
    print ('Iteration %05d    PSNR_LR %.3f   PSNR_HR %.3f' % (i, psnr_LR, psnr_HR), '\r', end='')

    # History
    psnr_history.append([psnr_LR, psnr_HR])

    if PLOT and i % 100 == 0:
        out_HR_np = torch_to_np(out_HR)
        plot_image_grid([imgs['HR_np'], imgs['bicubic_np'], np.clip(out_HR_np, 0, 1)], factor=13, nrow=3)

    i += 1

    return total_loss   

psnr_history = [] 
volatile=True
net_input_saved = net_input.detach().clone()
noise = net_input.clone()
i = 0
p = get_params(OPT_OVER, net, net_input)
optimize(OPTIMIZER, p, closure, LR, num_iter)
out_HR_np = np.clip(torch_to_np(net(net_input)), 0, 1)
result_deep_prior = put_in_center(out_HR_np, imgs['orig_np'].shape[1:])
plot_image_grid([imgs['HR_np'],
             imgs['bicubic_np'],
             out_HR_np], factor=4, nrow=1);

Upvotes: 1

Views: 6139

Answers (1)

Prune
Prune

Reputation: 77847

The batch size depends on the model. Typically, it's the first dimension of your input tensors. Your model uses different names than I'm used to, some of which are general terms, so I'm not sure of your model topology or usage.

Upvotes: 1

Related Questions