Reputation: 1213
I need to create a Neural Network where I use binary gates to zero-out certain tensors, which are the output of disabled circuits.
To improve runtime speed, I was looking forward to use torch.bool
binary gates to stop backpropagation along disabled circuits in the network. However, I created a small experiment using the official PyTorch
example for the CIFAR-10 dataset, and the runtime speed is exactly the same for any values for gate_A
and gate_B
: (this means that the idea is not working)
class Net(nn.Module):
def __init__(self):
super().__init__()
self.pool = nn.MaxPool2d(2, 2)
self.conv1a = nn.Conv2d(3, 6, 5)
self.conv2a = nn.Conv2d(6, 16, 5)
self.conv1b = nn.Conv2d(3, 6, 5)
self.conv2b = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(32 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
# Only one gate is supposed to be enabled at random
# However, for the experiment, I fixed the values to [1,0] and [1,1]
choice = randint(0,1)
gate_A = torch.tensor(choice ,dtype = torch.bool)
gate_B = torch.tensor(1-choice ,dtype = torch.bool)
a = self.pool(F.relu(self.conv1a(x)))
a = self.pool(F.relu(self.conv2a(a)))
b = self.pool(F.relu(self.conv1b(x)))
b = self.pool(F.relu(self.conv2b(b)))
a *= gate_A
b *= gate_B
x = torch.cat( [a,b], dim = 1 )
x = torch.flatten(x, 1) # flatten all dimensions except batch
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
How can i define gate_A
and gate_B
in such a way that backpropagation effectively stops if they are zero?
PS. Changing concatenation
dynamically at runtime would also change which weights are assigned to every module. (for example, the weights associated to a
could be assigned to b
in another pass, disrupting how the network operates).
Upvotes: 1
Views: 957
Reputation: 1213
Easy solution, simply define a tensor with zeros when a
or b
are disabled :)
class Net(nn.Module):
def __init__(self):
super().__init__()
self.pool = nn.MaxPool2d(2, 2)
self.conv1a = nn.Conv2d(3, 6, 5)
self.conv2a = nn.Conv2d(6, 16, 5)
self.conv1b = nn.Conv2d(3, 6, 5)
self.conv2b = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(32 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
if randint(0,1):
a = self.pool(F.relu(self.conv1a(x)))
a = self.pool(F.relu(self.conv2a(a)))
b = torch.zeros_like(a)
else:
b = self.pool(F.relu(self.conv1b(x)))
b = self.pool(F.relu(self.conv2b(b)))
a = torch.zeros_like(b)
x = torch.cat( [a,b], dim = 1 )
x = torch.flatten(x, 1) # flatten all dimensions except batch
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
PS. I thought about this while I was having a coffee.
Upvotes: 1
Reputation: 8517
You could use torch.no_grad
(the code below can probably be made more concise):
def forward(self, x):
# Only one gate is supposed to be enabled at random
# However, for the experiment, I fixed the values to [1,0] and [1,1]
choice = randint(0,1)
gate_A = torch.tensor(choice ,dtype = torch.bool)
gate_B = torch.tensor(1-choice ,dtype = torch.bool)
if choice:
a = self.pool(F.relu(self.conv1a(x)))
a = self.pool(F.relu(self.conv2a(a)))
a *= gate_A
with torch.no_grad(): # disable gradient computation
b = self.pool(F.relu(self.conv1b(x)))
b = self.pool(F.relu(self.conv2b(b)))
b *= gate_B
else:
with torch.no_grad(): # disable gradient computation
a = self.pool(F.relu(self.conv1a(x)))
a = self.pool(F.relu(self.conv2a(a)))
a *= gate_A
b = self.pool(F.relu(self.conv1b(x)))
b = self.pool(F.relu(self.conv2b(b)))
b *= gate_B
x = torch.cat( [a,b], dim = 1 )
x = torch.flatten(x, 1) # flatten all dimensions except batch
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
On a second look, I think the following is a simpler solution to the specific problem:
def forward(self, x):
# Only one gate is supposed to be enabled at random
# However, for the experiment, I fixed the values to [1,0] and [1,1]
choice = randint(0,1)
if choice:
a = self.pool(F.relu(self.conv1a(x)))
a = self.pool(F.relu(self.conv2a(a)))
b = torch.zeros(shape_of_conv_output) # replace shape of conv output here
else:
b = self.pool(F.relu(self.conv1b(x)))
b = self.pool(F.relu(self.conv2b(b)))
a = torch.zeros(shape_of_conv_output) # replace shape of conv output here
x = torch.cat( [a,b], dim = 1 )
x = torch.flatten(x, 1) # flatten all dimensions except batch
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
Upvotes: 1