Reputation: 11
I'm working on a Billion-scale semi-supervised learning project for image classification. The general process is as follows:
The first step is to train an initial teacher model A using labeled data;
The second step is to use teacher model A to make predictions on unlabeled data, sort the images of each category label, select the best K to construct a new training data set, that is, a pseudo-labeled dataset;
There are more steps ahead, but when the second step was performed, the above error message was reported. The role of the select_top_k
function is - the key-value pairs extracted from the json file are arranged in descending order of the key elements, and the first k elements are selected, and the program runs to this place and an error occurs.
The specific error information is as follows:
Load Model Accuracy: 75.64 Load Model end epoch: 100
class name: apple
image data count: 210
class name: aquarium_fish
image data count: 203
class name: baby
image data count: 151
class name: bear
image data count: 189
...
class name: wolf
image data count: 212
class name: woman
image data count: 199
class name: worm
image data count: 173
Saving.. sampling_dict
label: 39
each label item count: 18779
label: 82
each label item count: 18779
label: 20
each label item count: 18626
label: 0
each label item count: 18340
label: 9
each label item count: 13232
label: 6
each label item count: 5547
label: 3
each label item count: 17344
label: 16
each label item count: 4228
label: 2
each label item count: 17960
label: 24
each label item count: 1880
label: 10
each label item count: 1581
label: 5
each label item count: 14524
label: 4
each label item count: 16767
label: 61
each label item count: 799
Traceback (most recent call last):
File "make_sample_data_1.py", line 155, in <module>
main(args)
File "make_sample_data_1.py", line 148, in main
select_top_k(args.k)
File "make_sample_data_1.py", line 128, in select_top_k
sampled_image_dict["all"].append([all_items[index][0], int(key)])
IndexError: index 799 is out of bounds for axis 0 with size 799
Related code snippets:
1 def select_top_k(k=1000):
2 sampled_image_dict = {}
3 sampled_image_dict["all"] = []
4 with codecs.open("./sampling_dict.json", "r", encoding="utf-8", errors="ignore") as f:
5 load_data = json.load(f)
6
7 for key in load_data.keys():
8 print("label: ", key)
9 all_items = load_data[key]
10 all_items.sort(key=lambda x: x[1], reverse=True)
11 all_items = np.array(all_items)
12 print("each label item count: ", len(all_items))
13 for index in range(0, k):
14 sampled_image_dict["all"].append([all_items[index][0], int(key)])
15
16 print("Saving.. selected image json")
17 j = json.dumps(sampled_image_dict)
18 with open("selected_image.json", "w") as f:
19 f.write(j)
[Note] The error message is on line 14.
Upvotes: 1
Views: 173
Reputation: 771
In your code you are looping from 0 to k and by default you are setting k to 1000. When you reach data that has less than 1000 records it will fail as soon as it reaches the limit.
You could either set k to a lower number like 500 or change your code like the below to make sure that k is never higher than the length of all_items.
Another option is making K a percentage of the all_items and just do that instead of taking the top 1000.
Code below to make sure that k never exceeds the length of your all_items.
1 def select_top_k(k=1000):
2 sampled_image_dict = {}
3 sampled_image_dict["all"] = []
4 with codecs.open("./sampling_dict.json", "r", encoding="utf-8", errors="ignore") as f:
5 load_data = json.load(f)
6
7 for key in load_data.keys():
8 print("label: ", key)
9 all_items = load_data[key]
10 all_items.sort(key=lambda x: x[1], reverse=True)
11 all_items = np.array(all_items)
## make sure if you have data less than k it uses a smaller amount.
if len(all_items) < k:
k = len(all_items)
12 print("each label item count: ", len(all_items))
13 for index in range(0, k):
14 sampled_image_dict["all"].append([all_items[index][0], int(key)])
15
16 print("Saving.. selected image json")
17 j = json.dumps(sampled_image_dict)
18 with open("selected_image.json", "w") as f:
19 f.write(j)
Upvotes: 1