dongyokim
dongyokim

Reputation: 11

What is the LightGBM equivalent of .fit() to DecisionTreeRegressor/Classifier (from scikit-learn)?

I am trying to replicate the .fit() from lightgbm library in python, but there seems to be different methods for lightgbm Booster.

  1. .update()
  2. .refit()
  3. .train()

I have tried all three to no avail.

tree = DecisionTreeRegressor(criterion='friedman_mse', max_depth=3, 
                                          max_features=self.max_features, max_leaf_nodes=None,
                                          min_impurity_decrease=0.0, min_impurity_split=None,
                                          min_samples_leaf=1, min_samples_split=5,
                                          min_weight_fraction_leaf=0.0
                                          , random_state=0)

tree.fit(X, gradient)

works

This, however, doesn't work.

tree = lgb.Booster(model_file='lgbm_model.txt')

train_data = lgb.Dataset(X, label=gradient, free_raw_data=False)
valid_data = lgb.Dataset(Xtest, label=gradient_t, free_raw_Data=False)

solution 1

tree.update(train_data) # gives me this error: 

AttributeError: 'Booster' object has no attribute 'train_set'

solution 2

tree.refit(X, gradient, predict_disable_shape_check = True) 

runs but doesn't seem to update the tree all that much

solution 3

tree = lgb.train(self.params, 
                 train_data,  
                 valid_sets=valid_data,
                 num_boost_round= 10, 
                 keep_training_booster = True,
                 init_model = tree
                 )

doesn't run

Upvotes: 1

Views: 257

Answers (1)

afsharov
afsharov

Reputation: 5164

The LightGBM package for Python has different APIs. If you are using the Training API then you should definitely use the train method:

Perform the training with given parameters

However, if you want to stick to scikit-learn conventions then you should simply use the scikit-learn API with a LGBMClassifier which offers a fit method:

import lightgbm as lgb


clf = lgb.LGBMClassifier()
clf.fit(X, y)

Upvotes: 1

Related Questions