#' Imputation of Missing Values by Automatic Tuned Chained Tree Ensembles
#'
#' @description Uses \code{randomForest} package to do
#' missing value imputation by automatic chained tree ensembles, see
#' [1, 2]. The optimal \code{mtry} parameter is found using the \code{tuneRF}
#' function.
#' The iterative chaining stops as soon as \code{max_iter} is reached or if the
#' average out-of-bag estimate of performance stops improving.
#' In the latter case, the best imputed data is returned.
#'
#' @param x_miss A \code{data.frame} or \code{tibble} with missing values to
#' impute.
#' @param max_iter Maximum number of chaining iterations.
#' @param seed Integer seed to initialize the random generator.
#' @param verbose Boolean. FALSE (default) to print nothing, TRUE to print
#' the OOB prediction error per iteration and variable (1 minus R-squared for
#' regression).
#' @param num_trees Number of trees passed to \code{train} function of the
#' \code{caret} package.
#'
#' @return A class with the imputed data having the smaller OOB error, and all
#' the OOB errors from the iterations of the algorithm.
#'
#' @references
#' [1] Liaw, Andy, and Matthew Wiener. "Classification and regression by
#' randomForest." R news 2.3 (2002): 18-22.
#'
#' [2] Stekhoven, D.J. and Buehlmann, P. (2012). 'MissForest - nonparametric
#' missing value imputation for mixed-type data', Bioinformatics, 28(1) 2012,
#' 112-118, doi: 10.1093/bioinformatics/btr597
#'
#'
#' @export miss_tune
#'
#' @examples
#'
#' \dontrun{
#' iris_na <- generate_na(iris)
#' iris_imp <- miss_tune(iris_na)
#' head(iris_imp)
#' head(iris_na)
#' }
miss_tune <- function(x_miss, max_iter = 10L, seed = NULL, num_trees = 200,
verbose = TRUE) {
# check parameters ###########################################################
assertthat::assert_that(is.data.frame(x_miss))
assertthat::assert_that(assertthat::is.count(max_iter))
assertthat::assert_that(assertthat::is.count(seed) | is.null(seed),
msg = "seed in not an integer"
)
assertthat::assert_that(assertthat::is.count(num_trees))
assertthat::assert_that(assertthat::is.flag(verbose),
msg = "verbose is not a boolean (TRUE/FALSE)"
)
# create data to impute #####################################################
# get variables that are factors or numeric and dont have all NAs
all_vars <- imputable_variables(x_miss)
# stop if no variable is appropriate
assertthat::assert_that(
length(all_vars) >= 2,
msg = "At least two factor and or numeric variables are needed to run the algorithm"
)
# report about ignored variables
if (verbose && length(all_vars) < ncol(x_miss)) {
message(
paste(
"Variables ignored in algorithm:",
setdiff(names(x_miss), all_vars)
)
)
}
# fing na proportion per variable and sort them increasing
x_na <- as.data.frame(is.na(x_miss[, all_vars, drop = FALSE]))
impute_vars <- sort(colMeans(x_na))
impute_vars <- names(which(impute_vars > 0))
assertthat::assert_that(
length(impute_vars) > 0,
msg = "No appropriate variable found for infilling"
)
# perform initial S.W.A.G. on x_miss ########################################
# set seed
if (!is.null(seed)) {
set.seed(seed)
}
# perform initial S.W.A.G.
x_imp <- x_miss
x_imp[impute_vars] <- lapply(impute_vars, function(j) {
if (is.numeric(x_imp[[j]])) {
imp_median(x_imp[[j]])
} else if (is.factor(x_imp[[j]])) {
imp_factor(x_imp[[j]])
}
})
# Iterate Algorithm #########################################################
# init variables
i <- 0
oob <- vector("list", max_iter)
oob_error <- rep(Inf, length(impute_vars))
names(oob_error) <- impute_vars
criterion <- FALSE
while (TRUE) {
if (verbose) {
message(paste("iteration", i + 1, "in progress..."))
}
# keep last step's results
oob_last <- oob_error
x_old <- x_imp
# tune for every variable
for (variable in impute_vars) {
na_index <- x_na[[variable]]
# create formula
frm <- reformulate(all_vars, response = variable)
# tune using RFtune
rf_tune <- randomForest::tuneRF(
y = x_imp[[variable]],
x = x_imp[setdiff(all_vars, variable)],
stepFactor = 1, ntreeTry = num_trees,
plot = FALSE, doBest = TRUE, trace = FALSE
)
# use tuned model to predict NA values
pred <- stats::predict(rf_tune, x_imp[na_index, all_vars])
x_imp[na_index, variable] <- pred
# OOB error
if (rf_tune$type == "regression") {
# rf's OOB prediction error is 1 - R2
oob_error[[variable]] <- 1 - rf_tune$rsq[num_trees]
} else {
# rf's OOB prediction error is the ratio of missclassified samples
oob_error[[variable]] <- rf_tune$err.rate[num_trees]
}
# if error metric is NAN change it to zero
if (is.nan(oob_error[[variable]])) {
oob_error[[variable]] <- 0
}
}
if (verbose) {
print(oob_error)
}
# check conditions to break algorithm iterations
i <- i + 1L
oob[[i]] <- oob_error
criterion <- mean(oob_error) > mean(oob_last)
if (i == max_iter || criterion) {
break()
}
}
# if the algorithm stopped of the criterion keep the previous iteration
if (criterion) {
x_imp <- x_old
oob_error <- oob_last
}
res <- list(
x_imp = x_imp,
oob_list = oob[1:i]
)
class(res) <- "missTune"
res
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.