Reputation: 1
I'm working on a machine learning project in PyTorch where I need to optimize a model using the full batch gradient descent method. The key requirement is that the optimizer should use all the data points in the dataset for each update. My challenge with the existing torch.optim.SGD optimizer is that it doesn't inherently support using the entire dataset in a single update. This is crucial for my project as I need the optimization process to consider all data points to ensure the most accurate updates to the model parameters.
Additionally, I would like to retain the use of Nesterov momentum in the optimization process. I understand that one could potentially modify the batch size to equal the entire dataset, simulating a full batch update with the SGD optimizer. However, I'm interested in whether there's a more elegant or direct way to implement a true Gradient Descent optimizer in PyTorch that also supports Nesterov momentum.
Ideally, I'm looking for a solution or guidance on how to implement or configure an optimizer in PyTorch that meets the following criteria:
Upvotes: 0
Views: 919
Reputation: 392
The pytorch SGD implementation is actually independent of the batching!
It only uses the gradients that were calculated and stored in the parameters .grad
attribute in the backward pass.
So the batch size used for calculations and the batch size used for optimization are decoupled.
You can now either:
a) Put all your samples as one big batch through your model by setting the batchsize to the dataset size or
b) Accumulate the gradients for many smaller batches before doing a single step of the optimizer (Pseudo-code):
model = YourModel()
data = YourDataSetOrLoader()
optim = torch.optim.SGD(model.parameters())
for full_batch_step in range(100)
#this sets the accumulated gradient to zero
optim.zero_grad()
for batch in data:
f=model(data)
# this adds the gradient wrt to the parameters for the current datapoint to the model paramters
f.backward()
# now after we summed the gradient for all samples, we do a GD step.
optim.step()
Upvotes: 2