Reputation: 6207
I have a Python function call like so:
import torchvision
model = torchvision.models.resnet18(pretrained=configs.use_trained_models)
Which works fine.
If I attempt to make it dynamic:
import torchvision
model_name = 'resnet18'
model = torchvision.models[model_name](pretrained=configs.use_trained_models)
then it fails with:
TypeError: 'module' object is not subscriptable
Which makes sense since model
is a module which exports a bunch of things, including the resnet functions:
# __init__.py for the "models" module
...
from .resnet import *
...
How can I call this function dynamically without knowing ahead of time its name (other than that I get a string with the function name)?
Upvotes: 1
Views: 656
Reputation: 68858
The new APIs since Aug 2022 are as follows:
# returns list of available model names
torchvision.models.list_models()
# returns specified model with pretrained common weights
torchvision.models.get_model("alexnet", weights="DEFAULT")
# returns specified model with pretrained=False
torchvision.models.get_model("alexnet", weights=None)
# returns specified model with specified pretrained weights
torchvision.models.get_model("alexnet", weights=ResNet50_Weights.IMAGENET1K_V2)
Reference: https://pytorch.org/blog/easily-list-and-initialize-models-with-new-apis-in-torchvision/
Upvotes: 2
Reputation: 27515
You can use the getattr
function:
import torchvision
model_name = 'resnet18'
model = getattr(torchvision.models, model_name)(pretrained=configs.use_trained_models)
This essentially is along the lines the same as the dot notation just in function form accepting a string to retrieve the attribute/method.
Upvotes: 3