kgavras
kgavras

Reputation: 31

Plot predictions with simultaneous interval from gam over the range of the smoothed variable

I am running a gam model using the mgcv package with one smoothing spline and two factor variables as additional controls. I want to show a plot of predictions over the whole range of the independent smoothed variable (with simultaneous intervals):

library(mgcv)

mod <- gam(dv_value ~ age_grps + period.f + s(born_adult), data = dat, contrasts = list(age_grps = contr.sum, period.f = contr.sum))

I first calculate the predicted values over the whole range of the born_adult variable with simultaneous interval, which seems to work quite well:

rmvn <- function(n, mu, sig) { 
  L <- mroot(sig)
  m <- ncol(L)
  t(mu + L %*% matrix(rnorm(m*n), m, n))
}

Vb <- vcov(mod)

pred <- predict(mod, se.fit = TRUE)

se.fit <- pred$se.fit

N <- 10000

BUdiff <- rmvn(N, mu = rep(0, nrow(Vb)), sig = Vb)

Cg <- predict(mod, type = "lpmatrix")
simDev <- Cg %*% t(BUdiff)

absDev <- abs(sweep(simDev, 1, se.fit, FUN = "/"))

masd <- apply(absDev, 2L, max)

crit <- quantile(masd, prob = 0.95, type = 8)

predData <- transform(cbind(data.frame(pred), dat),
                      uprP = fit + (crit * se.fit),
                      lwrP = fit - (crit * se.fit))

However, when trying to plot the results, I get a really weird plot:

ggplot() +
  geom_ribbon(aes(x = born_adult, ymin = lwrP, ymax = uprP), data = predData, alpha = 0.2, fill = "red")

https://www.dropbox.com/s/uskj9oyq8ud3zx2/plot1.png?dl=0

But, when faceting by my control variables, I get proper predictions for the separate "slices" of my data:

ggplot() +
  geom_ribbon(aes(x = born_adult, ymin = lwrP, ymax = uprP), data = predData, alpha = 0.2, fill = "red") + 
  facet_wrap(vars(period.f, age_grps))

https://www.dropbox.com/s/yju68yl8kes8mp1/plot2.png?dl=0

I have also tried predicting on a new simulated data set using the same structure as my data, however, the problem remained the same. Is there any possibility to show the "average" predictions over the whole range of my independent smoothed variable, without having to facet by the control variables? I believe it could work by taking the mean predictions grouped by the values of the born_adult variable: predData <- group_by(born_adult) %>% summarize(fit = mean(fit)) However, I have no idea on how to take the average of the simultaneous intervals for the single predictions.

Last but not least, here is a small subset of the data I am using:

dat <- structure(list(dv_value = c(0.8, 0.8, 0.4, 0.8, 1, 0.6, 0.6, 
1, 0.8, 1, 1, 1, 1, 0.4, 0.8, 0.8, 1, 0.4, 1, 0.6, 1, 0.8, 0.6, 
0, 0.6, 0.8, 0.8, 1, 0.8, 0.8, 0.8, 1, 1, 1, 0.8, 1, 0.6, 1, 
0.6, 0.8, 0.8, 0.8, 0.6, 1, 1, 1, 0.6, 1, 1, 1, 0.8, 1, 0.6, 
0.6, 1, 1, 0.8, 0.6, 0.8, 0.6, 1, 0.8, 0.8, 0.6, 0.8, 0.8, 1, 
1, 0.8, 0.8, 0.8, 1, 1, 0.6, 1, 1, 1, 1, 1, 1, 0.6, 0.8, 1, 1, 
1, 1, 1, 1, 1, 1, 1, 0.6, 1, 0.6, 0.6, 0.6, 0.8, 0.8, 0.8, 0.8, 
1, 0.4, 0.8, 1, 1, 1, 1, 0.4, 1, 1, 0.6, 1, 1, 0.4, 0.6, 0.8, 
1, 1, 0.6, 1, 1, 0.6, 1, 0.8, 0.8, 1, 0.8, 1, 0.8, 1, 0.6, 0.8, 
1, 0.8, 0.6, 0.6, 1, 0.8, 0.6, 1, 0.6, 1, 0.6, 0.8, 1, 0.6, 1, 
0.8, 0.8, 0.8, 1, 1, 1, 1, 0.2, 1, 0.6, 1, 0.8, 0.8, 1, 0.6, 
1, 0.4, 1, 0.8, 0.8, 0.4, 1, 1, 0.8, 0.8, 0.8, 1, 0.8, 0.6, 0.6, 
0.4, 0.2, 1, 0.8, 0.4, 1, 1, 0.8, 1, 0.8, 0.6, 1, 1, 1, 0.8, 
1, 0.6, 0.8, 0.8, 1, 1, 0.8, 1), age_grps = structure(c(1L, 3L, 
3L, 3L, 1L, 2L, 3L, 3L, 2L, 2L, 2L, 3L, 2L, 2L, 3L, 3L, 3L, 2L, 
3L, 3L, 3L, 2L, 3L, 2L, 2L, 2L, 3L, 3L, 3L, 1L, 3L, 2L, 3L, 3L, 
2L, 1L, 2L, 2L, 2L, 3L, 3L, 2L, 1L, 2L, 3L, 1L, 2L, 3L, 2L, 3L, 
2L, 2L, 3L, 3L, 2L, 3L, 3L, 2L, 3L, 1L, 3L, 1L, 2L, 3L, 2L, 3L, 
2L, 3L, 2L, 3L, 3L, 2L, 2L, 1L, 2L, 2L, 2L, 2L, 3L, 2L, 2L, 1L, 
2L, 3L, 2L, 3L, 3L, 2L, 3L, 1L, 3L, 2L, 2L, 2L, 1L, 2L, 2L, 1L, 
3L, 3L, 3L, 3L, 2L, 2L, 2L, 2L, 2L, 3L, 2L, 2L, 3L, 3L, 2L, 2L, 
3L, 3L, 1L, 2L, 1L, 3L, 2L, 2L, 3L, 2L, 2L, 2L, 2L, 2L, 3L, 3L, 
3L, 3L, 3L, 2L, 3L, 2L, 3L, 3L, 2L, 1L, 2L, 1L, 2L, 2L, 1L, 2L, 
2L, 2L, 3L, 2L, 1L, 2L, 2L, 2L, 1L, 1L, 2L, 2L, 2L, 3L, 1L, 1L, 
2L, 3L, 2L, 3L, 3L, 3L, 2L, 1L, 2L, 2L, 2L, 2L, 2L, 2L, 1L, 2L, 
2L, 1L, 3L, 3L, 3L, 3L, 2L, 1L, 2L, 3L, 3L, 3L, 3L, 2L, 2L, 2L, 
2L, 3L, 2L, 3L, 3L, 2L), .Label = c("1", "2", "3"), class = "factor"), 
    period.f = structure(c(9L, 9L, 6L, 5L, 10L, 2L, 3L, 6L, 13L, 
    5L, 2L, 2L, 13L, 6L, 7L, 13L, 3L, 7L, 5L, 9L, 5L, 7L, 9L, 
    10L, 7L, 13L, 3L, 13L, 6L, 2L, 10L, 6L, 9L, 9L, 9L, 13L, 
    6L, 7L, 5L, 13L, 3L, 13L, 6L, 10L, 13L, 3L, 7L, 2L, 3L, 9L, 
    10L, 2L, 6L, 6L, 2L, 7L, 6L, 5L, 13L, 2L, 13L, 2L, 3L, 9L, 
    13L, 9L, 7L, 10L, 2L, 13L, 2L, 13L, 10L, 7L, 7L, 9L, 3L, 
    6L, 5L, 5L, 9L, 7L, 13L, 2L, 3L, 6L, 6L, 2L, 13L, 10L, 13L, 
    13L, 10L, 13L, 6L, 5L, 2L, 5L, 6L, 6L, 13L, 7L, 13L, 7L, 
    13L, 13L, 13L, 9L, 13L, 3L, 13L, 13L, 10L, 3L, 10L, 7L, 13L, 
    7L, 5L, 3L, 13L, 9L, 5L, 10L, 2L, 6L, 6L, 2L, 13L, 13L, 13L, 
    9L, 6L, 10L, 5L, 13L, 13L, 7L, 6L, 6L, 7L, 7L, 6L, 3L, 2L, 
    9L, 2L, 5L, 9L, 9L, 2L, 13L, 10L, 13L, 9L, 10L, 2L, 6L, 7L, 
    6L, 2L, 5L, 13L, 5L, 3L, 9L, 7L, 13L, 7L, 3L, 9L, 7L, 9L, 
    3L, 2L, 7L, 2L, 3L, 7L, 7L, 6L, 3L, 5L, 9L, 9L, 10L, 6L, 
    6L, 10L, 2L, 10L, 6L, 6L, 5L, 13L, 3L, 13L, 3L, 3L, 2L), .Label = c("1991", 
    "1992", "1993", "1994", "1995", "1996", "1998", "2000", "2002", 
    "2005", "2008", "2014", "2018"), class = "factor"), born_adult = c(1994, 
    1953, 1937, 1944, 1996, 1977, 1944, 1953, 2001, 1976, 1963, 
    1950, 1978, 1984, 1938, 1969, 1928, 1977, 1943, 1945, 1951, 
    1968, 1959, 1971, 1978, 1998, 1951, 1976, 1951, 1987, 1950, 
    1969, 1955, 1946, 1981, 2008, 1968, 1975, 1957, 1942, 1950, 
    1978, 1993, 1986, 1974, 1982, 1960, 1948, 1953, 1943, 1980, 
    1963, 1943, 1944, 1958, 1953, 1937, 1971, 1971, 1983, 1954, 
    1984, 1979, 1952, 1984, 1946, 1959, 1949, 1979, 1953, 1947, 
    1980, 1979, 1996, 1973, 1964, 1952, 1955, 1948, 1980, 1961, 
    1994, 1991, 1949, 1979, 1947, 1941, 1955, 1962, 2004, 1974, 
    1993, 1976, 1994, 1994, 1974, 1976, 1990, 1946, 1947, 1961, 
    1941, 1991, 1986, 1983, 1983, 1988, 1953, 1990, 1965, 1961, 
    1971, 1979, 1977, 1956, 1948, 2015, 1973, 1988, 1935, 2004, 
    1983, 1948, 1993, 1976, 1960, 1959, 1980, 1968, 1968, 1970, 
    1940, 1949, 1964, 1941, 2005, 1959, 1954, 1969, 1988, 1959, 
    1989, 1971, 1975, 1989, 1980, 1953, 1955, 1959, 1972, 1986, 
    1988, 1974, 1981, 1998, 2001, 1959, 1970, 1960, 1944, 1986, 
    1984, 2000, 1946, 1978, 1930, 1952, 1956, 1979, 1982, 1969, 
    1980, 1961, 1973, 1951, 1979, 1982, 1970, 1974, 1998, 1944, 
    1941, 1950, 1948, 1978, 1999, 1955, 1930, 1961, 1942, 1962, 
    1980, 1983, 1974, 1992, 1949, 2003, 1949, 1949, 1976)), row.names = c(NA, 
-200L), class = c("tbl_df", "tbl", "data.frame"))

Any help is much appreciated!

Upvotes: 1

Views: 1413

Answers (2)

kgavras
kgavras

Reputation: 31

Thanks to Gavin for providing a solution to the faceted prediction plots! However, in order to generate a plot on the predictions for the observed values, I believe there might be a work-around applying the methods proposed by King et al. 2001 in order to get the predictions over the range of the independent variable in one plot.

The underlying problem is that the predictions for the cohorts vary by their values on the other covariates, making the plot looking so wiggly. In order to solve the problem, we can assume that the overall prediction for a given cohort is unobservable, but can imputed from the predictions and their standard errors, which we get from the mgcv::predict.gam function. Using the methods explained on page 53 in the King et al. 2001 paper, we are able to calculate the overall predicted values with their respective standard error.

Getting the overall prediction for every cohort is pretty straightforward, by just taking the mean of the predictions for each cohort. The overall standard error is a bit more complicated. It requires applying the following two formulas:

enter image description here

and

enter image description here

In order to implement these calculations in R, we can simply use some dplyr functions:

predData <- transform(cbind(data.frame(pred)))
predBornAdult <- predData %>% 
group_by(born_adult) %>% 
mutate(m = n(),
       mean_fit = mean(fit),
       S_sq = (fit - mean_fit)^2/(m - 1)) %>%
replace_na(list(S_sq = 0)) %>% # For those cohorts, where we only have one prediction
summarize(fit = mean(fit),
          S_sq = mean(S_sq),
          se.fit2 = mean(se.fit^2) + S_sq,
          se.fit = sqrt(se.fit2)) %>% 
ungroup() %>% 
mutate(uprP = fit + crit * se.fit,
       lwrP = fit - crit * se.fit) %>% 
select(born_adult, fit, uprP, lwrP)

Now that we have overall predictions for every cohort, we are able to plot them. However, we have to be aware that the observed values for the cohort are only integers, making point estimates and error bars more appropriate than ribbons, which again would look wiggly due to the fact that we have calculated overall predictions for each cohort separately.

ggplot(predBornAdult, aes(x=born_adult, 
                     y=fit)) + 
geom_errorbar(aes(ymin = lwrP, ymax = uprP)) +
geom_point(size = 1)

enter image description here

Since we always want to easily identify trends for both the predictions and its uncertainty, we can now add geom_smooth to visualize potential underlying trends:

ggplot(predBornAdult, aes(x=born_adult, y=fit)) + 
geom_errorbar(aes(ymin = lwrP, ymax = uprP), alpha = 0.2) +
geom_point(alpha = 0.2, size = 1) + 
geom_smooth(aes(y = fit), se = F, alpha = 0.5) + 
geom_smooth(aes(y = lwrP), se = F, alpha = 0.5, linetype = "solid", size = 0.5) + 
geom_smooth(aes(y = uprP), se = F, alpha = 0.5, linetype = "solid", size = 0.5)

enter image description here

Upvotes: 1

Gavin Simpson
Gavin Simpson

Reputation: 174803

I think this is just an issue due to the data being all mixed up because you are predicting for the observations.

This plot is based on yours, but I plot the fitted values, from mgcv::predict.gam() and you can see the issue immediately:

enter image description here

The red and blue lines are the upper and lower simultaneous intervals respectively, while the black line is the fitted value from mgcv::predict.gam(). As the latter hasn't been manipulated at all, I'm inclined to believe the intervals here.

This is what you get if you use the pointwise/across-the-function credible intervals:

enter image description here

which, apart from being narrower, exhibit the same behaviour.

If you are just trying to get predictions and simultaneous intervals over the range of born_adult for each combination of the two factor variables, then you should create new data to predict at that repeats a sequence of born_adult values for all combinations of age_grps and period.f. Here is do that for 50 values of born_adult — the fitted smooth is basically linear so even 50 is overkill but the intervals are smoother with larger n — using expand.grid():

pdat <- with(dat, expand.grid(
  born_adult = seq(min(born_adult), max(born_adult), length = 50),
  age_grps = unique(age_grps),
  period.f = unique(period.f)))

Then, repeating your code but adding newdata = pdat in to the simultaneous interval calculations we get them for our prediction data not the original data

Vb <- vcov(mod)
pred2 <- predict(mod, newdata = pdat, se.fit = TRUE)
N <- 10000
BUdiff <- rmvn(N, mu = rep(0, nrow(Vb)), sig = Vb)
Cg <- predict(mod, newdata = pdat, type = "lpmatrix")
simDev <- Cg %*% t(BUdiff)
absDev <- abs(sweep(simDev, 1, pred2$se.fit, FUN = "/"))
masd <- apply(absDev, 2L, max)
crit2 <- quantile(masd, prob = 0.95, type = 8)

Here I create the same as your predData, but I cbind() on pdat instead of the original data, and I add the pointwise intervals just as a check:

predData2 <- transform(cbind(data.frame(pred2), pdat),
                       uprP = fit + (crit2 * se.fit),
                       lwrP = fit - (crit2 * se.fit),
                       uprCI = fit + (2 * se.fit),
                       lwrCI = fit - (2 * se.fit))

which when plotted using

ggplot(predData2) +
  geom_ribbon(aes(x = born_adult, ymin = lwrP, ymax = uprP),
              alpha = 0.2, fill = "red") +
  geom_ribbon(aes(x = born_adult, ymin = lwrCI, ymax = uprCI),
              alpha = 0.2, fill = "red") +
  geom_line(aes(x = born_adult, y = fit)) +
  facet_wrap(vars(period.f, age_grps))

gives this:

enter image description here

If you only want it for the observed combinations of age_grps and period.f you'll need to create the prediction data somewhat differently, but the general idea still applies. (Or you could just do what I did and then delete all rows where the combination of age_grps and period.f is not one of the observed combinations.)

Upvotes: 1

Related Questions