Carl Cheung
Carl Cheung

Reputation: 558

How to concurrently run multiple branches in pytorch?

I was trying to build a network with multiple branches in pytorch. But how can I run multiple branches in parallel instead of run them one by one?

Not like tensorflow or keras, pytorch use dynamic graph, so I can't define concurrent processing beforehand.

I looked up for some similar official implement of pytorch network like InceptionNet, only to find out pytorch runs consecutively with multiple branches.

from inception.py

def _forward(self, x):
    branch1x1 = self.branch1x1(x)

    branch5x5 = self.branch5x5_1(x)
    branch5x5 = self.branch5x5_2(branch5x5)

    branch3x3dbl = self.branch3x3dbl_1(x)
    branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
    branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)

    branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
    branch_pool = self.branch_pool(branch_pool)

    outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
    return outputs

Four branches run one by one, first branch1x1, then branch5x5, and branch3x3dbl, branch_pool. Then outputs stores their results and they will be concatenated later.

Would't it be a waste of performance? And how can we deal with that?

Upvotes: 4

Views: 1602

Answers (1)

Nopileos
Nopileos

Reputation: 2117

In general you don't have to care about performance of network execution as long as you use functions provided by pytorch.

As pointed out in the comments, all calls to the gpu are asynchron. And as long as a call is not dependent on data it is executed. So in your case you have multiple branches. Pytorch will schedule all operations and execute them according to the data dependencies. Since you don't share data between branches they will be executed in parallel.

So in your case

branch3x3dbl = self.branch3x3dbl_1(x)
branch1x1 = self.branch1x1(x)
branch3x3dbl = self.branch3x3dbl_1(x)

are probably executed more or less at the same time. Same thing for all the following layers.

Upvotes: 3

Related Questions