youtube
youtube

Reputation: 504

How to save single Random Forest model with cross validation?

I am using 10 fold cross validation, trying to predict binary labels (Y) based on the embedding inputs (X). I want to save one of the models (perhaps the one with the highest ROC AUC). I'm not sure how to do it because the ROC AUCs are not stored and I don't know how to grab accordingly.

X = np.array([np.array(x) for x in df['embeddings'].values])
y = df['label'].values
groups = df['chromosome'].values
group_kfold = GroupKFold(n_splits=n_folds)

Initialize figure for plotting

fig, axes = plt.subplots(1, 2, figsize=(15, 6))

all_fpr = []
all_tpr = []
all_accuracy = []
all_pr_auc = []
Perform cross-validation and plot ROC and PR curves for each fold
for i, (train_idx, val_idx) in enumerate(group_kfold.split(X, y, groups)):
    X_train_fold, X_val_fold = X[train_idx], X[val_idx]
    y_train_fold, y_val_fold = y[train_idx], y[val_idx]
    
    # Initialize classifier
    rf_classifier = RandomForestClassifier(n_estimators=n_trees, random_state=42, max_depth=max_depth, n_jobs=-1)
    
    # Train the classifier on this fold
    rf_classifier.fit(X_train_fold, y_train_fold)
    
    # Make predictions on the validation set
    y_pred_proba = rf_classifier.predict_proba(X_val_fold)[:, 1]
    
    # Calculate ROC curve
    fpr, tpr, _ = roc_curve(y_val_fold, y_pred_proba)
    
    all_fpr.append(fpr)
    all_tpr.append(tpr)

    # Calculate AUC
    roc_auc = auc(fpr, tpr)

    # Plot ROC curve for this fold
    axes[0].plot(fpr, tpr, lw=1, alpha=0.7, label=f'ROC Fold {i+1} (AUC = {roc_auc:.2f})')
    
    # Calculate precision-recall curve
    precision, recall, _ = precision_recall_curve(y_val_fold, y_pred_proba)
    
    # Calculate PR AUC
    pr_auc = auc(recall, precision)
    all_pr_auc.append(pr_auc)

    # Plot PR curve for this fold
    axes[1].plot(recall, precision, lw=1, alpha=0.7, label=f'PR Curve Fold {i+1} (AUC = {pr_auc:.2f})')
    
    # Calculate accuracy
    accuracy = accuracy_score(y_val_fold, rf_classifier.predict(X_val_fold))
    all_accuracy.append(accuracy)

# Initialize empty arrays to store interpolated TPR values
interpolated_tpr = []

# Define common set of thresholds
mean_fpr = np.linspace(0, 1, 100)

# Interpolate TPR values for each fold to the common set of thresholds
for fpr, tpr in zip(all_fpr, all_tpr):
    interpolated_tpr.append(np.interp(mean_fpr, fpr, tpr))

# Calculate the mean and standard deviation of interpolated TPR values
mean_tpr = np.mean(interpolated_tpr, axis=0)
std_tpr = np.std(interpolated_tpr, axis=0)

# Plot the mean ROC curve with shaded area representing the standard deviation
axes[0].plot(mean_fpr, mean_tpr, color='black', linestyle='--', lw=2, label=f'Average ROC curve ({np.round(auc(mean_fpr, mean_tpr), 2)})')
axes[0].fill_between(mean_fpr, mean_tpr - std_tpr, mean_tpr + std_tpr, color='grey', alpha=0.2)

# Plot ROC for random classifier
axes[0].plot([0, 1], [0, 1], linestyle='--', lw=2, color='r', alpha=0.8)

Upvotes: 1

Views: 55

Answers (1)

Shaido
Shaido

Reputation: 28367

You can easily keep track of the best model during the iteration, keeping the best-performing one. For example:

best_score = 0
best_model = None
for i, (train_idx, val_idx) in enumerate(group_kfold.split(X, y, groups)):
    ...
    rf_classifier = ...
    ...
    pr_auc = ...
    
    if pr_auc > best_score:
        best_model = rf_classifier
    ...

However, you probably need to consider if this model is useful. If you used all the training data to train a model it could potentially be much better than this model which is only trained on a slice of the data.

Usually, cross-validation is used to determine the hyperparameters for the model. For example, you can determine what maximum depth is the most optimal for your data. In this case, usually you would compare the average score (accuracy/f1-score/AUC/etc) over all CV groups. See for example this previous question about why CV is interesting: What is the purpose of cross-validation if the model is thrown away each iteration

Another way is to keep all models and create an ensemble with them.

Upvotes: 0

Related Questions