#' Print Method
#'
#' Print method for an object of class "missRanger".
#'
#' @param x An object of class "missRanger".
#' @param ... Further arguments passed from other methods.
#' @returns Invisibly, the input is returned.
#' @export
#' @examples
#' CO2_ <- generateNA(CO2, seed = 1)
#' imp <- missRanger(CO2_, pmm.k = 5, data_only = FALSE, num.threads = 1)
#' imp
print.missRanger <- function(x, ...) {
b <- x$best_iter
cat("missRanger object. Extract imputed data via $data\n")
cat("- best iteration:", b, "\n")
cat("- best average OOB imputation error:", x$mean_pred_errors[b], "\n")
invisible(x)
}
#' Summary Method
#'
#' Summary method for an object of class "missRanger".
#'
#' @param object An object of class "missRanger".
#' @param ... Further arguments passed from other methods.
#' @returns Invisibly, the input is returned.
#' @export
#' @examples
#' CO2_ <- generateNA(CO2, seed = 1)
#' imp <- missRanger(CO2_, pmm.k = 5, data_only = FALSE, num.threads = 1)
#' summary(imp)
summary.missRanger <- function(object, ...) {
print(object)
cat("\nSequence of OOB prediction errors:\n\n")
print(object$pred_errors)
cat("\nMean performance per iteration:\n")
print(object$mean_pred_errors)
cat("\nFirst rows of imputed data:\n\n")
print(utils::head(object$data, 3L))
invisible(object)
}
#' Predict Method
#'
#' @description
#' Impute missing values on `newdata` based on an object of class "missRanger".
#'
#' For multivariate imputation, use `missRanger(..., keep_forests = TRUE)`.
#' For univariate imputation, no forests are required.
#' This can be enforced by `predict(..., iter = 0)` or via `missRanger(. ~ 1, ...)`.
#'
#' Note that out-of-sample imputation works best for rows in `newdata` with only one
#' missing value (counting only missings in variables used as covariates
#' in random forests). We call this the "easy case". In the "hard case",
#' even multiple iterations (set by `iter`) can lead to unsatisfactory results.
#'
#' @details
#' The out-of-sample algorithm works as follows:
#' 1. Impute univariately all relevant columns by randomly drawing values
#' from the original unimputed data. This step will only impact "hard case" rows.
#' 2. Replace univariate imputations by predictions of random forests. This is done
#' sequentially over variables, where the variables are sorted to minimize the impact
#' of univariate imputations. Optionally, this is followed by predictive mean matching (PMM).
#' 3. Repeat Step 2 for "hard case" rows multiple times.
#'
#' @param object 'missRanger' object.
#' @param newdata A `data.frame` with missing values to impute.
#' @param pmm.k Number of candidate predictions of the original dataset
#' for predictive mean matching (PMM). By default the same value as during fitting.
#' @param iter Number of iterations for "hard case" rows. 0 for univariate imputation.
#' @param num.threads Number of threads used by ranger's predict function.
#' The default `NULL` uses all threads.
#' @param seed Integer seed used for initial univariate imputation and PMM.
#' @param verbose Should info be printed? (1 = yes/default, 0 for no).
#' @param ... Passed to the predict function of ranger.
#' @export
#' @examples
#' iris2 <- generateNA(iris, seed = 20, p = c(Sepal.Length = 0.2, Species = 0.1))
#' imp <- missRanger(iris2, pmm.k = 5, num.trees = 100, keep_forests = TRUE, seed = 2)
#' predict(imp, head(iris2), seed = 3)
predict.missRanger <- function(
object,
newdata,
pmm.k = object$pmm.k,
iter = 4L,
num.threads = NULL,
seed = NULL,
verbose = 1L,
...
) {
stopifnot(
"'newdata' should be a data.frame!" = is.data.frame(newdata),
"'newdata' should have at least one row!" = nrow(newdata) >= 1L,
"'iter' should not be negative!" = iter >= 0L,
"'pmm.k' should not be negative!" = pmm.k >= 0L
)
data_raw <- object$data_raw
# WHICH VARIABLES TO IMPUTE?
# (a) Only those in newdata
to_impute <- intersect(object$to_impute, colnames(newdata))
# (b) Only those with missings
to_fill <- is.na(newdata[, to_impute, drop = FALSE])
missing_counts <- colSums(to_fill)
to_impute <- to_impute[missing_counts > 0L]
to_fill <- to_fill[, to_impute, drop = FALSE]
if (length(to_impute) == 0L) {
return(newdata)
}
# CHECK VARIABLES USED TO IMPUTE
impute_by <- object$impute_by
if (!all(impute_by %in% colnames(newdata))) {
stop(
"Variables not present in 'newdata': ",
paste(setdiff(impute_by, colnames(newdata)), collapse = ", ")
)
}
# We currently don't do multivariate imputation if variable not to be imputed
# has missing values
only_impute_by <- setdiff(impute_by, to_impute)
if (length(only_impute_by) > 0L && anyNA(newdata[, only_impute_by])) {
stop(
"Missing values in ", paste(only_impute_by, collapse = ", "), " not allowed."
)
}
# CONSISTENCY CHECKS WITH 'data_raw'
for (v in union(to_impute, impute_by)) {
v_new <- newdata[[v]]
v_orig <- data_raw[[v]]
if (all(is.na(v_new))) {
next # NA of wrong class is fine!
}
# class() distinguishes numeric, integer, logical, factor, character, Date, ...
# - variables in to_impute are numeric, integer, logical, factor, or character
# - variables in impute_by can also be of *mode* numeric, which includes Dates
if (!identical(class(v_new), class(v_orig))) {
stop("Inconsistency between 'newdata' and original data in variable ", v)
}
# Factor inconsistencies are not okay in 'to_impute'
if (
v %in% to_impute && is.factor(v_new) && !identical(levels(v_new), levels(v_orig))
) {
if (all(levels(v_new) %in% levels(v_orig))) {
newdata[[v]] <- factor(v_new, levels(v_orig), ordered = is.ordered(v_orig))
if (verbose >= 1L) {
message("\nExtending factor levels of '", v, "' to those in original data")
}
} else {
stop("New factor levels seen in variable to impute: ", v)
}
}
}
if (!is.null(seed)) {
set.seed(seed)
}
# UNIVARIATE IMPUTATION
for (v in to_impute) {
bad <- to_fill[, v]
v_orig <- data_raw[[v]]
donors <- sample(v_orig[!is.na(v_orig)], size = sum(bad), replace = TRUE)
if (all(bad)) {
# Handles e.g. case when original is factor, but newdata has all NA of numeric type
newdata[[v]] <- donors
} else {
newdata[[v]][bad] <- donors
}
}
if (length(impute_by) == 0L || iter == 0L) {
if (verbose >= 1L) {
message("\nOnly univariate imputations done")
}
return(newdata)
}
# MULTIVARIATE IMPUTATION
if (is.null(object$forests)) {
stop("No random forests in 'object'. Use missRanger(, keep_forests = TRUE).")
}
# Do we have a random forest for all variables with missings? If no, we don't repeat
# its univariate imputation.
forests_missing <- setdiff(to_impute, names(object$forests))
if (length(forests_missing) > 0L) {
if (verbose >= 1L) {
message(
"\nNo random forest for ", forests_missing,
". Univariate imputation done for this variable."
)
}
to_impute <- setdiff(to_impute, forests_missing)
}
# Do we have rows of "hard case"? If no, a single iteration is sufficient
hard_cols <- intersect(to_impute, impute_by)
hard_rows <- rowSums(to_fill[, hard_cols, drop = FALSE]) > 1L
if (!any(hard_rows)) {
iter <- 1L
}
# We first impute hard columns, then the rest.
# Sorting hard columns is done in decreasing order of missings, counting only
# rows of hard case. Sorting of the rest is irrelevant.
# We ignore the special case where one forest is missing
hard_counts <- colSums(to_fill[hard_rows, hard_cols, drop = FALSE])
to_impute <- c(
hard_cols[order(hard_counts, decreasing = TRUE)],
setdiff(to_impute, hard_cols) # rest
)
for (j in seq_len(iter)) {
for (v in to_impute) {
pred <- stats::predict(
object$forests[[v]],
newdata[to_fill[, v], ],
num.threads = num.threads,
verbose = verbose >= 1L,
...
)$predictions
if (pmm.k >= 1) {
xtrain <- object$forests[[v]]$predictions
ytrain <- data_raw[[v]]
if (anyNA(ytrain)) {
ytrain <- ytrain[!is.na(ytrain)] # To align with OOB predictions
}
pred <- pmm(xtrain = xtrain, xtest = pred, ytrain = ytrain, k = pmm.k)
} else if (is.logical(newdata[[v]])) {
pred <- as.logical(pred)
} else if (is.character(newdata[[v]])) {
pred <- as.character(pred)
}
newdata[[v]][to_fill[, v]] <- pred
}
if (j == 1L && iter > 1L) {
to_fill <- to_fill & hard_rows
hard_counts <- colSums(to_fill[, to_impute, drop = FALSE])
to_impute <- to_impute[hard_counts > 0L] # Need to fill only hard cases when j>1
}
}
return(newdata)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.