Reputation: 1
I'm working with Mistral Inference on Windows using PyTorch and xformers. To fix a reshape error in the attention layers, I implemented a monkey patch that replaces the fixed tensor reshaping calls in the attention modules with a dynamic alternative. I tried two approaches:
A recursive patch on all modules that have the attributes n_heads and head_dim to override their forward methods.
A specific patch targeting the Attention class in transformer_layers.
My debug logs confirm that the patched methods are being called and that the computed token count is correct. However, during generation I still get an error stating that a tensor cannot be reshaped as expected. The error message is:
RuntimeError: shape '[1, 4096]' is invalid for input of size 753664
It appears that somewhere in the code there is still a fixed reshape call that is not being overridden by my patch. I'm seeking advice on how to force a complete override of this reshape operation without directly modifying the source code, or whether modifying the file directly is the only solution.
How to resolve this?
I tried two approaches to replace the fixed reshape operation in the attention layers:
I applied a recursive patch to all modules that have the attributes n_heads and head_dim, overriding their forward methods to use a dynamic reshaping with .reshape(-1, computed_size) instead of a fixed .view().
I implemented a specific patch targeting the Attention class in the transformer layers to override its forward method in the same way.
I expected that, after applying these patches, the output tensor would be dynamically reshaped to the correct dimensions (for example, [34, 4096] in my case) instead of forcing a shape of [1, 4096]. However, although the debug logs confirm that the computed token count is correct, the error still occurs—indicating that somewhere in the code the fixed .view() call is still being used.
Upvotes: 0
Views: 28