DevLoverUmar
DevLoverUmar

Reputation: 14011

Error loading Pytorch model checkpoint: _pickle.UnpicklingError: invalid load key, '\x1f'

I'm trying to load the weights of a Pytorch model but getting this error: _pickle.UnpicklingError: invalid load key, '\x1f'.

Here is the weights loading code:

import os
import torch
import numpy as np
# from data_loader import VideoDataset
import timm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device being used:', device)

mname = os.path.join('./CDF2_0.pth')
checkpoints = torch.load(mname, map_location=device)
print("Checkpoint loaded successfully.")
model = timm.create_model('legacy_xception', pretrained=True, num_classes=2).to(device)
model.load_state_dict(checkpoints['state_dict'])
model.eval()

I have tried with different Pytorch versions. I have tried to inspect the weights by changing the extension to .zip and opening with archive manager but can't fix the issue. Here is a public link to the weights .pth file, I'm trying to load. Any help is highly appreciated as I have around 40 trained models that took around one month for training!

Upvotes: 3

Views: 401

Answers (1)

Marcin Wrochna
Marcin Wrochna

Reputation: 519

The error is typical when trying to open a gzip file as if it was a pickle or pytorch file, because gzips start with a 1f byte. But this is not a valid gzip: it looks like a corrupted pytorch file.

Indeed, looking at hexdump -C file.pt | head (shown below), most of it looks like a pytorch file (which should be a ZIP archive, not gzip, containing a python pickle file named data.pkl). But the first few bytes are wrong: instead of starting like a ZIP file as it should (bytes 50 4B or ASCII PK), it starts like a gzip file (1f 8b 08 08). In fact it's exactly as if the first 31 bytes were replaced with a valid, empty gzip file (with a timestamp ff 35 29 67 pointing to November 4, 2024 9:00:47 PM GMT).

Your file:

00000000  1f 8b 08 08 ff 35 29 67  02 ff 43 44 46 32 5f 30  |.....5)g..CDF2_0|
00000010  2e 70 74 68 00 03 00 00  00 00 00 00 00 00 00 44  |.pth...........D|
00000020  46 32 5f 30 2f 64 61 74  61 2e 70 6b 6c 46 42 0f  |F2_0/data.pklFB.|
00000030  00 5a 5a 5a 5a 5a 5a 5a  5a 5a 5a 5a 5a 5a 5a 5a  |.ZZZZZZZZZZZZZZZ|
00000040  80 02 7d 71 00 28 58 08  00 00 00 62 65 73 74 5f  |..}q.(X....best_|
00000050  61 63 63 71 01 63 6e 75  6d 70 79 2e 63 6f 72 65  |accq.cnumpy.core|
...

(inspecting the pickle data we can see a dictionary {"best_acc": ..., "state_dict": ...}) with the typical contents of a checkpoint of a pytorch model).

A valid zipped pickle produced by torch.save({"best_acc": np.array([1]), "state_dict": ...}, "CDF2_0.pth"):

00000000  50 4b 03 04 00 00 08 08  00 00 00 00 00 00 00 00  |PK..............|
00000010  00 00 00 00 00 00 00 00  00 00 0f 00 13 00 43 44  |..............CD|
00000020  46 32 5f 30 2f 64 61 74  61 2e 70 6b 6c 46 42 0f  |F2_0/data.pklFB.|
00000030  00 5a 5a 5a 5a 5a 5a 5a  5a 5a 5a 5a 5a 5a 5a 5a  |.ZZZZZZZZZZZZZZZ|
00000040  80 02 7d 71 00 28 58 08  00 00 00 62 65 73 74 5f  |..}q.(X....best_|
00000050  61 63 63 71 01 63 6e 75  6d 70 79 2e 63 6f 72 65  |accq.cnumpy.core|
...

A gzip containing an empty file with the same name and timestamp (with gzip --best) has 31 bytes, the same as your file's prefix (except for the two 'Operating System' bytes):

00000000  1f 8b 08 08 ff 35 29 67  02 03 43 44 46 32 5f 30  |.....5)g..CDF2_0|
00000010  2e 70 74 68 00 03 00 00  00 00 00 00 00 00 00     |.pth...........|

Edit: Here's a script that might fix such files in general:

#!/usr/bin/env python3
import os
import sys
from pathlib import Path
from shutil import copy2
from tempfile import TemporaryDirectory

import numpy as np
import torch

CHUNK_SIZE = 4

def main(orig_path: Path) -> None:
    fixed_path = orig_path.with_suffix(".fixed.pth")
    copy2(orig_path, fixed_path)

    with TemporaryDirectory() as temp_dir:
        temp_path = Path(temp_dir) / orig_path.name
        torch.save({"best_acc": np.array([1]), "state_dict": {}}, temp_path)

        with open(temp_path, "rb") as f_temp:
            with open(fixed_path, "rb+") as f_fixed:
                while True:
                    content = f_fixed.read(CHUNK_SIZE)
                    replacement = f_temp.read(CHUNK_SIZE)
                    if content == replacement:
                        break
                    print(f"Replacing {content!r} with {replacement!r}")
                    f_fixed.seek(-CHUNK_SIZE, os.SEEK_CUR)
                    f_fixed.write(replacement)


if __name__ == "__main__":
    assert len(sys.argv) == 2, "Expected exactly one argument (the path to the broken .pth file)."
    main(Path(sys.argv[1]))

Upvotes: 6

Related Questions