user023049
user023049

Reputation: 11

In PyTorch how do I compute the gradient of a matrix multiplication with respect to the hidden state inside the forward pass?

Here is a simplified version of the model I am working on:

class InferContextModel(nn.Module):
    def __init__(self, input_size, context_size, output_size):
        super().__init__()
        self.context_size = context_size
        self.embedding_size = input_size
        self.output_size = output_size

        self.alpha = nn.Parameter(torch.tensor(1.0, dtype=torch.float32))
        self.linear_layer = nn.Linear(self.output_size, self.embedding_size)
        self.previous_input = None

    def init_hidden(self, batch_size, device):
        return torch.zeros(batch_size, self.context_size, device=device)

    def recurrence(self, input, hidden):
        context.requires_grad_(True)
        
        if self.previous_input is not None:
            prev_prediction = self.linear(self.previous_input)
            loss = 0.5 * torch.sum((prev_prediction - input) ** 2)
            hidden_grad = torch.autograd.grad(loss, hidden, retain_graph=False)[0]
            
            # Update context without requiring gradients
            with torch.no_grad():
                hidden = hidden - self.alpha * hidden_grad.detach()
                hidden = hidden.detach()

            
        self.previous_input = input.detach()
        return self.output_layer(hidden), hidden

    def forward(self, input, context=None, num_steps=1):
        batch_size = input.shape[1]
        seq_len = input.size(0)
        outputs = torch.empty(
            seq_len, 
            batch_size, 
            self.vocab_size, 
            device=input.device
        )
        
        if context is None:
            context = self.init_hidden(batch_size, input.device)
        
        for i in range(seq_len):
            output = None
            for _ in range(num_steps):
                output, context = self.recurrence(input[i], context)
            outputs[i] = output
        
        return outputs, context

Obviously this model doesn't make much sense, but I don't think the complexity that makes the model function is required to address the bug I can't get past. When I run the model I get the error

Traceback (most recent call last):
  File "/geode2/home/u040/joshnunl/BigRed200/projects_with_transformers/project5_main.py", line 327, in <module>
    losses = estimate_loss(model, eval_iters, train_data, val_data,
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/N/u/joshnunl/BigRed200/.local/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/geode2/home/u040/joshnunl/BigRed200/projects_with_transformers/utils.py", line 44, in estimate_loss
    logits, loss = model(X.T, Y)
                   ^^^^^^^^^^^^^
  File "/N/u/joshnunl/BigRed200/.local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/N/u/joshnunl/BigRed200/.local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/geode2/home/u040/joshnunl/BigRed200/projects_with_transformers/models/Context_RNN.py", line 195, in forward
    logits, context = self.rnn(x)
                      ^^^^^^^^^^^
  File "/N/u/joshnunl/BigRed200/.local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/N/u/joshnunl/BigRed200/.local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/geode2/home/u040/joshnunl/BigRed200/projects_with_transformers/models/Context_RNN.py", line 151, in forward
    output, context = self.recurrence(input[i], context)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/geode2/home/u040/joshnunl/BigRed200/projects_with_transformers/models/Context_RNN.py", line 116, in recurrence
    context_grad = torch.autograd.grad(loss, context, retain_graph=False)[0]
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/N/u/joshnunl/BigRed200/.local/lib/python3.11/site-packages/torch/autograd/__init__.py", line 436, in grad
    result = _engine_run_backward(
             ^^^^^^^^^^^^^^^^^^^^^
  File "/N/u/joshnunl/BigRed200/.local/lib/python3.11/site-packages/torch/autograd/graph.py", line 768, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

What I want to do should be relatively simple. I just want to find the gradient with respect to the hidden state, but I don't want the hidden state to be updated as a part of the model parameters. Any suggestions?

I have made multiple attempts to solve this issue with multiple sota LLMs and have not come up with a working solution.

Upvotes: 0

Views: 35

Answers (0)

Related Questions