John
John

Reputation: 1828

R gbm why the length of fit$trees is always 3 times the n.trees for data set iris

Regardless of the method ('cv', 'OOB' or 'test') we choose, the number of trees in the fitted gbm object is always 3*n.trees for data set iris.

Is it because the iris data set has a target categorical variable with 3 levels.

If so and the target variable has values A, B and C, is the 1st tree for A, 2nd tree for B, 3rd tree for C and 4th tree for A again?

Or if the n.trees is set to 100, the first 100 trees are for A, 2nd 100 trees are for B and the last 100 for C?

fit = gbm(data=iris, Species ~., shrinkage = 0.2, n.trees = 50, cv.folds = 2)

best.iter = gbm.perf(fit, method = 'cv')

fit = gbm(data=iris, Species ~., shrinkage = 0.2, n.trees = 40, train.fraction = 0.8)

best.iter = gbm.perf(fit, method = 'test')

fit = gbm(data=iris, Species ~., shrinkage = 0.2, n.trees = 50)

best.iter = gbm.perf(fit, method = 'OOB')

Upvotes: 1

Views: 80

Answers (1)

J Hart
J Hart

Reputation: 89

You are right that the number of trees is 3 times because of the factor on the left side of the formula you are trying to fit. R will break that down as 3 separate variables on the left hand side of your formula and you will get 3 separate fits.

If you filter the data to get only two levels of the factor, you will end up with 2*n.trees instead.

library(gbm)
iris.sub <- iris[iris$Species != "setosa", ]
iris.sub$Species <- factor(as.character(iris.sub$Species))
levels(iris.sub$Species)
fit = gbm(data=iris.sub, Species ~., shrinkage = 0.2, n.trees = 50, cv.folds = 2, distribution = "multinomial")
length(fit$trees)

As for how the data is organized within the gbm.object, I was unable to figure out which order the trees are in.

Upvotes: 1

Related Questions