arushi
arushi

Reputation: 1

How is optimizer step implemented for data parallelism in PyTorch?

I'm trying to implement Data Parallelism from scratch in PyTorch. To do this, I've implemented the following steps:

  1. Making model copies for each device
  2. Re-batching the data for each model copy
  3. Running forward passes
  4. Accumulating the gradients and averaging them
  5. Optimizer step on the averaged gradient <- you are here

I am trying to figure out how to combine the pytorch optimizer step and manual data parallelism. Currently, the only way I can do this is if I keep a copy of the optimizer around for each of the data parallel model replicas — here's a simple reduction of that code

# Create optimizer such that its step would update the weights of all models
optimizers = [torch.optim.SGD(model.parameters(), lr=0.1) for model in models]

# Step the optimizers
for optimizer in optimizers:
    optimizer.step()

I have a suspicion this is not how people implement the optimizer step, since it does not seem very memory efficient. How is this done in practice? Is there a way to attach one optimizer to multiple model copies? I'd like to understand what the proper way of doing this is!

Upvotes: 0

Views: 58

Answers (0)

Related Questions