Reputation: 345
I am running multinomial regression using nnet::multinom
for a dataset where response and predictors are all factors. A sample of my data:
df2 = structure(list(SOURC = c("dsint", "dsint", "circm", "dsint",
"dsint", "circm", "dsint", "circm", "circm", "dsint", "dsint",
"circm", "circm", "dsint", "circm", "circm", "circm", "circm",
"dsint", "rul", "circm", "dsint", "dsint", "dsint", "dsint",
"dsint", "dsint", "circm", "dsint", "circm", "dsint", "dsint",
"circm", "circm", "circm", "cond", "circm", "dsint", "dsint",
"circm", "circm", "dsint", "dsint", "circm", "dsint", "dsint",
"circm", "circm", "circm", "circm", "dsint", "dsint", "circm",
"rul", "circm", "dsint", "circm", "dsint", "circm", "circm",
"dsint", "circm", "circm", "dsint", "circm", "dsint", "dsint",
"circm", "dsint", "dsint", "dsint", "dsint", "rul", "cond", "circm",
"cond", "circm", "dsint", "dsint", "dsint", "circm", "circm",
"dsint", "dsint", "dsint", "dsint", "dsint", "dsint", "circm",
"circm", "dsint", "dsint", "dsint", "dsint", "dsint", "dsint",
"circm", "circm", "dsint", "dsint"), SITN = c("dynamic", "dynamic",
"dynamic", "stative", "dynamic", "dynamic", "dynamic", "dynamic",
"dynamic", "dynamic", "dynamic", "dynamic", "dynamic", "dynamic",
"dynamic", "dynamic", "dynamic", "dynamic", "dynamic", "dynamic",
"dynamic", "stative", "dynamic", "stative", "dynamic", "dynamic",
"dynamic", "dynamic", "dynamic", "dynamic", "stative", "stative",
"dynamic", "dynamic", "dynamic", "stative", "dynamic", "dynamic",
"dynamic", "dynamic", "dynamic", "dynamic", "dynamic", "stative",
"stative", "stative", "dynamic", "dynamic", "dynamic", "dynamic",
"dynamic", "stative", "stative", "dynamic", "stative", "dynamic",
"dynamic", "dynamic", "dynamic", "dynamic", "dynamic", "dynamic",
"dynamic", "stative", "dynamic", "dynamic", "dynamic", "stative",
"dynamic", "stative", "dynamic", "dynamic", "dynamic", "dynamic",
"stative", "dynamic", "dynamic", "stative", "dynamic", "stative",
"dynamic", "dynamic", "dynamic", "dynamic", "dynamic", "stative",
"dynamic", "dynamic", "dynamic", "dynamic", "dynamic", "stative",
"stative", "stative", "stative", "dynamic", "dynamic", "dynamic",
"stative", "dynamic"), GENR = c("specific", "specific", "generic",
"specific", "specific", "specific", "generic", "generic", "specific",
"specific", "specific", "specific", "specific", "specific", "generic",
"specific", "generic", "specific", "specific", "generic", "generic",
"specific", "specific", "specific", "specific", "specific", "generic",
"generic", "specific", "generic", "generic", "specific", "specific",
"specific", "specific", "specific", "specific", "specific", "generic",
"generic", "specific", "specific", "specific", "specific", "specific",
"generic", "specific", "specific", "specific", "generic", "specific",
"generic", "generic", "generic", "specific", "specific", "specific",
"specific", "generic", "specific", "specific", "specific", "specific",
"generic", "specific", "generic", "generic", "specific", "specific",
"specific", "specific", "specific", "generic", "generic", "specific",
"specific", "generic", "generic", "specific", "specific", "specific",
"specific", "specific", "specific", "specific", "generic", "specific",
"generic", "specific", "specific", "generic", "generic", "specific",
"specific", "specific", "generic", "generic", "specific", "specific",
"specific"), PRSN = c("1st", "3rd", "2nd", "1st", "1st", "3rd",
"2nd", "2nd", "3rd", "3rd", "3rd", "1st", "3rd", "3rd", "1st",
"3rd", "3rd", "1st", "3rd", "3rd", "2nd", "3rd", "3rd", "2nd",
"3rd", "3rd", "3rd", "2nd", "1st", "2nd", "2nd", "3rd", "3rd",
"1st", "1st", "3rd", "1st", "3rd", "2nd", "2nd", "1st", "1st",
"3rd", "1st", "3rd", "3rd", "2nd", "1st", "1st", "1st", "1st",
"3rd", "3rd", "3rd", "3rd", "1st", "3rd", "3rd", "3rd", "1st",
"1st", "2nd", "1st", "1st", "3rd", "1st", "2nd", "3rd", "3rd",
"1st", "3rd", "3rd", "2nd", "3rd", "1st", "2nd", "2nd", "3rd",
"3rd", "3rd", "1st", "2nd", "2nd", "1st", "3rd", "3rd", "3rd",
"3rd", "1st", "3rd", "3rd", "3rd", "3rd", "1st", "1st", "2nd",
"2nd", "3rd", "1st", "2nd"), ANIM = c("animate", "inanimate",
"animate", "animate", "animate", "animate", "animate", "animate",
"inanimate", "animate", "inanimate", "animate", "inanimate",
"animate", "animate", "animate", "animate", "animate", "inanimate",
"inanimate", "animate", "animate", "animate", "animate", "animate",
"animate", "inanimate", "animate", "animate", "animate", "animate",
"inanimate", "animate", "animate", "animate", "inanimate", "animate",
"animate", "animate", "animate", "animate", "animate", "animate",
"animate", "animate", "animate", "animate", "animate", "animate",
"animate", "animate", "inanimate", "inanimate", "animate", "animate",
"animate", "inanimate", "animate", "animate", "animate", "animate",
"animate", "animate", "animate", "inanimate", "animate", "animate",
"animate", "inanimate", "animate", "animate", "animate", "animate",
"animate", "animate", "animate", "animate", "inanimate", "inanimate",
"animate", "animate", "animate", "animate", "animate", "inanimate",
"animate", "animate", "animate", "animate", "inanimate", "animate",
"animate", "inanimate", "animate", "animate", "animate", "animate",
"animate", "animate", "animate"), AGENT = c("yes", "no", "yes",
"yes", "yes", "yes", "yes", "yes", "no", "yes", "no", "yes",
"no", "yes", "yes", "yes", "yes", "yes", "no", "no", "yes", "no",
"yes", "yes", "yes", "yes", "no", "yes", "yes", "yes", "no",
"no", "yes", "yes", "yes", "no", "yes", "yes", "yes", "yes",
"yes", "yes", "no", "no", "no", "no", "yes", "yes", "yes", "yes",
"yes", "no", "no", "yes", "no", "yes", "no", "yes", "yes", "yes",
"yes", "yes", "yes", "yes", "no", "yes", "yes", "no", "no", "no",
"no", "yes", "yes", "yes", "no", "yes", "yes", "no", "no", "no",
"yes", "yes", "yes", "no", "no", "yes", "yes", "yes", "yes",
"no", "yes", "no", "no", "no", "yes", "yes", "yes", "yes", "no",
"yes"), CTGRY = c("cmt", "cmt", "wbf", "twt", "wbf", "wbf", "wbs",
"wbf", "wbf", "cmt", "cmt", "twt", "wbf", "cmt", "wbs", "cmt",
"cmt", "twt", "wbf", "wbs", "wbs", "cmt", "cmt", "twt", "cmt",
"cmt", "cmt", "wbs", "cmt", "wbs", "cmt", "wbf", "wbf", "wbf",
"wbs", "cmt", "twt", "cmt", "wbs", "wbs", "wbs", "cmt", "cmt",
"wbf", "wbf", "cmt", "wbf", "cmt", "cmt", "wbs", "twt", "cmt",
"wbf", "wbf", "wbf", "wbf", "wbf", "cmt", "cmt", "twt", "cmt",
"wbf", "cmt", "cmt", "wbf", "cmt", "wbs", "cmt", "wbs", "twt",
"cmt", "cmt", "wbs", "wbs", "twt", "wbf", "wbs", "wbf", "cmt",
"wbf", "wbf", "twt", "twt", "cmt", "cmt", "cmt", "wbf", "cmt",
"twt", "wbs", "wbf", "wbf", "cmt", "cmt", "cmt", "twt", "wbs",
"cmt", "wbf", "twt"), CNTRY = c("in", "in", "us", "in", "us",
"in", "in", "us", "us", "us", "in", "in", "in", "us", "in", "in",
"in", "in", "in", "in", "in", "in", "in", "us", "us", "in", "in",
"in", "in", "us", "us", "in", "us", "us", "in", "us", "us", "in",
"in", "in", "in", "in", "us", "in", "in", "us", "in", "in", "in",
"us", "us", "in", "us", "in", "in", "in", "us", "us", "in", "us",
"in", "in", "us", "us", "in", "us", "in", "in", "in", "in", "us",
"in", "us", "us", "in", "in", "in", "in", "in", "in", "in", "us",
"in", "us", "in", "in", "us", "in", "in", "in", "us", "in", "in",
"in", "us", "us", "in", "in", "us", "us"), VERB = c("ndt", "hvt",
"hvt", "ndt", "hvt", "hvt", "ndt", "ndt", "ndt", "mst", "ndt",
"ndt", "ndt", "ndt", "mst", "hvt", "hvt", "ndt", "hvt", "mst",
"mst", "mst", "mst", "hvt", "hvt", "mst", "ndt", "ndt", "hvt",
"ndt", "hvt", "hvt", "hvt", "hvt", "hvt", "hvt", "ndt", "ndt",
"ndt", "ndt", "ndt", "mst", "ndt", "mst", "ndt", "ndt", "ndt",
"mst", "mst", "ndt", "hvt", "mst", "mst", "hvt", "mst", "hvt",
"ndt", "ndt", "hvt", "hvt", "mst", "mst", "hvt", "hvt", "mst",
"ndt", "mst", "hvt", "ndt", "mst", "mst", "mst", "mst", "ndt",
"mst", "hvt", "ndt", "hvt", "mst", "ndt", "hvt", "ndt", "ndt",
"ndt", "hvt", "ndt", "ndt", "mst", "mst", "mst", "hvt", "mst",
"mst", "ndt", "ndt", "hvt", "ndt", "ndt", "ndt", "hvt"), ASSOC = c("must",
"need_to", "need_to", "need_to", "need_to", "have_to", "need_to",
"need_to", "have_to", "must", "must", "need_to", "have_got_to",
"have_to", "must", "must", "have_to", "must", "have_to", "need_to",
"need_to", "need_to", "must", "have_to", "need_to", "have_to",
"need_to", "must", "need_to", "need_to", "have_got_to", "must",
"have_to", "need_to", "have_to", "must", "must", "must", "have_to",
"have_to", "have_to", "need_to", "must", "must", "need_to", "must",
"need_to", "have_to", "must", "have_to", "have_to", "must", "must",
"have_to", "need_to", "must", "need_to", "need_to", "need_to",
"have_to", "must", "need_to", "have_to", "have_to", "have_to",
"need_to", "need_to", "must", "need_to", "need_to", "have_to",
"have_to", "have_to", "have_to", "need_to", "must", "have_to",
"have_to", "need_to", "must", "have_to", "have_to", "must", "need_to",
"have_to", "have_to", "need_to", "must", "have_to", "must", "must",
"must", "must", "have_got_to", "need_to", "have_to", "need_to",
"need_to", "must", "have_to")), class = "data.frame", row.names = c(NA,
-100L))
df2[sapply(df2, is.character)] <- lapply(df2[sapply(df2, is.character)],
as.factor)
I have fitted the model like this:
multi_mo <- nnet::multinom(VERB ~ SOURC + SITN + PRSN + ANIM + ASSOC +
AGENT + CTGRY + CNTRY + CTGRY:CNTRY + SOURC:CTGRY + PRSN:CTGRY, data = df2, model=TRUE, maxit=1000)
Now I wanna see the marginal effects of variables SITN
and CNTRY
(following Sonderegger (2022): 176). Sonderegger uses ggeffects::ggeffect
but I also found marginaleffects::avg_prediction
while searching on the internet. However, I get different results from these two functions. For example see the predicted probabilities of dynamic
vs stative
for hvt
in panel 1 of each graph.
ggeffects::ggeffect
output:
library(ggeffects)
pred = ggeffect(multi_mo, terms = c("SITN", "CNTRY"))
p = ggplot(data=pred, aes(x=x, y=predicted)) +
facet_wrap(~ response.level) +
geom_point(aes(shape=group)) +
geom_line(aes(group = group, linetype = group))
print(p)
marginaleffects::avg_prediction
output
library(marginaleffects)
pred = avg_predictions (multi_mo, by = c("SITN", "CNTRY"), type = "probs")
p = ggplot(data=pred, aes(x=SITN, y=estimate)) +
facet_wrap(~ group) +
geom_point(aes(shape=CNTRY)) +
geom_line(aes(group = CNTRY, linetype = CNTRY))
print(p)
Am I comparing two different things here? If that is the case, what is the relevant function for
ggeffects::ggeffect
in marginaleffects
package?
Upvotes: 0
Views: 174
Reputation: 7832
As you mentioned in your comments, predict_response()
is the wrapper around ggpredict()
(using predict()
), ggemmeans()
(using emmeans::emmeans()
) and ggaverage()
(using marginaleffects::avg_predictions()
). Each of those functions treats the non-focal terms differently, that's why margin
is the argument used to modulate how to "marginalize" over non-focal terms.
As you can see, ggeffect()
(using effects::Effect()
) and the two options for predict_response()
generate quite different predictions. It's similar to creating different data grids in marginaleffects.
If you use avg_predictions(..., by = ...)
you get different predictions as if you would use avg_predictions(..., variables = ...)
. variables
is the argument used in ggeffects::ggaverage()
(or: predict_response(margin = "empirical")
) to produce counterfactual predictions, and then you see that the plots from ggeffects and marginaleffects are identical (as expected).
multi_mo <- nnet::multinom(VERB ~ SOURC + SITN + PRSN + ANIM + ASSOC +
AGENT + CTGRY + CNTRY + CTGRY:CNTRY + SOURC:CTGRY + PRSN:CTGRY, data = df2, model=TRUE, maxit=1000)
library(ggeffects)
library(marginaleffects)
library(ggplot2)
pred <- ggeffect(multi_mo, terms = c("SITN", "CNTRY"))
plot(pred, show_ci = FALSE, connect_lines = TRUE)
pred2 <- predict_response(multi_mo, terms = c("SITN", "CNTRY"), margin = "marginalmeans")
plot(pred2, show_ci = FALSE, connect_lines = TRUE)
pred3 <- predict_response(multi_mo, terms = c("SITN", "CNTRY"), margin = "empirical")
plot(pred3, show_ci = FALSE, connect_lines = TRUE)
pred <- avg_predictions(multi_mo, by = c("SITN", "CNTRY"), type = "probs")
ggplot(data = pred, aes(x = SITN, y = estimate)) +
facet_wrap(~group) +
geom_point(aes(shape = CNTRY)) +
geom_line(aes(group = CNTRY, linetype = CNTRY))
pred <- avg_predictions(multi_mo, variables = c("SITN", "CNTRY"), type = "probs")
ggplot(data = pred, aes(x = SITN, y = estimate)) +
facet_wrap(~group) +
geom_point(aes(shape = CNTRY)) +
geom_line(aes(group = CNTRY, linetype = CNTRY))
Created on 2024-07-18 with reprex v2.1.1
As Vincent said, there's probably no best choice or only one way for the type of "marginalization" you want. I tried to describe this in terms of "what kind of question would you like to answer?", see this vignette. Maybe that helps to decide which estimands you want to have.
Upvotes: 2
Reputation: 17823
The ggeffects::ggeffect()
function appears to have been superseded, and is no longer documented extensively on that package’s website. Therefore, I’m not exactly sure what it does, and I can only describe the estimates produced by marginaleffects::avg_predictions()
.
Calling avg_predictions()
as you do is the equivalent of this simple process:
CNTRY
and SITN
.Here’s how to obtain the same results with only Base R
functions:
# predictions
p <- predict(multi_mo, type = 'probs')
# combine original data and predictions
p <- data.frame(p, df2)
# average predictions for the `hvt` outcome level, by subgroups
aggregate(hvt ~ SITN + CNTRY, FUN = mean, data = p)
> SITN CNTRY hvt
> 1 dynamic in 0.2762476
> 2 stative in 0.1712579
> 3 dynamic us 0.3835831
> 4 stative us 0.5324932
Which is equivalent to the first few rows in this:
avg_predictions(multi_mo, by = c("SITN", "CNTRY"))
>
> Group SITN CNTRY Estimate Std. Error z Pr(>|z|) S 2.5 % 97.5 %
> hvt dynamic in 0.276 0.0427 6.47 < 0.001 33.2 0.1926 0.360
> hvt dynamic us 0.384 0.0623 6.16 < 0.001 30.3 0.2615 0.506
> hvt stative in 0.171 0.0591 2.90 0.00377 8.1 0.0554 0.287
> hvt stative us 0.532 0.0940 5.66 < 0.001 26.0 0.3482 0.717
> mst dynamic in 0.356 0.0590 6.04 < 0.001 29.3 0.2408 0.472
> mst dynamic us 0.103 0.0394 2.62 0.00887 6.8 0.0259 0.180
> mst stative in 0.493 0.0925 5.33 < 0.001 23.3 0.3117 0.674
> mst stative us 0.139 0.0837 1.66 0.09725 3.4 -0.0253 0.303
> ndt dynamic in 0.367 0.0557 6.60 < 0.001 34.5 0.2582 0.476
> ndt dynamic us 0.513 0.0684 7.50 < 0.001 43.9 0.3792 0.647
> ndt stative in 0.336 0.0911 3.68 < 0.001 12.1 0.1570 0.514
> ndt stative us 0.329 0.1125 2.92 0.00350 8.2 0.1081 0.549
>
> Columns: group, SITN, CNTRY, estimate, std.error, statistic, p.value, s.value, conf.low, conf.high
> Type: probs
In this case, the by
argument indicates the variables with respect to which we are marginalizing (taking averages). By default, averages are taken across the empirical distribution of the data, that is, across predictions made for each combination of actually observed predictor values in the dataset.
In some cases, users may want to construct a different dataset to marginalize over. For example, some people like to create a “balanced” grid of predictors, where the dataset is built from each combination of categorical predictors and numeric variables are held at their means:
avg_predictions(multi_mo,
by = c("SITN", "CNTRY"),
newdata = "balanced")
>
> Group SITN CNTRY Estimate Std. Error z Pr(>|z|) S 2.5 % 97.5 %
> hvt dynamic in 0.3939 0.01520 25.91 <0.001 489.4 0.3641 0.4237
> hvt dynamic us 0.5535 0.02412 22.95 <0.001 384.7 0.5063 0.6008
> hvt stative in 0.4007 0.02201 18.21 <0.001 243.8 0.3576 0.4439
> hvt stative us 0.5661 0.03243 17.46 <0.001 224.3 0.5025 0.6296
> mst dynamic in 0.1992 0.02662 7.48 <0.001 43.6 0.1470 0.2514
> mst dynamic us 0.0611 0.00931 6.56 <0.001 34.1 0.0428 0.0793
> mst stative in 0.2125 0.04023 5.28 <0.001 22.9 0.1337 0.2914
> mst stative us 0.0637 0.01078 5.91 <0.001 28.1 0.0426 0.0848
> ndt dynamic in 0.4069 0.02867 14.19 <0.001 149.4 0.3507 0.4631
> ndt dynamic us 0.3854 0.02526 15.25 <0.001 172.1 0.3359 0.4349
> ndt stative in 0.3867 0.03823 10.12 <0.001 77.5 0.3118 0.4617
> ndt stative us 0.3702 0.03298 11.23 <0.001 94.7 0.3056 0.4349
>
> Columns: group, SITN, CNTRY, estimate, std.error, statistic, p.value, s.value, conf.low, conf.high
> Type: probs
Or you may want to estimate the predicted outcome for a unit with very specific characteristics:
predictions(multi_mo,
by = c("SITN", "CNTRY"),
newdata = datagrid(SOURC = "cond", PRSN = "2nd"))
>
> Group SITN CNTRY Estimate Std. Error z Pr(>|z|) S 2.5 % 97.5 %
> hvt dynamic in 1.00e+00 NA NA NA NA NA NA
> mst dynamic in 2.66e-18 4.99e-18 0.533 0.594 0.8 -7.12e-18 1.24e-17
> ndt dynamic in 6.20e-35 1.09e-34 0.567 0.571 0.8 -1.52e-34 2.76e-34
>
> Columns: group, SITN, CNTRY, estimate, std.error, statistic, p.value, s.value, conf.low, conf.high
> Type: probs
These are all different estimands, and the choice between them depends on subject matter expertise, with which I cannot help.
geffects::ggeffect()
almost certainly estimates a variation on those quantities, but I cannot tell which one based on the current documentation.
Upvotes: 1