Reputation: 11
Here is a simplified version of the model I am working on:
class InferContextModel(nn.Module):
def __init__(self, input_size, context_size, output_size):
super().__init__()
self.context_size = context_size
self.embedding_size = input_size
self.output_size = output_size
self.alpha = nn.Parameter(torch.tensor(1.0, dtype=torch.float32))
self.linear_layer = nn.Linear(self.output_size, self.embedding_size)
self.previous_input = None
def init_hidden(self, batch_size, device):
return torch.zeros(batch_size, self.context_size, device=device)
def recurrence(self, input, hidden):
context.requires_grad_(True)
if self.previous_input is not None:
prev_prediction = self.linear(self.previous_input)
loss = 0.5 * torch.sum((prev_prediction - input) ** 2)
hidden_grad = torch.autograd.grad(loss, hidden, retain_graph=False)[0]
# Update context without requiring gradients
with torch.no_grad():
hidden = hidden - self.alpha * hidden_grad.detach()
hidden = hidden.detach()
self.previous_input = input.detach()
return self.output_layer(hidden), hidden
def forward(self, input, context=None, num_steps=1):
batch_size = input.shape[1]
seq_len = input.size(0)
outputs = torch.empty(
seq_len,
batch_size,
self.vocab_size,
device=input.device
)
if context is None:
context = self.init_hidden(batch_size, input.device)
for i in range(seq_len):
output = None
for _ in range(num_steps):
output, context = self.recurrence(input[i], context)
outputs[i] = output
return outputs, context
Obviously this model doesn't make much sense, but I don't think the complexity that makes the model function is required to address the bug I can't get past. When I run the model I get the error
Traceback (most recent call last):
File "/geode2/home/u040/joshnunl/BigRed200/projects_with_transformers/project5_main.py", line 327, in <module>
losses = estimate_loss(model, eval_iters, train_data, val_data,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/N/u/joshnunl/BigRed200/.local/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/geode2/home/u040/joshnunl/BigRed200/projects_with_transformers/utils.py", line 44, in estimate_loss
logits, loss = model(X.T, Y)
^^^^^^^^^^^^^
File "/N/u/joshnunl/BigRed200/.local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/N/u/joshnunl/BigRed200/.local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/geode2/home/u040/joshnunl/BigRed200/projects_with_transformers/models/Context_RNN.py", line 195, in forward
logits, context = self.rnn(x)
^^^^^^^^^^^
File "/N/u/joshnunl/BigRed200/.local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/N/u/joshnunl/BigRed200/.local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/geode2/home/u040/joshnunl/BigRed200/projects_with_transformers/models/Context_RNN.py", line 151, in forward
output, context = self.recurrence(input[i], context)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/geode2/home/u040/joshnunl/BigRed200/projects_with_transformers/models/Context_RNN.py", line 116, in recurrence
context_grad = torch.autograd.grad(loss, context, retain_graph=False)[0]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/N/u/joshnunl/BigRed200/.local/lib/python3.11/site-packages/torch/autograd/__init__.py", line 436, in grad
result = _engine_run_backward(
^^^^^^^^^^^^^^^^^^^^^
File "/N/u/joshnunl/BigRed200/.local/lib/python3.11/site-packages/torch/autograd/graph.py", line 768, in _engine_run_backward
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
What I want to do should be relatively simple. I just want to find the gradient with respect to the hidden state, but I don't want the hidden state to be updated as a part of the model parameters. Any suggestions?
I have made multiple attempts to solve this issue with multiple sota LLMs and have not come up with a working solution.
Upvotes: 0
Views: 35