knitr::opts_chunk$set(prompt=TRUE,warning=FALSE,message=FALSE,collapse=TRUE) suppressPackageStartupMessages(library(RFinfer)) suppressPackageStartupMessages(library(ggplot2)) suppressPackageStartupMessages(library(nhanesA))
This vignette is an introductory overview of the
RFinfer package for generating prediction variances from random forest predictions.The first section covers how to install the package in
R. Next, we cover a simple example showing how to implement the main function,
rfPredVar() to generate prediction with accompanying variance estimates. Finally, we cover a more in depth example using NHANES data, comparing the different types of random forests and using plots to carry out some statistical inference. These inferences are compared to a standard linear model.
For an detailed introduction to the infinitesimal jackknife and how it is applied to the random forest to generate prediction variances, see the other vignette titled
Install the package from CRAN:
Alternatively, install the latest version from GitHub:
Formerly, the package required a custom version of the
randomForest package that counted the number of times that each observation was selected in each resample, rather than only indicating if it was selected at least once. This information is necessary for the implementation of the IJ, but has since been included in the latest CRAN version of the
randomForest package (v4.6-12). If you run into problems relating to setting
keep.inbag=TRUE, then make sure you have the most recent CRAN version of the
randomForest version package installed (
install.packages('randomForest')); however, this should be installed if necessary when installing the
At the heart of the package is the
rfPredVar function which trains a random forest using either conditional inference trees or traditional CART trees and produces prediction estimates with accompanying estimates of the prediction variance. The function takes a random forest trained with the
keep.inbag=TRUE in order to detect the resampling scheme (either bootstrap or subsample). This way the user can supply the same random forest to
rfPredVar and the same resamples will be used when training the forest using either type of trees, facilitating a true comparison. The function defaults to making predictions using the original prediction data, which is useful for making inferences about the model, but can also be supplied with new data to predict on. The function returns a data.frame with the predictions and prediction variances. If requested in the call with
CI=TRUE, the 95% confidence intervals will also be returned. Finally, since the training of forests using conditional inference trees can be computationally intensive, a progress bar can be requested using
?rfPredVar for more details, including examples on how to specify forests built with subsampling instead of bootstrap sampling or utilizing conditional inference trees instead of traditional CART trees.
First load the package and some example data included with R. The New York Air Quality Measurements is a dataset of daily ozone concentrations in New York from May to September of 1973 and includes meteorologic covariates (see
?airquality for more details). Because calls to random forest do not allow missing data, we will only use complete cases here for the sake of example. We'll also kick out two very high outliers.
library(RFinfer) data('airquality') d.aq <- na.omit(airquality) d.aq <- d.aq[d.aq$Ozone < 100, ]
Next, train a random forest
keep.inbag=T option specified.
rf <- randomForest(Ozone ~ .,data=d.aq,keep.inbag=T)
Extract the prediction variances for the training data along with their 95\% confidence intervals. The values are returned in a data.frame.
rf.preds <- rfPredVar(rf,rf.data=d.aq,CI=TRUE) str(rf.preds)
Plot the predictions with their 95% confidence intervals according to the actual values.
library(ggplot2) ggplot(rf.preds,aes(d.aq$Ozone,pred)) + geom_abline(intercept=0,slope=1,lty=2,color='#999999') + geom_point() + geom_errorbar(aes(ymin=l.ci,ymax=u.ci,height=0.15)) + xlab('Actual') + ylab('Predicted') + theme_bw()
Here, we can see that the random forest is generally less confident about its more inaccurate predictions. This can also be visualized by plotting the prediction variance as a function of the prediction error.
qplot(d.aq$Ozone - rf.preds$pred,rf.preds$pred.ij.var, xlab='prediction error',ylab='prediction variance') + theme_bw()
Here we will explore a more in depth example using real data from NHANES showing how the random forest prediction variance might be used for statistical inference. We will compare the inferences to those from a linear model.
We will use the
nhanesA package to retrieve information on BMI and some possibly related demographic characteristics.
library(nhanesA) bmx_d <- nhanes('BMX_D') demo_d <- nhanes('DEMO_D') d_merged <- merge(bmx_d,demo_d) d <- data.frame('bmi' = d_merged$BMXBMI, 'age' = d_merged$RIDAGEYR, 'race' = factor(d_merged$RIDRETH1), 'poverty_income_ratio' = d_merged$INDFMPIR, 'edu' = factor(d_merged$DMDHREDU,levels=1:9))
For the sake of this example, we will omit any observations with missing data and take a random subsample of size 100. Here, we have the numeric BMI value for each observation along with the subjects age in years, race and education each defined as a categorical variable, and their income to poverty ratio.
d <- na.omit(d) set.seed(24512399) d <- d[sample(1:nrow(d),size=100), ] summary(d)
Next, we will train an initial random forest and use the
rfPredVar function to extract predictions and prediction variances for forests built using both CART and CI trees. Since we aim to plot the prediction confidence intervals, supply the
library(RFinfer) rf <- randomForest(bmi ~ .,data=d,keep.inbag=TRUE, sampsize=nrow(d)^0.7,replace=FALSE,ntrees=5000) rf.preds_cart <- rfPredVar(rf,rf.data=d,CI=TRUE,tree.type='rf') rf.preds_ci <- rfPredVar(rf,rf.data=d,CI=TRUE,tree.type='ci',prog.bar=FALSE)
To plot these, we will bind both prediction data.frames together.
preds <- rbind(data.frame(rf.preds_cart,'tree'='CART'), data.frame(rf.preds_ci,'tree'='CI')) preds$pred.ij.var <- NULL
Next, fit a linear model using the same data and add its predictions with confidence intervals to the data.frame of model predictions. We will specify
interval='prediction' since we are interested in the confidence intervals for future predictions.
LM <- lm(bmi ~ .,data=d) LM.preds <- data.frame(predict(LM,interval='prediction')) names(LM.preds) <- c('pred','l.ci','u.ci') LM.preds$tree <- 'LM' preds <- rbind(preds,LM.preds)
Also add the measured BMI, the prediction deviance, confidence interval length, and some of the covariates for plotting purposes.
preds$measured <- d[ ,'bmi'] preds$deviance <- preds$pred - preds$measured preds$CI_length <- (preds$u.ci - preds$l.ci) preds$age <- rep(d$age,3)
To compare the three model predictions, we can plot the predicted BMI according to the measured BMI for each model type, including the prediction confidence intervals.
ggplot(preds,aes(measured,pred,color=tree)) + geom_point(size=0.5) + geom_abline(intercept=0,slope=1,lty=2,color='#999999') + geom_errorbar(aes(ymin=l.ci,ymax=u.ci)) + xlab('Actual') + ylab('Predicted') + theme_bw() + facet_grid(~tree)
To compare the model performances, calculate the root mean squared error.
As we expected, the random forest models have a better performance than the linear model. There doesn't seem to be a large difference in using CART or CI trees within the random forest. It is encouraging to see that the random forest is less confident about its less accurate predictions. To visualize this, we can plot the length of the confidence intervals according to the prediction deviance. The same pattern does not hold with the linear model, with some of the highest uncertainty being associated with the linear model's most accurate predictions.
ggplot(preds,aes(deviance,CI_length,color=tree)) + geom_point() + theme_bw() + facet_wrap(~tree)
For a linear model, we can simply interpret the regression coefficient to estimate the effect of age on BMI.
We can see that age is indeed a significant predictor of BMI, with an estimated increase of 0.16 in BMI per increase in year of age. However, we don't necessarily expect age to have an exact linear relationship with age.
Since we do not have model coefficients for the random forest models, we can instead use the predictions to interpret the effect of age on BMI.
ggplot(preds,aes(age,pred,color=tree)) + geom_point(size=0.5) + geom_smooth() + geom_errorbar(aes(ymin=l.ci,ymax=u.ci)) + xlab('Age') + ylab('Predicted BMI') + theme_bw() + facet_grid(~tree)
Here, we can see that relationship of age and BMI flattens out at about age 30 and may even negatively correlate in later life (after age 60). Although the random forest predictions cannot be directly interpreted like a regression coefficient where "all other variables are held constant" because other variables also change according to age, we can still see how our sample behaves with changing age. The linear model predictions do not reveal the same nuanced relationship between age and BMI.
To make the interpretations of the random forest more like the regression coefficients, we can predict BMI for a range of ages and hold the other covariates at fixed values. Make sure to define the classes of the objects in
newdat so that they match the classes of the data used to train the original random forest.
newdat <- expand.grid('age'=2:85, 'race'=factor(1,levels=1:5), 'poverty_income_ratio'=median(d$poverty_income_ratio), 'edu'=factor(5,levels=1:9)) class(newdat$age) <- c('labelled','integer') class(newdat$poverty_income_ratio) <- c('labelled','numeric') newdat.preds <- rfPredVar(rf,rf.data=d,pred.data=newdat, CI=TRUE,tree.type='ci',prog.bar=FALSE) newdat.preds.plot <- cbind(newdat,newdat.preds)
Now, we can plot the predictions for different values of age according to fixed values of race, education, and income poverty ratio. A locally smoothed polynomial regression fit is overlaid for interpretation purposes. Although this prediction occurs at fixed values of the other covariates, it is still more insightful that a regression coefficient that assumes a linear relationship. Furthermore, we did not have to explore higher order terms or interactions and make multiple model comparisons that would be required to uncover this relationship using a linear model.
ggplot(newdat.preds.plot,aes(age,pred)) + geom_point(size=0.5) + geom_smooth() + geom_errorbar(aes(ymin=l.ci,ymax=u.ci)) + xlab('Age') + ylab('Predicted BMI') + theme_bw()
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.