stats-hb
stats-hb

Reputation: 988

Plotting Partial Dependence Plots in R for binary target (mlr)

I have a problem to get partial dependence plots with mlr to work properly for me. Somehow not the probability is plottet, but just the class label. I suspect, that the target may be lost during the creation of der partialdependence-data.

Any ideas?

library(mlr)
library(dplyr)
library(ranger)

# select subset
iris_bin <- iris %>% 
  filter(Species != "virginica") %>% 
  mutate(bin_target = ifelse(Species == "setosa", TRUE, FALSE)) %>% 
  select(-Species)

# fit model
task_bin <- makeClassifTask(data = iris_bin, target = "bin_target")
lrn_bin  <- makeLearner("classif.ranger", predict.type = "prob")
fit_bin <- train(lrn_bin, task_bin)

# create partial dependence plot
pd <- generatePartialDependenceData(fit_bin, task_bin, "Sepal.Length")

pd  # is the target correct?
#> PartialDependenceData
#> Task: iris_bin
#> Features: Sepal.Length
#> Target: FALSE
#> Derivative: FALSE
#> Interaction: FALSE
#> Individual: FALSE
#>        FALSE Sepal.Length
#> 1: 0.4920347          4.3
#> 2: 0.4920347          4.6
#> 3: 0.4935947          4.9
#> 4: 0.4945947          5.2
#> 5: 0.5104600          5.5
#> 6: 0.5107800          5.8
#> ... (#rows: 10, #cols: 2)
plotPartialDependence(pd)

enter image description here

This would the details of my current session, maybe this helps?:

Session info ---------------------------------------------
 setting  value                       
 version  R version 3.4.2 (2017-09-28)
 system   x86_64, mingw32             
 ui       RStudio (1.1.383)           
 language (EN)                        
 collate  German_Germany.1252         
 tz       Europe/Berlin               
 date     2018-03-29                  

Packages ----------------------------------------------------
 package      * version    date       source                                   
 assertthat     0.2.0      2017-04-11 CRAN (R 3.4.3)                           
 backports      1.1.2      2017-12-13 CRAN (R 3.4.3)                           
 base         * 3.4.2      2017-09-28 local                                    
 BBmisc         1.11       2017-03-10 CRAN (R 3.4.3)                           
 bindr          0.1.1      2018-03-13 CRAN (R 3.4.2)                           
 bindrcpp     * 0.2        2017-06-17 CRAN (R 3.4.3)                           
 checkmate      1.8.5      2017-10-24 CRAN (R 3.4.3)                           
 colorspace     1.3-2      2016-12-14 CRAN (R 3.4.3)                           
 compiler       3.4.2      2017-09-28 local                                    
 data.table     1.10.4-3   2017-10-27 CRAN (R 3.4.3)                           
 datasets     * 3.4.2      2017-09-28 local                                    
 devtools       1.13.5     2018-02-18 CRAN (R 3.4.3)                           
 digest         0.6.15     2018-01-28 CRAN (R 3.4.3)                           
 dplyr        * 0.7.4      2017-09-28 CRAN (R 3.4.3)                           
 ggplot2        2.2.1.9000 2018-03-26 Github (tidyverse/ggplot2@3c9c504)       
 glue           1.2.0      2017-10-29 CRAN (R 3.4.3)                           
 graphics     * 3.4.2      2017-09-28 local                                    
 grDevices    * 3.4.2      2017-09-28 local                                    
 grid           3.4.2      2017-09-28 local                                    
 gtable         0.2.0      2016-02-26 CRAN (R 3.4.3)                           
 labeling       0.3        2014-08-23 CRAN (R 3.4.1)                           
 lattice        0.20-35    2017-03-25 CRAN (R 3.4.2)                           
 lazyeval       0.2.1      2017-10-29 CRAN (R 3.4.3)                           
 magrittr       1.5        2014-11-22 CRAN (R 3.4.3)                           
 Matrix         1.2-11     2017-08-21 CRAN (R 3.4.2)                           
 memoise        1.1.0      2017-04-21 CRAN (R 3.4.3)                           
 methods      * 3.4.2      2017-09-28 local                                    
 mlr          * 2.13       2018-03-28 Github (mlr-org/mlr@a9036e3)             
 mmpf         * 0.0.4      2017-12-05 CRAN (R 3.4.4)                           
 munsell        0.4.3      2016-02-13 CRAN (R 3.4.3)                           
 parallel       3.4.2      2017-09-28 local                                    
 parallelMap    1.3        2015-06-10 CRAN (R 3.4.3)                           
 ParamHelpers * 1.11       2018-02-19 Github (berndbischl/ParamHelpers@59c649e)
 pillar         1.2.1      2018-02-27 CRAN (R 3.4.3)                           
 pkgconfig      2.0.1      2017-03-21 CRAN (R 3.4.3)                           
 plyr           1.8.4      2016-06-08 CRAN (R 3.4.3)                           
 R6             2.2.2      2017-06-17 CRAN (R 3.4.3)                           
 ranger       * 0.9.0      2018-01-09 CRAN (R 3.4.3)                           
 Rcpp           0.12.16    2018-03-13 CRAN (R 3.4.2)                           
 rlang          0.2.0.9001 2018-03-26 Github (r-lib/rlang@49d7a34)             
 rstudioapi     0.7        2017-09-07 CRAN (R 3.4.3)                           
 scales         0.5.0      2017-08-24 CRAN (R 3.4.4)                           
 splines        3.4.2      2017-09-28 local                                    
 stats        * 3.4.2      2017-09-28 local                                    
 stringi        1.1.7      2018-03-12 CRAN (R 3.4.2)                           
 survival       2.41-3     2017-04-04 CRAN (R 3.4.2)                           
 tibble         1.4.2      2018-01-22 CRAN (R 3.4.3)                           
 tools          3.4.2      2017-09-28 local                                    
 utils        * 3.4.2      2017-09-28 local                                    
 withr          2.1.2      2018-03-26 Github (jimhester/withr@79d7b0d)         
 XML            3.98-1.10  2018-02-19 CRAN (R 3.4.3)                           
 yaml           2.1.18     2018-03-08 CRAN (R 3.4.3)

Upvotes: 0

Views: 1463

Answers (1)

bgreenwell
bgreenwell

Reputation: 413

Hopefully the mlr package maintainers can help (I don't use that package). However, in the meantime, you can fit the model directly, and just use the pdp package:

fit <- ranger(as.factor(bin_target) ~ ., data = iris_bin, 
              probability = TRUE)
library(ggplot2)
library(pdp)
pd <- partial(fit, pred.var = "Sepal.Length", prob = TRUE)
autoplot(pd)

Note the use of prob = TRUE in the call to partial. Also, ggplot2 is not necessary as you can just use plotPartial(pd) instead, which relies on lattice graphics.

Also, you can still fit the model with mlr and then use partial on that; for instance,

library(mlr)
library(dplyr)
library(ranger)
library(pdp)

# select subset
iris_bin <- iris %>% 
  filter(Species != "virginica") %>% 
  mutate(bin_target = ifelse(Species == "setosa", TRUE, FALSE)) %>% 
  select(-Species)

# fit model
task_bin <- makeClassifTask(data = iris_bin, target = "bin_target")
lrn_bin  <- makeLearner("classif.ranger", predict.type = "prob")
fit_bin <- train(lrn_bin, task_bin)

# partial dependence plot
mod <- getLearnerModel(fit_bin)  # EXTRACT THE MODEL!!  <<--
partial(mod, pred.var = "Sepal.Length", prob = TRUE, 
        plot = TRUE, train = iris_bin)

Note, however, the need to supply the original training data via the train argument.

Upvotes: 3

Related Questions