balbok
balbok

Reputation: 426

Taking a derivative through torch.ge, or how to explicitly define a derivative in pytorch

I am trying to set up a network in which one layer maps from real numbers to {0, 1} (i.e. makes output binary).

What I tried

While I was able to find that torch.ge provides such functionality, whenever I want to train any parameter occurring before that layer in a network PyTorch breaks.

I have been also trying to find if there is any way in PyTorch/autograd, to override the derivative of a module by hand. More specifically in this cause, I would just like to pass derivative through the torch.ge, without changing it.

Minimal Example

Here is a minimal example I produced, which uses a typical neural network training structure in PyTorch.

import torch
import torch.nn as nn
import torch.optim as optim


class LinearGE(nn.Module):
    def __init__(self, features_in, features_out):
        super().__init__()
        self.fc = nn.Linear(features_in, features_out)

    def forward(self, x):
        return torch.ge(self.fc(x), 0)


x = torch.randn(size=(10, 30))
y = torch.randint(2, size=(10, 10))

# Define Model
m1 = LinearGE(30, 10)

opt = optim.SGD(m1.parameters(), lr=0.01)

crit = nn.MSELoss()

# Train Model
for x_batch, y_batch in zip(x, y):
    # zero the parameter gradients
    opt.zero_grad()

    # forward + backward + optimize
    pred = m1(x_batch)
    loss = crit(pred.float(), y_batch.float())
    loss.backward()
    opt.step()

What I encountered

When I run the above code the following error occurs:

File "__minimal.py", line 33, in <module>
    loss.backward()
...
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

This error makes sense since torch.ge function is not differentiable. However, since MaxPool2D is also not differentiable, I believe that there are ways of mitigating non-differentiability in PyTorch.

It would be great if someone could point me to any source which can help me either implement my own backprop for a custom module, or any way of avoiding this error message.

Thanks!

Upvotes: 1

Views: 1365

Answers (1)

Mr_U4913
Mr_U4913

Reputation: 1354

Two things I noticed

  1. If your input x is 10x30 (10 examples, 30 features)and the number of output node is 10, then the parameter matrix is 30x10. The expected output matrix is 10x10 (10 examples 10 output nodes)

  2. ge = greater than and equal to. As the code indicated, x >= 0 element wise. We can use relu.

class LinearGE(nn.Module):
    def __init__(self, features_in, features_out):
        super().__init__()
        self.fc = nn.Linear(features_in, features_out)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        return self.relu(self.fc(x))

or torch.max

torch.max(self.fc(x), 0)[0]

Upvotes: 1

Related Questions