PlotForestPred: Prediction plot of an ensemble forest for a grid of two...

View source: R/PlotForestPred.R

PlotForestPredR Documentation

Prediction plot of an ensemble forest for a grid of two variables

Description

Prediction plot of an ensemble forest with respect to two variables of interest while fixing other variables at their means of the coordinating site data.

Usage

PlotForestPred(
  aug_df,
  coord_df,
  coord_id,
  myfit,
  site,
  covars,
  var1,
  var2,
  site_enc_tab = NULL,
  grids = 100,
  seed = 1234,
  midpoint = 0,
  plt_title = ""
)

Arguments

aug_df

The augmented data frame used to fit an ensemble forest ('data.table').

coord_df

The coordinating site data ('data.table').

coord_id

Site index for coordinating site.

myfit

A fitted ensemble forest.

site

Variable name for site indicator.

covars

A vector of covariate names used.

var1

A character string with the name a numerical predictor that will on X-axis. If 'var1' is the site indicator, the sites are ordered by average effect in the augmented data for better visualization.

var2

A character string with the name a numerical predictor that will on Y-axis. If 'var2' is the site indicator, the sites are ordered by average effect in the augmented data for better visualization.

site_enc_tab

A data.table of mean outcome for each site. Default is NULL. If class of myfit is a 'grf', "site_enc_tab" should not be NULL.

grids

The number of points on the one-dimensional grid on x and y-axis. Default is 100.

seed

A random seed for reproducing the figure.

midpoint

The midpoint (in data value) of the diverging scale. Default is 0.

plt_title

Title of the plot. Default is "".

Value

A ggplot2 object.

Examples

data(SimDataLst)
K <- length(SimDataLst)
covars <- grep("^X", names(SimDataLst[[1]]), value=TRUE)
fit_lst <- list()
for (k in 1:K) {
    tmpdf <- SimDataLst[[k]]
    fit_lst[[k]] <- grf::causal_forest(X=as.matrix(tmpdf[, covars, with=FALSE]),
                                       Y=tmpdf$Y, W=tmpdf$Z)
}
coord_id <- 1
coord_df <- SimDataLst[[coord_id]]
aug_df <- GenAugData(coord_id, coord_df, fit_lst, covars)

## Treat each site as a distinct factor
myfit <- EnsemForest(coord_id, aug_df, "site", covars)$myfit
PlotForestPred(aug_df, coord_df, coord_id, myfit, "site", covars, "site", "X1")
PlotForestPred(aug_df, coord_df, coord_id, myfit, "site", covars, "X1", "site")
PlotForestPred(aug_df, coord_df, coord_id, myfit, "site", covars, "X1", "X2")

## Mean encoding as surrogate for site index
res_ef <- EnsemForest(coord_id, aug_df, "site", covars, is_encode=TRUE)
myfit <- res_ef$myfit
site_enc_tab <- res_ef$site_enc_tab
PlotForestPred(aug_df, coord_df, coord_id, myfit,
               "site", covars, "site", "X1", site_enc_tab)
PlotForestPred(aug_df, coord_df, coord_id, myfit,
               "site", covars, "X1", "site", site_enc_tab)
PlotForestPred(aug_df, coord_df, coord_id, myfit,
               "site", covars, "X1", "X2", site_enc_tab)



ellenxtan/ifedtree documentation built on March 28, 2023, 9:09 a.m.