Reputation: 23
Here is my code. The packages imported are not shown. I am trying to feed the CIFAR-10 test data into alexnet. The dictionary at the end needs to be sorted so I can find the most common classification. Please help, I have tried everything!
............................................................................................................................................................................................................................................................................................................
alexnet = models.alexnet(pretrained=True)
transform = transforms.Compose([ #[1]
transforms.Resize(256), #[2]
transforms.CenterCrop(224), #[3]
transforms.ToTensor(), #[4]
transforms.Normalize( #[5]
mean=[0.485, 0.456, 0.406], #[6]
std=[0.229, 0.224, 0.225] #[7]
)])
# Getting the CIFAR-10 dataset
dataset = CIFAR10(root='data/', download=True, transform=transform)
test_dataset = CIFAR10(root='data/', train=False, transform=transform)
classes = dataset.classes
#print(classes)
torch.manual_seed(43)
val_size = 10000
train_size = len(dataset) - val_size
train_ds, val_ds = random_split(dataset, [train_size, val_size])
#print(len(train_ds), len(val_ds))
batch_size=100
train_loader = DataLoader(train_ds, batch_size, shuffle=True, num_workers=8, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size, num_workers=8, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size, num_workers=8, pin_memory=True)
with open("/home/shaan/Computer Science/CS4442/Ass4/imagenet_classes.txt") as f:
classes = eval(f.read())
holder = []
dic = {}
current = ''
#data_iter = iter(test_loader)
#images,labels = data_iter.next()
#alexnet.eval()
with torch.no_grad():
for data in test_loader:
images, labels = data
out = alexnet(images)
#print(out.shape)
for j in range(0,batch_size):
sorted, indices = torch.sort(out,descending=True)
percentage = F.softmax(out,dim=1)[j]*100
results = [(classes[i.item()],percentage[i].item()) for i in indices[j][:5]]
holder.append(results[0][0])
holder.sort()
for z in holder:
if current != z:
count = 1
dic[z] = count
current = z
else:
count = count + 1
dic[z] = count
current = z
This is where im getting the error:
for w in sorted(dic, key=dic.get, reverse=True):
print(w, dic[w])
Upvotes: 1
Views: 508
Reputation: 2440
This line is the problem
sorted, indices = torch.sort(out,descending=True)
You created a variable named sorted
, which is exactly the same name as sorted
function you call when it error.
Just change this to something else like
sorted_out, indices = torch.sort(out,descending=True)
Upvotes: 1