Feng Chen
Feng Chen

Reputation: 2253

How to use python package multiprocessing in metaflow?

I am trying to run multiprocessing package in metaflow, in which fasttext model is running to predict some results. Here is my code:

import pickle
import os
import boto3
import multiprocessing
from functools import partial
from multiprocessing import Manager
import time
import pickle


from metaflow import batch, conda, FlowSpec, step, conda_base, Flow, Step
from util import pip_install_module
 

@conda_base(libraries={'scikit-learn': '0.23.1', 'numpy': '1.22.4', 'pandas': '1.5.1', 'fasttext': '0.9.2'}) 
class BatchInference(FlowSpec):
    pip_install_module("python-dev-tools", "2023.3.24")

    @batch(cpu=10, memory=120000)
    @step
    def start(self):
        import pandas as pd
        import numpy as np

        self.df_input = ['af', 'febrt' ,'fefv fd we' ,'fe hth dw hytht' ,' dfegrtg hg df reg']

        self.next(self.predict)



    @batch(cpu=10, memory=120000)
    @step
    def predict(self):
        import fasttext
        fasttext.FastText.eprint = lambda x: None

        print('model reading started')
        
        #download the fasttext model from aws s3.

        manager = Manager()
        model_abn = manager.list([fasttext.load_model('fasttext_model.bin')])

        
        print('model reading finished')

    
        time_start = time.time()

        pool = multiprocessing.Pool()
        #results = pool.map(self.predict_abn, self.df_input)
        results = pool.map(partial(self.predict_abn, model_abn=model_abn), self.df_input)

        pool.close()
        pool.join()

        time_end = time.time()
        print(f"Time elapsed: {round(time_end - time_start, 2)}s")

        self.next(self.end)


    @step
    def end(self):
        print("Predictions evaluated successfully")


    def predict_abn(self,text, model_abn):
        model = model_abn[0]
        return model.predict(text,k=1)


if __name__ == '__main__':
    BatchInference()

The error message is:

TypeError: cannot pickle 'fasttext_pybind.fasttext' object

I was told this is because fasttext model cannot be serialised. And I also try other message, for example:

self.model_bytes_abn = pickle.dumps(model_abn)

to transfer the model to bytes type. But still does not work.

Plz tell me what is wrong about the code and how to fix it?

Upvotes: 1

Views: 147

Answers (1)

Nopileos
Nopileos

Reputation: 2097

As your error says the pybind of fasttext can't be pickled

TypeError: cannot pickle 'fasttext_pybind.fasttext' object

This is a general problem when using pybindings they are normally not able to be pickled.

So your model_abn is a list of some objects from the pybind lib and thus can't be pickled. In general you can solve this by initalizing whatever you need which is not serializable in the function that is called by the multiprocessing. So that every process creates their own objects and nothing has to be pickled.

In your case this is probably not feasible since the thing that is done by multiprocessing is just a simple call executing the model.

It is a bit of a design question where to put things, how to separate, if you even want multiprocessing under these circumstances. What you can do and is keep most of the code the same is use the initializer argument of the Pool.

def predict_model(input_data):
    global model
    return model.predict(input_data)


def init_worker():
    global model
    model = ... # Do whatever you have to to init it


def some_func():
    ...
    pool = Pool(num_worker, initializer=init_worker)
    res = pool.map(predict_model, some_list)
    ...

So you when the pool is created every worker runs the init_worker function and has its own model stored as a global variable. Which you can use in the predict_model function you want to execute via map.

No matter what you do if you want to use it with multiprocessing you somehow need to have the model exist in each process and be initialized by the process, since you can't serialize it and distribute it.

Upvotes: 2

Related Questions