Shaikh Shehbaz
Shaikh Shehbaz

Reputation: 101

PyTorch UserWarning: Failed to initialize NumPy: _ARRAY_API not found and BERTModel weight initialization issue

I am working with PyTorch and the Hugging Face Transformers library to fine-tune a BERT model (UFNLP/gatortron-base) for a downstream task.

I received a warning related to NumPy initialization:

C:\Users\user\AppData\Local\Programs\Python\Python312\Lib\site-packages\torch\storage.py:321: UserWarning: Failed to initialize NumPy: _ARRAY_API not found (Triggered internally at ..\torch\csrc\utils\tensor_numpy.cpp:84.)

My code:

type himport torch
from transformers import BertTokenizer, BertModel

tokenizer = BertTokenizer.from_pretrained('UFNLP/gatortron-base')
model = BertModel.from_pretrained('UFNLP/gatortron-base')

model.eval()

def prepare_input(text):
    tokens = tokenizer.encode_plus(text, return_tensors='pt', add_special_tokens=True, max_length=512, truncation=True)
    return tokens['input_ids'], tokens['attention_mask']

def get_response(input_ids, attention_mask):        
    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        if 'logits' in outputs:
            predictions = torch.argmax(outputs['logits'], dim=-1)
        else:
            # Adjust this based on the actual structure of `outputs`
            predictions = torch.argmax(outputs[0], dim=-1) 

        # predictions = torch.argmax(outputs.logits, dim=-1)
        return tokenizer.decode(predictions[0], skip_special_tokens=True)

input_text = "Hello, how are you?"
input_ids, attention_mask = prepare_input(input_text)
response = get_response(input_ids, attention_mask)
print("Response from the model:", response)ere

Upvotes: 10

Views: 17246

Answers (3)

Roscoe - ROSCODE
Roscoe - ROSCODE

Reputation: 39

I had this same problem, and I sorted it out by upgrading my torch version from 1.x.x to 2.x.x

Upvotes: 0

user3380108
user3380108

Reputation: 17

pip uninstall numpy 
pip install numpy<2

Upvotes: -2

Farhan Hai Khan
Farhan Hai Khan

Reputation: 828

pip install --force-reinstall -v "numpy==1.25.2"

Fixed the issue for me.

This was following this github discussion from : https://github.com/stitionai/devika/issues/606

All thanks to @HOBE for the comment above

Upvotes: 13

Related Questions