LudvigH
LudvigH

Reputation: 4801

How to make shap.plots.scatter with xgboost.DMatrix holding missing data?

I have a dataset with missing data. They are encoded as NaN. This is fine for model fitting with XGBoost. When I want to understand the model, analyzing model importance with SHAP scatter plots, I am not sure what is the correct usage.

Consider the synthetic example below:

import numpy as np, scipy.special, xgboost as xgb, shap

rng = np.random.default_rng(0)
def gendata(n):
    X = rng.normal(size=(n,1))
    y = np.sin(X[:,0]) + rng.normal(size=n)
    X[:n//2,0] = np.nan
    y = (rng.random(size=n) < scipy.special.expit(y)).astype(int)
    dmatrix = xgb.DMatrix(X,label=y,feature_names=['X0'])
    return X,y, dmatrix

X,y,dmat = gendata(10)
model = xgb.train({'objective':'reg:squarederror','booster':'gbtree'}, dmat)
explainer = shap.Explainer(model,feature_names=dmat.feature_names)
explanation = explainer(dmat);shap.plots.scatter(explanation)
explanation = explainer(X,y);shap.plots.scatter(explanation)

It produces the following two scatter plots. When using raw numpy arrays, the plot shows missing data as rug plot markers. That seems correct. When using the xgb.DMatrix, it gets zero imputation. The explanation object holds the source data correctly (a sparse matrix in the dmat and numpy arrays in the X,y case). I suppose that there is a to_dense call somewhere in the scatter function that messes up everything.

How should I do the scatter if I only have a xgb.DMatrix available?

scatter another scatter

If I only have a dma

Upvotes: 0

Views: 103

Answers (1)

The issue occurs because SHAP’s scatter function may improperly handle missing data when using xgb.DMatrix, as it might convert the sparse matrix to dense, leading to zero imputation. To correctly display missing values (e.g., as rug plot markers), you should use the raw input data (numpy array or pandas.DataFrame) instead of xgb.DMatrix when calculating SHAP values. While the model can be trained with DMatrix, passing the original X to the SHAP explainer ensures proper handling of NaN values and accurate scatter plots.

Upvotes: -1

Related Questions