Tobias Uhmann
Tobias Uhmann

Reputation: 3057

Where to implement pre-processing in PyTorch Lightning (e.g. tokenizing input text)

Is there a convention to implement some kind of predict() method in PyTorch Lightning that does pre-processing before performing the actual prediction using forward()?

In my case, I have a text classifier consisting of an embedding layer and a few fully connected layers. The text needs to be tokenized before being passed to the embedding layer. During training and evaluation the LightningDataModule's setup() method does the job.

Now, I'm wondering what the best practice for inference during production is. I could add a predict() method to my LightningModule where I could write the same pre-processing code as in LightningDataModule.setup(). But, of course, I do not want to duplicate the code.

In this community example project linked in the official PyTorch Lightning docs, the authors define a prepare_sample() function in the LightningModule that is used by their predict() function, and is also passed to the LightningDataModule.

Is this the right way to handle pre-processing? Also, why is there no prepare_sample() or predict() in LightningModule? To me, this seems like a common use case, for example:

model = load_model('data/model.ckpt')  # load pre-trained model, analyzes user reviews

user_input = input('Your movie review > ')

predicted_rating = model.predict(user_input)  # e.g. "I liked the movie pretty much." -> 4 stars

print('Predicted rating: %s/5 stars' % predicted_rating)

Now that I think about it, predict() should also process the result from forward() the same way the evaluation code does, like selecting the class with the highest output or selecting all classes with outputs larger than some threshold - some more code that should not be duplicated.

Upvotes: 2

Views: 1823

Answers (1)

Fredrik
Fredrik

Reputation: 497

Why do you use a LightningModule if the code should be for production? If the model is finished you only need to load the model from memory and define the preprocess steps.

The repository you refer to have implemented the predict, and prepare_sample on top of the LightningModule.

In my opinion pytorch-lightning is for training and evaluation of the model and not for production. We would not want to keep the analytics and debugging when sending a model to production so instead we create a slimmed version which only have loading of model, preprocess and prediction.

Towardsdatascience have a small code example: https://towardsdatascience.com/how-to-deploy-pytorch-lightning-models-to-production-7e887d69109f

Upvotes: 1

Related Questions