View source: R/PlotForestPred.R
PlotForestPred | R Documentation |
Prediction plot of an ensemble forest with respect to two variables of interest while fixing other variables at their means of the coordinating site data.
PlotForestPred(
aug_df,
coord_df,
coord_id,
myfit,
site,
covars,
var1,
var2,
site_enc_tab = NULL,
grids = 100,
seed = 1234,
midpoint = 0,
plt_title = ""
)
aug_df |
The augmented data frame used to fit an ensemble forest ('data.table'). |
coord_df |
The coordinating site data ('data.table'). |
coord_id |
Site index for coordinating site. |
myfit |
A fitted ensemble forest. |
site |
Variable name for site indicator. |
covars |
A vector of covariate names used. |
var1 |
A character string with the name a numerical predictor that will on X-axis. If 'var1' is the site indicator, the sites are ordered by average effect in the augmented data for better visualization. |
var2 |
A character string with the name a numerical predictor that will on Y-axis. If 'var2' is the site indicator, the sites are ordered by average effect in the augmented data for better visualization. |
site_enc_tab |
A data.table of mean outcome for each site. Default is NULL. If class of myfit is a 'grf', "site_enc_tab" should not be NULL. |
grids |
The number of points on the one-dimensional grid on x and y-axis. Default is 100. |
seed |
A random seed for reproducing the figure. |
midpoint |
The midpoint (in data value) of the diverging scale. Default is 0. |
plt_title |
Title of the plot. Default is "". |
A ggplot2 object.
data(SimDataLst)
K <- length(SimDataLst)
covars <- grep("^X", names(SimDataLst[[1]]), value=TRUE)
fit_lst <- list()
for (k in 1:K) {
tmpdf <- SimDataLst[[k]]
fit_lst[[k]] <- grf::causal_forest(X=as.matrix(tmpdf[, covars, with=FALSE]),
Y=tmpdf$Y, W=tmpdf$Z)
}
coord_id <- 1
coord_df <- SimDataLst[[coord_id]]
aug_df <- GenAugData(coord_id, coord_df, fit_lst, covars)
## Treat each site as a distinct factor
myfit <- EnsemForest(coord_id, aug_df, "site", covars)$myfit
PlotForestPred(aug_df, coord_df, coord_id, myfit, "site", covars, "site", "X1")
PlotForestPred(aug_df, coord_df, coord_id, myfit, "site", covars, "X1", "site")
PlotForestPred(aug_df, coord_df, coord_id, myfit, "site", covars, "X1", "X2")
## Mean encoding as surrogate for site index
res_ef <- EnsemForest(coord_id, aug_df, "site", covars, is_encode=TRUE)
myfit <- res_ef$myfit
site_enc_tab <- res_ef$site_enc_tab
PlotForestPred(aug_df, coord_df, coord_id, myfit,
"site", covars, "site", "X1", site_enc_tab)
PlotForestPred(aug_df, coord_df, coord_id, myfit,
"site", covars, "X1", "site", site_enc_tab)
PlotForestPred(aug_df, coord_df, coord_id, myfit,
"site", covars, "X1", "X2", site_enc_tab)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.