jean
jean

Reputation: 141

get probability from xgb.train()

I am new to Python and Machine learning. I have searched internet regarding my question and tried the solution people have suggested, but still not get it. Would really appreciate it if anyone can help me out.

I am working on my first XGboost model. I have tuned the parameters by using xgb.XGBClassifier, and then would like to enforce monotonicity on model variables. Seemingly I have to use xgb.train() to enforce monotonicity as shown in my code below.

xgb.train() can do predict(), but NOT predict_proba() function. So how can I get probability from xgb.train() ?

I have tried to use 'objective':'multi:softprob' instead of 'objective':'binary:logistic'. then score = bst_constr.predict(dtrain). But the score does not seem right to me.

Thank you so much.

params_constr={
    'base_score':0.5, 
    'learning_rate':0.1, 
    'max_depth':5,
    'min_child_weight':100, 
    'n_estimators':200, 
    'nthread':-1,
    'objective':'binary:logistic', 
    'seed':2018, 
    'eval_metric':'auc' 
}

params_constr['monotone_constraints'] = "(1,1,0,1,-1,-1,0,0,1,-1,1,0,1,0,-1,0,0,0,0,0,0,0,0,0,0,0,0,0,)" 

dtrain = xgb.DMatrix(X_train, label = y_train)

bst_constr = xgb.train(params_constr, dtrain)


X_test['score']=bst_constr.predict_proba(X_test)[:,1]

AttributeError: 'Booster' object has no attribute 'predict_proba'

Upvotes: 5

Views: 8908

Answers (1)

Charles Kuo
Charles Kuo

Reputation: 71

So based on my understanding, you are trying to obtain the probability for each class in the prediction phase. Two options.

  1. It seems that you are using the XGBoost native api. Then just select the 'objective':'multi:softprob' as the parameter, and use the bst_constr.predict instead of bst_constr.predict_proba.

  2. XGBoost also provides the scikit-learn api. But then you should initiate the model with bst_constr = xgb.XGBClassifier(**params_constr), and use bst_constr.fit() for training. Then you can call the bst_constr.predict_proba to obtain what you want. You can refer here for more details Scikit-Learn API in XGBoost.

Upvotes: 7

Related Questions