user3056186
user3056186

Reputation: 869

using R to plot interaction plot

I have created a model using following

      age    hrs  charges
 530.6071 792.10  3474.60
 408.6071 489.70  1247.06
 108.0357 463.00  1697.07
 106.6071 404.15  1676.33
 669.4643 384.65  1701.13
 556.4643 358.15  1630.30
 665.4643 343.85  2468.83
 508.4643 342.35  3366.44
 106.0357 335.25  2876.82

interaction_model <- rlm( charges~age+hrs+age*hrs, age_vs_hrs_charges_cleaned);

Any idea how i can plot this in 3D?

I already plotted using

library(effects);
plot(effect(term="age:hrs", mod=interaction_model,default.levels=20),multiline=TRUE);

but this is not very clear visualization.

Any help?

Upvotes: 3

Views: 4700

Answers (2)

Tom Wenseleers
Tom Wenseleers

Reputation: 8019

A little while ago I wrote a couple of functions to display the results of a (general) linear model, together with colour coded data points, in either 3D (interactive, using rgl) or 2D (using a contour plot) :

# plot predictions of a (general) linear model as a function of two explanatory variables as an image / contour plot
# together with the actual data points
# mean value is used for any other variables in the model
plotImage=function(model=NULL,plotx=NULL,ploty=NULL,plotPoints=T,plotContours=T,plotLegend=F,npp=1000,xlab=NULL,ylab=NULL,zlab=NULL,xlim=NULL,ylim=NULL,pch=16,cex=1.2,lwd=0.1,col.palette=NULL) {
  n=npp
  require(rockchalk)
  require(aqfig)
  require(colorRamps)
  require(colorspace)
  require(MASS)
  mf=model.frame(model);emf=rockchalk::model.data(model)
  if (is.null(xlab)) xlab=plotx
  if (is.null(ylab)) ylab=ploty
  if (is.null(zlab)) zlab=names(mf)[[1]]
  if (is.null(col.palette)) col.palette=rev(rainbow_hcl(1000,c=100))
  x=emf[,plotx];y=emf[,ploty];z=mf[,1]
  if (is.null(xlim)) xlim=c(min(x)*0.95,max(x)*1.05)
  if (is.null(ylim)) ylim=c(min(y)*0.95,max(y)*1.05)
  preds=predictOMatic(model,predVals=c(plotx,ploty),n=npp,divider="seq")
  zpred=matrix(preds[,"fit"],npp,npp)
  zlim=c(min(c(preds$fit,z)),max(c(preds$fit,z)))
  par(mai=c(1.2,1.2,0.5,1.2),fin=c(6.5,6))
  graphics::image(x=seq(xlim[1],xlim[2],len=npp),y=seq(ylim[1],ylim[2],len=npp),z=zpred,xlab=xlab,ylab=ylab,col=col.palette,useRaster=T,xaxs="i",yaxs="i")
  if (plotContours) graphics::contour(x=seq(xlim[1],xlim[2],len=npp),y=seq(ylim[1],ylim[2],len=npp),z=zpred,xlab=xlab,ylab=ylab,add=T,method="edge")
  if (plotPoints) {cols1=col.palette[(z-zlim[1])*999/diff(zlim)+1]
                   pch1=rep(pch,length(n))
                   cols2=adjustcolor(cols1,offset=c(-0.3,-0.3,-0.3,1))
                   pch2=pch-15
                   points(c(rbind(x,x)),c(rbind(y,y)), cex=cex,col=c(rbind(cols1,cols2)),pch=c(rbind(pch1,pch2)),lwd=lwd) }
  box()
  if (plotLegend) vertical.image.legend(zlim=zlim,col=col.palette) # TO DO: add z axis label, maybe make legend a bit smaller?
}

# plot predictions of a (general) linear model as a function of two explanatory variables as an interactive 3D plot
# mean value is used for any other variables in the model
plotPlaneFancy=function(model=NULL,plotx1=NULL,plotx2=NULL,plotPoints=T,plotDroplines=T,npp=50,x1lab=NULL,x2lab=NULL,ylab=NULL,x1lim=NULL,x2lim=NULL,cex=1.5,col.palette=NULL,segcol="black",segalpha=0.5,interval="none",confcol="lightgrey",confalpha=0.4,pointsalpha=1,lit=T,outfile="graph.png",aspect=c(1,1,0.3),zoom=1,userMatrix=matrix(c(0.80,-0.60,0.022,0,0.23,0.34,0.91,0,-0.55,-0.72,0.41,0,0,0,0,1),ncol=4,byrow=T),windowRect=c(0,29,1920,1032)) { # or library(colorRamps);col.palette <- matlab.like(1000)
  require(rockchalk)
  require(rgl)
  require(colorRamps)
  require(colorspace)
  require(MASS)
  mf=model.frame(model);emf=rockchalk::model.data(model)
  if (is.null(x1lab)) x1lab=plotx1
  if (is.null(x2lab)) x2lab=plotx2
  if (is.null(ylab)) ylab=names(mf)[[1]]
  if (is.null(col.palette)) col.palette=rev(rainbow_hcl(1000,c=100)) 
  x1=emf[,plotx1]
  x2=emf[,plotx2]
  y=mf[,1]
  if (is.null(x1lim)) x1lim=c(min(x1),max(x1))
  if (is.null(x2lim)) x2lim=c(min(x2),max(x2))
  preds=predictOMatic(model,predVals=c(plotx1,plotx2),n=npp,divider="seq",interval=interval)
  ylim=c(min(c(preds$fit,y)),max(c(preds$fit,y)))
  open3d(zoom=zoom,userMatrix=userMatrix,windowRect=windowRect)
  if (plotPoints) plot3d(x=x1,y=x2,z=y,type="s",col=col.palette[(y-min(y))*999/diff(range(y))+1],size=cex,aspect=aspect,xlab=x1lab,ylab=x2lab,zlab=ylab,lit=lit,alpha=pointsalpha)
  if (!plotPoints) plot3d(x=x1,y=x2,z=y,type="n",col=col.palette[(y-min(y))*999/diff(range(y))+1],size=cex,aspect=aspect,xlab=x1lab,ylab=x2lab,zlab=ylab)
  if ("lwr" %in% names(preds)) persp3d(x=unique(preds[,plotx1]),y=unique(preds[,plotx2]),z=matrix(preds[,"lwr"],npp,npp),color=confcol, alpha=confalpha, lit=lit, back="lines",add=TRUE)
  ypred=matrix(preds[,"fit"],npp,npp)
  cols=col.palette[(ypred-min(ypred))*999/diff(range(ypred))+1]
  persp3d(x=unique(preds[,plotx1]),y=unique(preds[,plotx2]),z=ypred,color=cols, alpha=0.7, lit=lit, back="lines",add=TRUE)
  if ("upr" %in% names(preds)) persp3d(x=unique(preds[,plotx1]),y=unique(preds[,plotx2]),z=matrix(preds[,"upr"],npp,npp),color=confcol, alpha=confalpha, lit=lit, back="lines",add=TRUE)
  if (plotDroplines) segments3d(x=rep(x1,each=2),y=rep(x2,each=2),z=matrix(t(cbind(y,fitted(model))),nc=1),col=segcol,lty=2,alpha=segalpha)
  if (!is.null(outfile)) rgl.snapshot(outfile, fmt="png", top=TRUE)
}

Here is what you get as output with your model :

data=data.frame(age=c(530.6071,408.6071,108.0357,106.6071,669.4643,556.4643,665.4643,508.4643,106.0357),
                hrs=c(792.10,489.70,463.00,404.15,384.65,358.15,343.85,342.35,335.25),
                charges=c(3474.60,1247.06,1697.07,1676.33,1701.13,1630.30,2468.83,3366.44,2876.82))
library(MASS)
fit1=rlm( charges~age+hrs+age*hrs, data)

plotPlaneFancy(fit1, plotx1 = "age", plotx2 = "hrs")

enter image description here

plotPlaneFancy(fit1, plotx1 = "age", plotx2 = "hrs",interval="confidence")

enter image description here

(or interval="prediction" to show 95% prediction intervals)

plotImage(fit1,plotx="age",ploty="hrs",plotContours=T,plotLegend=T)

enter image description here

Upvotes: 4

jlhoward
jlhoward

Reputation: 59415

There are several ways to do this.

model <- lm( charges~age+hrs+age*hrs, df)
# set up grid of (x,y) values
age <- seq(0,1000, by=20)
hrs <- seq(0,1000, by=20)
gg <- expand.grid(age=age, hrs=hrs)
# prediction from the linear model
gg$charges <-predict(model,newdata=gg)

# contour plot 
library(ggplot2)
library(colorRamps)  
library(grDevices)
jet.colors <- colorRampPalette(matlab.like(9))
ggplot(gg, aes(x=age, y=hrs, z=charges))+
  stat_contour(aes(color=..level..),binwidth=200, size=2)+
  scale_color_gradientn(colours=jet.colors(8))

# 3D scatterplot
library(scatterplot3d)
scatterplot3d(gg$age, gg$hrs, gg$charges)

# interactive 3D scatterplot (just a screen shot here)
library(rgl)
plot3d(gg$age,gg$hrs,gg$charges)

# interactive 3D surface plot with shading (screen shot)
colorjet <- jet.colors(100)
open3d()
rgl.surface(x=age, z=hrs, y=0.05*gg$charges, 
            color=colorzjet[ findInterval(gg$charges, seq(min(gg$charges), max(gg$charges), length=100))] )
axes3d()

Upvotes: 5

Related Questions