Shakir
Shakir

Reputation: 345

Different outputs of marginaleffects::avg_prediction and ggeffects::ggeffect in R

I am running multinomial regression using nnet::multinom for a dataset where response and predictors are all factors. A sample of my data:

df2 = structure(list(SOURC = c("dsint", "dsint", "circm", "dsint", 
"dsint", "circm", "dsint", "circm", "circm", "dsint", "dsint", 
"circm", "circm", "dsint", "circm", "circm", "circm", "circm", 
"dsint", "rul", "circm", "dsint", "dsint", "dsint", "dsint", 
"dsint", "dsint", "circm", "dsint", "circm", "dsint", "dsint", 
"circm", "circm", "circm", "cond", "circm", "dsint", "dsint", 
"circm", "circm", "dsint", "dsint", "circm", "dsint", "dsint", 
"circm", "circm", "circm", "circm", "dsint", "dsint", "circm", 
"rul", "circm", "dsint", "circm", "dsint", "circm", "circm", 
"dsint", "circm", "circm", "dsint", "circm", "dsint", "dsint", 
"circm", "dsint", "dsint", "dsint", "dsint", "rul", "cond", "circm", 
"cond", "circm", "dsint", "dsint", "dsint", "circm", "circm", 
"dsint", "dsint", "dsint", "dsint", "dsint", "dsint", "circm", 
"circm", "dsint", "dsint", "dsint", "dsint", "dsint", "dsint", 
"circm", "circm", "dsint", "dsint"), SITN = c("dynamic", "dynamic", 
"dynamic", "stative", "dynamic", "dynamic", "dynamic", "dynamic", 
"dynamic", "dynamic", "dynamic", "dynamic", "dynamic", "dynamic", 
"dynamic", "dynamic", "dynamic", "dynamic", "dynamic", "dynamic", 
"dynamic", "stative", "dynamic", "stative", "dynamic", "dynamic", 
"dynamic", "dynamic", "dynamic", "dynamic", "stative", "stative", 
"dynamic", "dynamic", "dynamic", "stative", "dynamic", "dynamic", 
"dynamic", "dynamic", "dynamic", "dynamic", "dynamic", "stative", 
"stative", "stative", "dynamic", "dynamic", "dynamic", "dynamic", 
"dynamic", "stative", "stative", "dynamic", "stative", "dynamic", 
"dynamic", "dynamic", "dynamic", "dynamic", "dynamic", "dynamic", 
"dynamic", "stative", "dynamic", "dynamic", "dynamic", "stative", 
"dynamic", "stative", "dynamic", "dynamic", "dynamic", "dynamic", 
"stative", "dynamic", "dynamic", "stative", "dynamic", "stative", 
"dynamic", "dynamic", "dynamic", "dynamic", "dynamic", "stative", 
"dynamic", "dynamic", "dynamic", "dynamic", "dynamic", "stative", 
"stative", "stative", "stative", "dynamic", "dynamic", "dynamic", 
"stative", "dynamic"), GENR = c("specific", "specific", "generic", 
"specific", "specific", "specific", "generic", "generic", "specific", 
"specific", "specific", "specific", "specific", "specific", "generic", 
"specific", "generic", "specific", "specific", "generic", "generic", 
"specific", "specific", "specific", "specific", "specific", "generic", 
"generic", "specific", "generic", "generic", "specific", "specific", 
"specific", "specific", "specific", "specific", "specific", "generic", 
"generic", "specific", "specific", "specific", "specific", "specific", 
"generic", "specific", "specific", "specific", "generic", "specific", 
"generic", "generic", "generic", "specific", "specific", "specific", 
"specific", "generic", "specific", "specific", "specific", "specific", 
"generic", "specific", "generic", "generic", "specific", "specific", 
"specific", "specific", "specific", "generic", "generic", "specific", 
"specific", "generic", "generic", "specific", "specific", "specific", 
"specific", "specific", "specific", "specific", "generic", "specific", 
"generic", "specific", "specific", "generic", "generic", "specific", 
"specific", "specific", "generic", "generic", "specific", "specific", 
"specific"), PRSN = c("1st", "3rd", "2nd", "1st", "1st", "3rd", 
"2nd", "2nd", "3rd", "3rd", "3rd", "1st", "3rd", "3rd", "1st", 
"3rd", "3rd", "1st", "3rd", "3rd", "2nd", "3rd", "3rd", "2nd", 
"3rd", "3rd", "3rd", "2nd", "1st", "2nd", "2nd", "3rd", "3rd", 
"1st", "1st", "3rd", "1st", "3rd", "2nd", "2nd", "1st", "1st", 
"3rd", "1st", "3rd", "3rd", "2nd", "1st", "1st", "1st", "1st", 
"3rd", "3rd", "3rd", "3rd", "1st", "3rd", "3rd", "3rd", "1st", 
"1st", "2nd", "1st", "1st", "3rd", "1st", "2nd", "3rd", "3rd", 
"1st", "3rd", "3rd", "2nd", "3rd", "1st", "2nd", "2nd", "3rd", 
"3rd", "3rd", "1st", "2nd", "2nd", "1st", "3rd", "3rd", "3rd", 
"3rd", "1st", "3rd", "3rd", "3rd", "3rd", "1st", "1st", "2nd", 
"2nd", "3rd", "1st", "2nd"), ANIM = c("animate", "inanimate", 
"animate", "animate", "animate", "animate", "animate", "animate", 
"inanimate", "animate", "inanimate", "animate", "inanimate", 
"animate", "animate", "animate", "animate", "animate", "inanimate", 
"inanimate", "animate", "animate", "animate", "animate", "animate", 
"animate", "inanimate", "animate", "animate", "animate", "animate", 
"inanimate", "animate", "animate", "animate", "inanimate", "animate", 
"animate", "animate", "animate", "animate", "animate", "animate", 
"animate", "animate", "animate", "animate", "animate", "animate", 
"animate", "animate", "inanimate", "inanimate", "animate", "animate", 
"animate", "inanimate", "animate", "animate", "animate", "animate", 
"animate", "animate", "animate", "inanimate", "animate", "animate", 
"animate", "inanimate", "animate", "animate", "animate", "animate", 
"animate", "animate", "animate", "animate", "inanimate", "inanimate", 
"animate", "animate", "animate", "animate", "animate", "inanimate", 
"animate", "animate", "animate", "animate", "inanimate", "animate", 
"animate", "inanimate", "animate", "animate", "animate", "animate", 
"animate", "animate", "animate"), AGENT = c("yes", "no", "yes", 
"yes", "yes", "yes", "yes", "yes", "no", "yes", "no", "yes", 
"no", "yes", "yes", "yes", "yes", "yes", "no", "no", "yes", "no", 
"yes", "yes", "yes", "yes", "no", "yes", "yes", "yes", "no", 
"no", "yes", "yes", "yes", "no", "yes", "yes", "yes", "yes", 
"yes", "yes", "no", "no", "no", "no", "yes", "yes", "yes", "yes", 
"yes", "no", "no", "yes", "no", "yes", "no", "yes", "yes", "yes", 
"yes", "yes", "yes", "yes", "no", "yes", "yes", "no", "no", "no", 
"no", "yes", "yes", "yes", "no", "yes", "yes", "no", "no", "no", 
"yes", "yes", "yes", "no", "no", "yes", "yes", "yes", "yes", 
"no", "yes", "no", "no", "no", "yes", "yes", "yes", "yes", "no", 
"yes"), CTGRY = c("cmt", "cmt", "wbf", "twt", "wbf", "wbf", "wbs", 
"wbf", "wbf", "cmt", "cmt", "twt", "wbf", "cmt", "wbs", "cmt", 
"cmt", "twt", "wbf", "wbs", "wbs", "cmt", "cmt", "twt", "cmt", 
"cmt", "cmt", "wbs", "cmt", "wbs", "cmt", "wbf", "wbf", "wbf", 
"wbs", "cmt", "twt", "cmt", "wbs", "wbs", "wbs", "cmt", "cmt", 
"wbf", "wbf", "cmt", "wbf", "cmt", "cmt", "wbs", "twt", "cmt", 
"wbf", "wbf", "wbf", "wbf", "wbf", "cmt", "cmt", "twt", "cmt", 
"wbf", "cmt", "cmt", "wbf", "cmt", "wbs", "cmt", "wbs", "twt", 
"cmt", "cmt", "wbs", "wbs", "twt", "wbf", "wbs", "wbf", "cmt", 
"wbf", "wbf", "twt", "twt", "cmt", "cmt", "cmt", "wbf", "cmt", 
"twt", "wbs", "wbf", "wbf", "cmt", "cmt", "cmt", "twt", "wbs", 
"cmt", "wbf", "twt"), CNTRY = c("in", "in", "us", "in", "us", 
"in", "in", "us", "us", "us", "in", "in", "in", "us", "in", "in", 
"in", "in", "in", "in", "in", "in", "in", "us", "us", "in", "in", 
"in", "in", "us", "us", "in", "us", "us", "in", "us", "us", "in", 
"in", "in", "in", "in", "us", "in", "in", "us", "in", "in", "in", 
"us", "us", "in", "us", "in", "in", "in", "us", "us", "in", "us", 
"in", "in", "us", "us", "in", "us", "in", "in", "in", "in", "us", 
"in", "us", "us", "in", "in", "in", "in", "in", "in", "in", "us", 
"in", "us", "in", "in", "us", "in", "in", "in", "us", "in", "in", 
"in", "us", "us", "in", "in", "us", "us"), VERB = c("ndt", "hvt", 
"hvt", "ndt", "hvt", "hvt", "ndt", "ndt", "ndt", "mst", "ndt", 
"ndt", "ndt", "ndt", "mst", "hvt", "hvt", "ndt", "hvt", "mst", 
"mst", "mst", "mst", "hvt", "hvt", "mst", "ndt", "ndt", "hvt", 
"ndt", "hvt", "hvt", "hvt", "hvt", "hvt", "hvt", "ndt", "ndt", 
"ndt", "ndt", "ndt", "mst", "ndt", "mst", "ndt", "ndt", "ndt", 
"mst", "mst", "ndt", "hvt", "mst", "mst", "hvt", "mst", "hvt", 
"ndt", "ndt", "hvt", "hvt", "mst", "mst", "hvt", "hvt", "mst", 
"ndt", "mst", "hvt", "ndt", "mst", "mst", "mst", "mst", "ndt", 
"mst", "hvt", "ndt", "hvt", "mst", "ndt", "hvt", "ndt", "ndt", 
"ndt", "hvt", "ndt", "ndt", "mst", "mst", "mst", "hvt", "mst", 
"mst", "ndt", "ndt", "hvt", "ndt", "ndt", "ndt", "hvt"), ASSOC = c("must", 
"need_to", "need_to", "need_to", "need_to", "have_to", "need_to", 
"need_to", "have_to", "must", "must", "need_to", "have_got_to", 
"have_to", "must", "must", "have_to", "must", "have_to", "need_to", 
"need_to", "need_to", "must", "have_to", "need_to", "have_to", 
"need_to", "must", "need_to", "need_to", "have_got_to", "must", 
"have_to", "need_to", "have_to", "must", "must", "must", "have_to", 
"have_to", "have_to", "need_to", "must", "must", "need_to", "must", 
"need_to", "have_to", "must", "have_to", "have_to", "must", "must", 
"have_to", "need_to", "must", "need_to", "need_to", "need_to", 
"have_to", "must", "need_to", "have_to", "have_to", "have_to", 
"need_to", "need_to", "must", "need_to", "need_to", "have_to", 
"have_to", "have_to", "have_to", "need_to", "must", "have_to", 
"have_to", "need_to", "must", "have_to", "have_to", "must", "need_to", 
"have_to", "have_to", "need_to", "must", "have_to", "must", "must", 
"must", "must", "have_got_to", "need_to", "have_to", "need_to", 
"need_to", "must", "have_to")), class = "data.frame", row.names = c(NA, 
-100L))

df2[sapply(df2, is.character)] <- lapply(df2[sapply(df2, is.character)], 
                                                           as.factor)

I have fitted the model like this:

multi_mo <- nnet::multinom(VERB ~ SOURC + SITN + PRSN + ANIM + ASSOC + 
    AGENT + CTGRY + CNTRY + CTGRY:CNTRY + SOURC:CTGRY + PRSN:CTGRY, data = df2, model=TRUE, maxit=1000)

Now I wanna see the marginal effects of variables SITN and CNTRY (following Sonderegger (2022): 176). Sonderegger uses ggeffects::ggeffect but I also found marginaleffects::avg_prediction while searching on the internet. However, I get different results from these two functions. For example see the predicted probabilities of dynamic vs stative for hvt in panel 1 of each graph. ggeffects::ggeffect output:

library(ggeffects)
pred = ggeffect(multi_mo, terms = c("SITN", "CNTRY"))
p = ggplot(data=pred, aes(x=x, y=predicted)) +
  facet_wrap(~ response.level) +
  geom_point(aes(shape=group)) +
  geom_line(aes(group = group, linetype = group))
print(p)

ggeffects::ggeffect output marginaleffects::avg_prediction output

library(marginaleffects)
pred = avg_predictions (multi_mo, by = c("SITN", "CNTRY"), type = "probs")
p = ggplot(data=pred, aes(x=SITN, y=estimate)) +
  facet_wrap(~ group) +
  geom_point(aes(shape=CNTRY)) +
  geom_line(aes(group = CNTRY, linetype = CNTRY))
print(p)

marginaleffects::avg_prediction output Am I comparing two different things here? If that is the case, what is the relevant function for ggeffects::ggeffect in marginaleffects package?

Upvotes: 0

Views: 174

Answers (2)

Daniel
Daniel

Reputation: 7832

As you mentioned in your comments, predict_response() is the wrapper around ggpredict() (using predict()), ggemmeans() (using emmeans::emmeans()) and ggaverage() (using marginaleffects::avg_predictions()). Each of those functions treats the non-focal terms differently, that's why margin is the argument used to modulate how to "marginalize" over non-focal terms.

As you can see, ggeffect() (using effects::Effect()) and the two options for predict_response() generate quite different predictions. It's similar to creating different data grids in marginaleffects.

If you use avg_predictions(..., by = ...) you get different predictions as if you would use avg_predictions(..., variables = ...). variables is the argument used in ggeffects::ggaverage() (or: predict_response(margin = "empirical")) to produce counterfactual predictions, and then you see that the plots from ggeffects and marginaleffects are identical (as expected).

multi_mo <- nnet::multinom(VERB ~ SOURC + SITN + PRSN + ANIM + ASSOC + 
AGENT + CTGRY + CNTRY + CTGRY:CNTRY + SOURC:CTGRY + PRSN:CTGRY, data = df2, model=TRUE, maxit=1000)

library(ggeffects)
library(marginaleffects)
library(ggplot2)

pred <- ggeffect(multi_mo, terms = c("SITN", "CNTRY"))
plot(pred, show_ci = FALSE, connect_lines = TRUE)


pred2 <- predict_response(multi_mo, terms = c("SITN", "CNTRY"), margin = "marginalmeans")
plot(pred2, show_ci = FALSE, connect_lines = TRUE)


pred3 <- predict_response(multi_mo, terms = c("SITN", "CNTRY"), margin = "empirical")
plot(pred3, show_ci = FALSE, connect_lines = TRUE)


pred <- avg_predictions(multi_mo, by = c("SITN", "CNTRY"), type = "probs")
ggplot(data = pred, aes(x = SITN, y = estimate)) +
  facet_wrap(~group) +
  geom_point(aes(shape = CNTRY)) +
  geom_line(aes(group = CNTRY, linetype = CNTRY))


pred <- avg_predictions(multi_mo, variables = c("SITN", "CNTRY"), type = "probs")
ggplot(data = pred, aes(x = SITN, y = estimate)) +
  facet_wrap(~group) +
  geom_point(aes(shape = CNTRY)) +
  geom_line(aes(group = CNTRY, linetype = CNTRY))

Created on 2024-07-18 with reprex v2.1.1

As Vincent said, there's probably no best choice or only one way for the type of "marginalization" you want. I tried to describe this in terms of "what kind of question would you like to answer?", see this vignette. Maybe that helps to decide which estimands you want to have.

Upvotes: 2

Vincent
Vincent

Reputation: 17823

The ggeffects::ggeffect() function appears to have been superseded, and is no longer documented extensively on that package’s website. Therefore, I’m not exactly sure what it does, and I can only describe the estimates produced by marginaleffects::avg_predictions().

Calling avg_predictions() as you do is the equivalent of this simple process:

  1. Compute a prediction (fitted value) for every row of the original dataset, for every level of the outcome variable.
  2. Take the average of those predictions for each subset of CNTRY and SITN.

Here’s how to obtain the same results with only Base R functions:

# predictions
p <- predict(multi_mo, type = 'probs')

# combine original data and predictions
p <- data.frame(p, df2)

# average predictions for the `hvt` outcome level, by subgroups
aggregate(hvt ~ SITN + CNTRY, FUN = mean, data = p)
>      SITN CNTRY       hvt
> 1 dynamic    in 0.2762476
> 2 stative    in 0.1712579
> 3 dynamic    us 0.3835831
> 4 stative    us 0.5324932

Which is equivalent to the first few rows in this:

avg_predictions(multi_mo, by = c("SITN", "CNTRY"))
> 
>  Group    SITN CNTRY Estimate Std. Error    z Pr(>|z|)    S   2.5 % 97.5 %
>    hvt dynamic    in    0.276     0.0427 6.47  < 0.001 33.2  0.1926  0.360
>    hvt dynamic    us    0.384     0.0623 6.16  < 0.001 30.3  0.2615  0.506
>    hvt stative    in    0.171     0.0591 2.90  0.00377  8.1  0.0554  0.287
>    hvt stative    us    0.532     0.0940 5.66  < 0.001 26.0  0.3482  0.717
>    mst dynamic    in    0.356     0.0590 6.04  < 0.001 29.3  0.2408  0.472
>    mst dynamic    us    0.103     0.0394 2.62  0.00887  6.8  0.0259  0.180
>    mst stative    in    0.493     0.0925 5.33  < 0.001 23.3  0.3117  0.674
>    mst stative    us    0.139     0.0837 1.66  0.09725  3.4 -0.0253  0.303
>    ndt dynamic    in    0.367     0.0557 6.60  < 0.001 34.5  0.2582  0.476
>    ndt dynamic    us    0.513     0.0684 7.50  < 0.001 43.9  0.3792  0.647
>    ndt stative    in    0.336     0.0911 3.68  < 0.001 12.1  0.1570  0.514
>    ndt stative    us    0.329     0.1125 2.92  0.00350  8.2  0.1081  0.549
> 
> Columns: group, SITN, CNTRY, estimate, std.error, statistic, p.value, s.value, conf.low, conf.high 
> Type:  probs

In this case, the by argument indicates the variables with respect to which we are marginalizing (taking averages). By default, averages are taken across the empirical distribution of the data, that is, across predictions made for each combination of actually observed predictor values in the dataset.

In some cases, users may want to construct a different dataset to marginalize over. For example, some people like to create a “balanced” grid of predictors, where the dataset is built from each combination of categorical predictors and numeric variables are held at their means:

avg_predictions(multi_mo, 
    by = c("SITN", "CNTRY"), 
    newdata = "balanced")
> 
>  Group    SITN CNTRY Estimate Std. Error     z Pr(>|z|)     S  2.5 % 97.5 %
>    hvt dynamic    in   0.3939    0.01520 25.91   <0.001 489.4 0.3641 0.4237
>    hvt dynamic    us   0.5535    0.02412 22.95   <0.001 384.7 0.5063 0.6008
>    hvt stative    in   0.4007    0.02201 18.21   <0.001 243.8 0.3576 0.4439
>    hvt stative    us   0.5661    0.03243 17.46   <0.001 224.3 0.5025 0.6296
>    mst dynamic    in   0.1992    0.02662  7.48   <0.001  43.6 0.1470 0.2514
>    mst dynamic    us   0.0611    0.00931  6.56   <0.001  34.1 0.0428 0.0793
>    mst stative    in   0.2125    0.04023  5.28   <0.001  22.9 0.1337 0.2914
>    mst stative    us   0.0637    0.01078  5.91   <0.001  28.1 0.0426 0.0848
>    ndt dynamic    in   0.4069    0.02867 14.19   <0.001 149.4 0.3507 0.4631
>    ndt dynamic    us   0.3854    0.02526 15.25   <0.001 172.1 0.3359 0.4349
>    ndt stative    in   0.3867    0.03823 10.12   <0.001  77.5 0.3118 0.4617
>    ndt stative    us   0.3702    0.03298 11.23   <0.001  94.7 0.3056 0.4349
> 
> Columns: group, SITN, CNTRY, estimate, std.error, statistic, p.value, s.value, conf.low, conf.high 
> Type:  probs

Or you may want to estimate the predicted outcome for a unit with very specific characteristics:

predictions(multi_mo, 
    by = c("SITN", "CNTRY"), 
    newdata = datagrid(SOURC = "cond", PRSN = "2nd"))
> 
>  Group    SITN CNTRY Estimate Std. Error     z Pr(>|z|)   S     2.5 %   97.5 %
>    hvt dynamic    in 1.00e+00         NA    NA       NA  NA        NA       NA
>    mst dynamic    in 2.66e-18   4.99e-18 0.533    0.594 0.8 -7.12e-18 1.24e-17
>    ndt dynamic    in 6.20e-35   1.09e-34 0.567    0.571 0.8 -1.52e-34 2.76e-34
> 
> Columns: group, SITN, CNTRY, estimate, std.error, statistic, p.value, s.value, conf.low, conf.high 
> Type:  probs

These are all different estimands, and the choice between them depends on subject matter expertise, with which I cannot help.

geffects::ggeffect() almost certainly estimates a variation on those quantities, but I cannot tell which one based on the current documentation.

Upvotes: 1

Related Questions