Dario Federici
Dario Federici

Reputation: 1258

pytorch passing architecture type with argprse

Using Pytorch. When passing architecture type by using the following code:

parser.add_argument('-arch', action='store',
                    dest='arch',
                    default= str('vgg16'))

When using the name of the architecture with the following code:

model = models.__dict__['{!r}'.format(results.arch)](pretrained=True)

I get the following error:

model = models.dict'{!r}'.format(results.arch) KeyError: "'vgg16'"

What am I doing wrong?

Upvotes: 1

Views: 133

Answers (2)

Dario Federici
Dario Federici

Reputation: 1258

model = models.__dict__[results.arch](pretrained=True)

Solution.

Upvotes: 0

Shai
Shai

Reputation: 114886

You got KeyError meaning your imported models do not include 'vgg16' as one of the known models.
Check what models you do have by printing

print(models.__dict__.keys())

This should allow you to know what models you import and which are missing, then you can look into your imports and see where 'vgg16' got lost.

Upvotes: 1

Related Questions