Problems with selecting OptunaSearchCV pipeline parameters

I have a code with which I iterate over the hyperparameters of both the model itself and the entire pipeline

preprocessor = ColumnTransformer(
            [
                ('OneHotEncoder', OneHotEncoder(drop='if_binary', sparse_output=False), binary_cols),
                ('CatBoostEncoder', CatBoostEncoder(random_state=RANDOM_STATE), non_binary_cat_cols),
                ('StandardScaler', StandardScaler(), num_cols)
            ],
            verbose_feature_names_out=False,
            remainder='drop'
        )
    
        pipe_final = ImbPipeline([
            ('preprocessor', preprocessor),
            ('target_imbalance', ADASYN()),
            ('selection', PCA()),
            ('models', CatBoostClassifier(random_state=RANDOM_STATE))
        ])

        # Гиперпараметры для CatBoostClassifier
        param_grid = {
            'models__iterations': [1000, 2000, 3000],
            'models__class_weights': ['Balanced', None],
            'target_imbalance': [ADASYN(random_state=RANDOM_STATE), SMOTETomek(random_state=RANDOM_STATE),
                                 SMOTE(random_state=RANDOM_STATE, k_neighbors=10), 'passthrough'],
            'preprocessor__StandardScaler': [StandardScaler(), RobustScaler(), MinMaxScaler(),
                                             PowerTransformer(), QuantileTransformer(),
                                             Normalizer(),PolynomialFeatures(degree=2, include_bias=False), 'passthrough'],
            'selection': [PCA(random_state=RANDOM_STATE, n_components="mle", svd_solver="full"),
                          SelectKBest(mutual_info_classif, k=40),
                          SelectKBest(f_classif, k=40),
                          SelectKBest(chi2, k=40),
                          SelectPercentile(mutual_info_classif, percentile=10),
                          SelectPercentile(f_classif, percentile=10),
                          SelectFromModel(CatBoostClassifier(random_state=RANDOM_STATE)),
                          SelectFromModel(LogisticRegression(random_state=RANDOM_STATE)),
                          SelectFromModel(RandomForestClassifier(random_state=RANDOM_STATE)),
                          'passthrough'],
        }

        gs = GridSearchCV(
            pipe_final, 
            param_grid, 
            cv=5, 
            scoring='roc_auc', 
            n_jobs=-1
        )

        # Запускаем поиск гиперпараметров
        gs.fit(X, y_enc)

It takes a very long time to complete and I want to speed it up. For this I want to use OptunaSearchCV. Do I understand correctly that using OptunaSearchCV I can enumerate the hyperparameters of the model itself, but not the entire pipeline, because Is there no Distribution with which I can set SelectKBest(f_classif, k=40), RobustScaler(), etc.?

Sorry if my wording is not accurate somewhere, I use Google translator, because... I am not a native speaker

Upvotes: 0

Views: 276

Answers (1)

Ben Reiniger
Ben Reiniger

Reputation: 12582

Optuna says it requires an Optuna Distribution object for each hyperparameter's "range". But there is CategoricalDistribution for providing a discrete list of options. There is the Note in the documentation though which might indicate an issue in some settings:

Not all types are guaranteed to be compatible with all storages. It is recommended to restrict the types of the choices to None, bool, int, float and str.


I first interpreted your question to be about setting the hyperparameters of steps in a pipeline, so the convention like preprocessor__StandardScaler. I no longer thing that's what you meant, but in case others come here with that confusion:

In the documentation of OptunaSearchCV, the estimator parameter is explained as "Object to use to fit the data. This is assumed to implement the scikit-learn estimator interface. [...]", and I would assume that means they use clone and set_params to create an estimator for each hyperparameter setting; then it should be fine to use a pipeline with the usual naming convention for steps' hyperparameters.

Upvotes: 0

Related Questions