mamafoku
mamafoku

Reputation: 1139

How to pass weights when using Sklearn GridSearchCV with Pipeline

I am working on a text classification model, and I am using a Pipeline coupled with GridSearch Cross Validation. Code Snippets below:

count_vec=CountVectorizer(ngram_range=(1,2),stop_words=Stopwords_X,min_df=0.01)
TFIDF_Transformer=TfidfTransformer(sublinear_tf=True,norm='l2')

my_pipeline=Pipeline([('Count_Vectorizer',count_vec),
                    ('TF_IDF',TFIDF_Transformer),
                    ('MultiNomial_NB',MultinomialNB())])

param_grid={'Count_Vectorizer__ngram_range':[(1,1),(1,2),(2,2)],
               'Count_Vectorizer__stop_words':[Stopwords_X,stopwords],
               'Count_Vectorizer__min_df':[0.001,0.005,0.01],
               'TF_IDF__sublinear_tf':[True,False],
               'TF_IDF__norm':['l2'],
               'TF_IDF__smooth_idf':[True,False],
               'MultiNomial_NB__alpha':[0.2,0.4,0.5,0.6],
               'MultiNomial_NB__fit_prior':[True,False]}

# Grid Search CV with pipeline
model=GridSearchCV(estimator=my_pipeline,param_grid=param_grid,
                   scoring=scoring,cv=4,verbose=1,refit=False)

However, as the data is highly imbalanced, I want to pass weights to the MultinomialNB classifier in the pipeline. I know that I can pass weights to elements within the pipeline (as shown below):

model.fit(Data_Labeled['Clean-Merged-Final'], 
          Data_Labeled['Labels'],MultiNomial_NB__sample_weight=weights)

My question is how does this compile without a shape error? as weights are only passed to the final element (MultiNomial_NB classifier) in the pipeline while CV partitions the X/Y feed entering the pipeline.

Upvotes: 3

Views: 2498

Answers (1)

Vivek Kumar
Vivek Kumar

Reputation: 36619

GridSearchCV handles the appropriate breaking up of sample_weights according to the cross-validation iterator.

GridSearchCV calls the _fit_and_score() method internally on the data and passes the indices for the training data. Up until now, the fit_params are for the whole data. Now this function in turn calls the function _index_param_value, which handles the splitting of the sample_weight (or other fit_params) in this line:

     ...
     return safe_indexing(v, indices)
     ...

This has been discussed in issues here:

Upvotes: 3

Related Questions