C-3PO
C-3PO

Reputation: 1213

PyTorch Boolean - Stop Backpropagation?

I need to create a Neural Network where I use binary gates to zero-out certain tensors, which are the output of disabled circuits.

To improve runtime speed, I was looking forward to use torch.bool binary gates to stop backpropagation along disabled circuits in the network. However, I created a small experiment using the official PyTorch example for the CIFAR-10 dataset, and the runtime speed is exactly the same for any values for gate_A and gate_B: (this means that the idea is not working)

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.pool = nn.MaxPool2d(2, 2)
        self.conv1a = nn.Conv2d(3, 6, 5)
        self.conv2a = nn.Conv2d(6, 16, 5)
        self.conv1b = nn.Conv2d(3, 6, 5)
        self.conv2b = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(32 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
        
    def forward(self, x):
        # Only one gate is supposed to be enabled at random
        # However, for the experiment, I fixed the values to [1,0] and [1,1]
        choice  =  randint(0,1)
        gate_A  =  torch.tensor(choice   ,dtype = torch.bool) 
        gate_B  =  torch.tensor(1-choice ,dtype = torch.bool) 
        
        a = self.pool(F.relu(self.conv1a(x)))
        a = self.pool(F.relu(self.conv2a(a)))
        
        b = self.pool(F.relu(self.conv1b(x)))
        b = self.pool(F.relu(self.conv2b(b)))
        
        a *= gate_A
        b *= gate_B
        x  = torch.cat( [a,b], dim = 1 )
        
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

How can i define gate_A and gate_B in such a way that backpropagation effectively stops if they are zero?

PS. Changing concatenation dynamically at runtime would also change which weights are assigned to every module. (for example, the weights associated to a could be assigned to b in another pass, disrupting how the network operates).

Upvotes: 1

Views: 957

Answers (2)

C-3PO
C-3PO

Reputation: 1213

Easy solution, simply define a tensor with zeros when a or b are disabled :)

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.pool = nn.MaxPool2d(2, 2)
        self.conv1a = nn.Conv2d(3, 6, 5)
        self.conv2a = nn.Conv2d(6, 16, 5)
        self.conv1b = nn.Conv2d(3, 6, 5)
        self.conv2b = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(32 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
        
    def forward(self, x):
        
        if randint(0,1):
            a = self.pool(F.relu(self.conv1a(x)))
            a = self.pool(F.relu(self.conv2a(a)))
            b = torch.zeros_like(a)
        else:
            b = self.pool(F.relu(self.conv1b(x)))
            b = self.pool(F.relu(self.conv2b(b)))
            a = torch.zeros_like(b)
        
        x  = torch.cat( [a,b], dim = 1 )
        
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

PS. I thought about this while I was having a coffee.

Upvotes: 1

GoodDeeds
GoodDeeds

Reputation: 8517

You could use torch.no_grad (the code below can probably be made more concise):

def forward(self, x):
        # Only one gate is supposed to be enabled at random
        # However, for the experiment, I fixed the values to [1,0] and [1,1]
        choice  =  randint(0,1)
        gate_A  =  torch.tensor(choice   ,dtype = torch.bool) 
        gate_B  =  torch.tensor(1-choice ,dtype = torch.bool) 
        
        if choice:
            a = self.pool(F.relu(self.conv1a(x)))
            a = self.pool(F.relu(self.conv2a(a)))
            a *= gate_A
            
            with torch.no_grad(): # disable gradient computation
                b = self.pool(F.relu(self.conv1b(x)))
                b = self.pool(F.relu(self.conv2b(b)))
                b *= gate_B
        else:
            with torch.no_grad(): # disable gradient computation
                a = self.pool(F.relu(self.conv1a(x)))
                a = self.pool(F.relu(self.conv2a(a)))
                a *= gate_A
            
            b = self.pool(F.relu(self.conv1b(x)))
            b = self.pool(F.relu(self.conv2b(b)))
            b *= gate_B

        x  = torch.cat( [a,b], dim = 1 )
        
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

On a second look, I think the following is a simpler solution to the specific problem:

def forward(self, x):
        # Only one gate is supposed to be enabled at random
        # However, for the experiment, I fixed the values to [1,0] and [1,1]
        choice  =  randint(0,1)

        if choice:
            a = self.pool(F.relu(self.conv1a(x)))
            a = self.pool(F.relu(self.conv2a(a)))
            b = torch.zeros(shape_of_conv_output) # replace shape of conv output here
        else:
            b = self.pool(F.relu(self.conv1b(x)))
            b = self.pool(F.relu(self.conv2b(b)))
            a = torch.zeros(shape_of_conv_output) # replace shape of conv output here
       
        x  = torch.cat( [a,b], dim = 1 )
        
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

Upvotes: 1

Related Questions