#' Fit random forest metamodel
#' @description This function fits a random forest metamodel using the \code{\link[randomForestSRC]{randomForestSRC}} package.
#' @param df a dataframe.
#' @param y_var character. Name of the output variable in the dataframe. This will be the dependent variable of the metamodel.
#' @param x_vars character or a vector for characters. Name of the input variable(s) in the dataframe. This will be the independent variable of the metamodel.
#' @param seed_num numeric. Determine which seed number to use to split the dataframe in fitting an validation sets.
#' @param tune logical. Determine whether nodesize and mtry should be tuned. Nodesize is the minimum size of terminal nodes, mtry is number of variables to possibly split at each node. If FALSE, nodesize = 15 (for regression), and mtry = number of x-variables / 3 (for regression). Default is FALSE.
#' @param var_importance logical or character. Determine whether to compute variable importance (TRUE/FALSE), or how to compute variable importance (permute/random/anti). Default is FALSE. TRUE corresponds to "anti".
#' @param pm_plot logical or character. Determine whether to plot the partial ("partial") or marginal ("marginal") effect or both ("both") of an x-variable (which is denoted by pm_vars). Default is FALSE. TRUE corresponds to "both".
#' @param pm_vars character. Name of the input variable(s) for the partial/marginal plot. Default is the first variable from the x_vars.
#' @param validation logical or character. Determine whether to validate the RF model. Choices are "test_train_split" and "cross-validation". TRUE corresponds to "cross-validation", default is FALSE.
#' @param folds numeric. Number of folds for the cross-validation. Default is 5.
#' @param show_intercept logical. Determine whether to show the intercept of the perfect prediction line (x = 0, y = 0). Default is FALSE.
#' @param partition numeric. Value between 0 and 1 to determine the proportion of the observations to use to fit the metamodel. Default is 1 (fitting the metamodel using all observations).
#' @param fit_complete_model logical. Determine whether to fit the (final) full model. So the model trained on all available data (as opposed to the model used in validation which is trained on the test data).
#' @return A list containing the following elements:
#' \itemize{
#' \item fit: a list, see \code{\link[randomForestSRC:rfsrc]{randomForestSRC::rfsrc()}} for a description of the outputs contained in this list.
#' \item model_info: a list containing the following elements:
#' \itemize{
#' \item x_vars: vector of names of parameters included in the metamodel;
#' \item y_var: name outcome variable;
#' \item form: formula of the metamodel based on `x_vars` and `y_var`;
#' \item data: dataframe containing the inputs and output values used to fit (and fit) the metamodel;
#' \item type: "rf" for "random forest".}
#' \item (if `tune` = TRUE) tune_fit: a list containing the results of the tuning process, see \code{\link[randomForestSRC:tune]{randomForestSRC::tune()}} for a description of the elements containd in this list.
#' \item (if `tune` = TRUE) tune_plot: plot showing the out-of-bag error for each tested combination of 'mtry' and 'nodesize'.
#' \item (if validation != FALSE) stats_validation: data frame containing the R-squared, Mean absolute error, Mean relative error, Mean squared error in the test validation set.
#' \item (if validation = "test_train_split") calibration_plot: plot showing the rf-predicted versus observed output values in the test validation set.}
#' If `var_importance` is set to TRUE, the variable importance plot is printed in the console.
#' If `pm_plot`is used, the marginal/ partial importance plot(s) - drawn using \code{\link[randomForestSRC:plot.variable.rfsrc]{randomForestSRC::plot.variable.rfsrc()}} - is (are) printed in the console.
#' @examples
#' # Fitting and tuning a random forest meta model with two variables using the example data
#' data(df_pa)
#' fit_rf_metamodel(df = df_pa,
#' y_var = "inc_qaly",
#' x_vars = c("p_pfsd", "p_pdd"),
#' tune = TRUE
#' )
#' @import interp
#' @import ggplot2
#' @import randomForestSRC
#' @export
fit_rf_metamodel <- function(df,
y_var = NULL,
x_vars = NULL,
seed_num = 1,
tune = FALSE,
var_importance = FALSE, #or permute/random/ TRUE(=anti)/FALSE
pm_plot = FALSE,
pm_vars = x_vars[1],
validation = FALSE, #= TRUE(=cross_validation)/FALSE/train_test_split
folds = 5, #if not a whole number is entered it's rounded DOWN to nearest integer
show_intercept = FALSE,
partition = 1,
fit_complete_model = TRUE
){
# Flag errors
if(length(y_var) > 1) {
stop("Multiple outcomes provided to 'y'.")
}
if(is.null(y_var)) {
stop("Cannot perform random forest regression because there is no value provided for 'y_var'.")
}
if(is.null(x_vars)) {
stop("Cannot perform random forest regression because there is no value provided for the predictors.")
}
if(length(var_importance)>1 || !all(var_importance %in% c(TRUE, FALSE, "anti", "permute", "random"))) {
stop("'var_importance' should be one of: TRUE, FALSE, 'anti','permute','random'.")
}
if(length(pm_plot)>1 || !(pm_plot %in% c(TRUE, FALSE, "partial", "marginal", "both"))) {
stop("'pm_plot' should be one of: TRUE, FALSE, 'partial','marginal','both'.")
}
if(!all(pm_vars %in% x_vars)) {
stop("Cannot produce the partial/marginal plot because at least one of the 'pm_vars' is not in 'x_vars'.")
}
if(partition < 0 || partition > 1) {
stop("Proportion selected for training the metamodel should be between 0 (excluded) and 1 (included).")
}
if(partition == 1 && validation == "train_test_split") {
stop("Cannot perform validation because all observations are included in the training set. Lower `partition` below 1.")
}
if(!(validation %in% c(FALSE,"cross_validation","train_test_split"))) {
stop("Validation must be one of: FALSE, 'cross_validation','train_test_split'.")
}
if(folds < 2 || folds > nrow(df)){
stop("Folds must be larger than 1 and smaller than or equal to the number of rows of the dataframe.")
}
# Remove any possible NA's
df = na.omit(df)
v_x <- paste(unique(x_vars), collapse = " + ")
form <- as.formula(paste(y_var, "~", v_x))
# Set up
set.seed(seed_num)
l_out = list(stats_validation = NULL,
calibration_plot = NULL,
fit = NULL,
tune_fit = NULL,
tune_plot = NULL,
model_info = list(x_vars = x_vars,
y_var = y_var,
form = form,
data = df,
type = "rf")
)
# Set default mtry and nodesize
nodesize = NULL
mtry = NULL
# Fit rf model if specified
if (fit_complete_model == TRUE){
# Tune mtry & nodesize
if (tune == TRUE){
rf_tune <- tune(form,
data = df,
splitrule = "mse")
nodesize = rf_tune$optimal[[1]]
mtry = rf_tune$optimal[[2]]
## Tune plot
x <- rf_tune$results[,1]
y <- rf_tune$results[,2]
z <- rf_tune$results[,3]
### If there is only one value tried for 'mtry' (=y) (so no 2D interpolation, but 1D over 'nodesize' (=x))
if (length(unique(y))==1){
xi=seq(min(x),max(x),0.1)
z_interp = interp1(x=unique(x),y=z,method="linear",xi=xi)
idx <- which.min(z)
x0 <- x[idx]
z0 = z[idx]
df_interp = data.frame(xi,z_interp)
tune_plot = ggplot() + geom_line(aes(x=xi,y=z_interp,color=z_interp),data=df_interp,linewidth=1) +
geom_smooth(aes(color=..y..)) +
scale_colour_gradient2(low = "blue", mid = "yellow" , high = "red",
midpoint=median(df_interp$z_interp),name="OOB error") +
guides(fill = guide_colourbar(title="OOB error")) +
geom_point(aes(x=x,y=z)) +
geom_point(aes(x=x0,y=z0), colour="black",size=8,pch="x") +
xlab('Nodesize') +
ylab("OOB error") +
ggtitle("Error rate for nodesize") +
theme_bw() +
theme(plot.title = element_text(hjust = 0.5))
}
else {
so <- interp(x=x, y=y, z=z, method = "linear",output = "grid")
idx <- which.min(z)
x0 <- x[idx]
y0 <- y[idx]
so_v = c(so$z)
xy_grid = expand.grid(so$x,so$y)
df_interp = data.frame(xy_grid,so_v)
colnames(df_interp) = c("nodesize","mtry","error")
tune_plot = ggplot() + geom_raster(aes(nodesize,mtry,fill=error),df_interp, interpolate = TRUE) +
scale_fill_gradientn(colours=c("yellow","red"),na.value="white") +
geom_point(aes(x=x0,y=y0), colour="black",size=8,pch="x") +
geom_point(aes(x=x,y=y)) +
scale_y_continuous(expand = expansion(mult = .5)) +
guides(fill = guide_colourbar(title="OOB error")) +
xlab("Nodesize") +
ylab("Mtry") +
ggtitle('Error rate for nodesize and mtry') +
theme_bw() +
theme(plot.title = element_text(hjust = 0.5))
}
## return
l_out[4] = list(rf_tune)
l_out[5] = list(tune_plot)
}
else {
l_out = l_out[-c(4,5)]
}
# Fit random forest model with tuned parameters
rf_fit = rfsrc(form,
data = df,
splitrule = "mse",
nodesize = nodesize,
mtry = mtry,
forest = TRUE,
importance = var_importance
)
l_out[3] = list(rf_fit)
# Show plots
## variable importance plot
if (var_importance != FALSE){
p_vimp <- plot(rf_fit, verbose = F, plots.one.page = F)
}
## partial and/or marginal plot
if (pm_plot != FALSE){
if (pm_plot == "both" || pm_plot == TRUE) {
plot.variable.rfsrc(rf_fit, xvar.names = pm_vars, partial = TRUE, show.plots = TRUE, sort = TRUE, plots.per.page = 1, main = "Partial dependence plot", ylab = rf_fit$fit$yvar.names)
plot.variable.rfsrc(rf_fit, xvar.names = pm_vars, partial = FALSE, show.plots = TRUE, sort = TRUE, plots.per.page = 1, main = "Marginal dependence plot", ylab = rf_fit$fit$yvar.names)
}
else if (pm_plot == "partial") {
plot.variable.rfsrc(rf_fit, xvar.names = pm_vars, partial = TRUE, show.plots = TRUE, sort = TRUE, plots.per.page = 1, main = "Partial dependence plot", ylab = rf_fit$fit$yvar.names)
}
else if (pm_plot == "marginal") {
plot.variable.rfsrc(rf_fit, xvar.names = pm_vars, partial = FALSE, show.plots = TRUE, sort = TRUE, plots.per.page = 1, main = "Marginal dependence plot", ylab = rf_fit$fit$yvar.names)
}
}
}
else {
l_out = l_out[-3]
}
# Validation
if (validation == "cross_validation"){
## Re-sample the data and make folds
df_validation = df[sample(nrow(df)),]
folds_ind = cut(seq(1,nrow(df_validation)),breaks=folds,labels=FALSE)
r_squared_validation = rep(NA,folds)
mae_validation = rep(NA,folds)
mre_validation = rep(NA,folds)
mse_validation = rep(NA,folds)
for (i in 1:folds){
test_indices = which(folds_ind == i)
df_test = df_validation[test_indices,]
df_train = df_validation[-test_indices,]
## Tune
if (tune == TRUE){
rf_tune_validation <- tune(form,
data = df_train,
splitrule = "mse")
nodesize_validation = rf_tune_validation$optimal[[1]]
mtry_validation = rf_tune_validation$optimal[[2]]
}
else {
nodesize_validation = NULL
mtry_validation = NULL
}
## Fit with tuned (if specified) or default parameters
rf_fit_validation = rfsrc(form,
data = df_train,
splitrule = "mse",
nodesize = nodesize_validation,
mtry = mtry_validation,
forest = TRUE
)
## Test on test data
preds = predict(rf_fit_validation, newdata = df_test)$predicted
tests = df_test[, y_var]
## Store performance metrics
r_squared_validation[i] = cor(preds, tests) ^ 2
mae_validation[i] = mean(abs(preds - tests))
mre_validation[i] = mean(abs(preds - tests)/abs(tests))
mse_validation[i] = mean((preds - tests) ^ 2)
}
## Store results
stats_validation = data.frame(
Statistic = c("R-squared", "Mean absolute error", "Mean relative error","Mean squared error"),
Value = round(c(mean(r_squared_validation),mean(mae_validation),mean(mre_validation),mean(mse_validation)),3)
)
names(stats_validation)[names(stats_validation) == "Value"] <- "Value (method: cross-validation)"
l_out[1] = list(stats_validation)
l_out = l_out[-2]
}
else if (validation == "train_test_split"){
## Partition data and fit to train data
selection = sample(1:nrow(df), size = round(nrow(df) * partition), replace = FALSE)
df_train = df[selection, ]
df_test = df[-selection, ]
## Tune
if (tune == TRUE){
rf_tune_validation <- tune(form,
data = df_train,
splitrule = "mse")
nodesize_validation = rf_tune_validation$optimal[[1]]
mtry_validation = rf_tune_validation$optimal[[2]]
}
else {
nodesize_validation = NULL
mtry_validation = NULL
}
## Fit with tuned (if specified) or default parameters
rf_fit_validation = rfsrc(form,
data = df_train,
splitrule = "mse",
nodesize = nodesize_validation,
mtry = mtry_validation,
forest = TRUE
)
## Test on test data
preds = predict(rf_fit_validation, newdata = df_test)$predicted
tests = df_test[,y_var]
r_squared_validation = cor(preds,tests)^2
mae_validation = mean(abs(preds-tests))
mre_validation = mean(abs(preds-tests)/abs(tests))
mse_validation = mean((preds-tests)^2)
## Calibration plot: predicted versus observed
df_test$y_pred = preds
calibration_plot <- ggplot2::ggplot(ggplot2::aes_string(x = "y_pred", y = y_var), data = df_test) +
ggplot2::geom_point(shape = 1) +
ggplot2::geom_abline(intercept = 0, slope = 1, colour = "orange") +
ggplot2::xlab("Predicted values") +
ggplot2::ylab("Observed values") +
ggplot2::ggtitle(paste("Calibration plot for",y_var)) +
ggplot2::theme_bw() +
ggplot2::theme(plot.title = element_text(hjust = 0.5))
if(show_intercept == TRUE) {
calibration_plot <- calibration_plot +
ggplot2::geom_abline(intercept = 0, slope = 1, colour = "orange")
}
stats_validation = data.frame(
Statistics = c("R-squared","Mean absolute error","Mean relative error","Mean squared error"),
Value = round(c(r_squared_validation,mae_validation,mre_validation,mse_validation),3)
)
names(stats_validation)[names(stats_validation) == "Value"] <- "Value (method: train/test split)"
l_out[1] = list(stats_validation)
l_out[2] = list(calibration_plot)
}
else {
l_out = l_out[-c(1,2)]
}
# Export
return(l_out)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.