Bob Wakefield
Bob Wakefield

Reputation: 4009

How do you access the model parameters in ml_decision_tree in the Sparklyr package?

I have some sample code that is only working on one machine. After some testing, I discovered that the machine that worked was running R 3.4.2 while everything else was running 3.4.3.

After some work I discovered that the way you access the parameters of ml_decision_tree has somehow changed. I'm trying to get the labels. This was the old way of doing it:

model_iris$model.parameters$labels

That doesn't work anymore. If you run that in the context of the rest of the script, I get a null. I've tried actually looking at the list object to determine where in the hierarchy labels are stored and I can SEE them, but no matter what I do, I can't seem to drill down to them.

Here is a version of the whole script:

library(tidyverse)
library(sparklyr)
library(Rcpp)
sc <- spark_connect(master = "local")
iris_tbl <- copy_to(sc, iris)

partition_iris <- sdf_partition(import_iris, training=0.5, testing=0.5)

sdf_register(partition_iris, c("spark_iris_training", "spark_iris_test"))

tidy_iris <- tbl(sc, "spark_iris_training") %>%
  select(Species, Petal_Length, Petal_Width)

model_iris <- tidy_iris %>%
  ml_decision_tree(response="Species", features=c("Petal_Length", "Petal_Width"))

test_iris <- tbl(sc, "spark_iris_test")

pred_iris <- sdf_predict(model_iris, test_iris) %>%
  collect

library(ggplot2)

pred_iris %>%
  inner_join(data.frame(prediction=0:2, lab=model_iris$model.parameters$labels)) %>%
  ggplot(aes(Petal_Length, Petal_Width, col=lab)) + geom_point()

EDIT: There appears to be a difference in the versions of the packages that I'm running. The working code runs sparklyr 0.6.3. The broken version is 0.7.0-9004.

Upvotes: 0

Views: 217

Answers (2)

Javier Luraschi
Javier Luraschi

Reputation: 912

model_iris$model.parameters$labels is now accessible with model_iris$.index_labels.

You can instead run:

pred_iris %>%
  inner_join(data.frame(prediction=0:2, lab=model_iris$.index_labels)) %>%
  ggplot(aes(Petal_Length, Petal_Width, col=lab)) + geom_point()

However, since model_iris$.index_labels is internal, to prevent code from breaking in the future we should get the labels from the original dataset or predicted data frame:

pred_iris %>%
  inner_join(data.frame(prediction=0:2, lab=unique(iris$Species))) %>%
  ggplot(aes(Petal_Length, Petal_Width, col=lab)) + geom_point()

or,

pred_iris %>%
  inner_join(data.frame(prediction=0:2, lab=unique(pred_iris$predicted_label))) %>%
  ggplot(aes(Petal_Length, Petal_Width, col=lab)) + geom_point()

Upvotes: 1

kevinykuo
kevinykuo

Reputation: 4762

pred_iris should have a predicted_label column with what you need. Are there other use cases you have that require getting labels from the model object?

Upvotes: 0

Related Questions