#' Replace missing values using multivariate statistical model
#'
#' @param data a data.frame.
#' @param formula an object of class "\code{\link[stats]{formula}}":
#' a symbolic description of the model to be fitted.
#' @param learnFun learning function in form \code{learnFun(formula, data, \dots)}.
#' @param predictFun function used for making predictions in form
#' \code{predictFun(object, newdata)}.
#' @param family in \code{na_glm}, this is the \code{family} argument from
#' the \code{\link[stats]{glm}} method.
#' @param \dots further arguments passed to \code{learnFun}.
#'
#' @details
#'
#' Multiple convenience wrappers allow user to use: linear regression (\code{na_lm}),
#' generalized linear models (\code{na_glm}), recursive partitioning and regression
#' trees (\code{na_rpart}), random forests (\code{na_rf}) and additionally, for
#' categorical data: naive Bayes (\code{na_nb}) and k-nearest neighbour classifiers
#' (\code{na_knn}). Both \code{na_rpart} and \code{na_rf} can be used for predicting
#' continuous and categorical variables.
#'
#' @seealso \code{\link[stats]{lm}}, \code{\link[stats]{glm}}, \code{\link[rpart]{rpart}},
#' \code{\link[randomForest]{randomForest}}, \code{\link[e1071]{naiveBayes}},
#' \code{\link[class]{knn}}
#'
#' @examples
#'
#' set.seed(123)
#'
#' dat <- mtcars
#' dat$disp[sample.int(nrow(dat), 10)] <- NA
#' dat$gear[sample.int(nrow(dat), 10)] <- NA
#' dat$gear <- as.factor(dat$gear)
#'
#' na_predict(dat, disp ~ mpg + drat, learnFun = glm, predictFun = function(object, newdata) {
#' predict(object, newdata= newdata, type = "response") })
#' na_predict(dat, gear ~ mpg + drat, learnFun = e1071::naiveBayes)
#'
#' # continuous variables
#' na_lm(dat, disp ~ mpg + drat)
#' na_glm(dat, disp ~ mpg + drat)
#' na_rpart(dat, disp ~ mpg + drat)
#' na_rf(dat, disp ~ mpg + drat)
#'
#' # categorical variables
#' na_nb(dat, gear ~ mpg + drat)
#' na_knn(dat, gear ~ mpg + drat)
#' na_rpart(dat, factor(gear) ~ mpg + drat)
#' na_rf(dat, factor(gear) ~ mpg + drat)
#'
#' @importFrom rpart rpart
#' @importFrom class knn
#' @importFrom e1071 naiveBayes
#' @importFrom randomForest randomForest
#' @importFrom stats formula model.frame predict lm glm gaussian
#'
#' @export
na_predict <- function(data, formula, learnFun, predictFun = predict, ...) {
y_var <- lhs_vars(formula, data)
x_var <- rhs_vars(formula, data)
nas <- is.na(data[, y_var])
data[, y_var] <- as_imputed(data[, y_var])
if ( anyNA(data[, x_var]) )
warning("predictors contain missing values")
model <- do.call(learnFun, list(formula, data = data[!nas, , drop = FALSE], ...))
pred <- predictFun(model, newdata = data[nas, ])
if ( !is_simple_vector(pred) )
stop("invalid format of predictions")
if ( length(pred) != sum(nas) )
stop("model failed predict all the missing values")
data[nas, y_var] <- pred
data
}
#' @rdname na_predict
#' @export
na_lm <- function(data, formula, ...) {
na_predict(data, formula, learnFun = lm, ...)
}
#' @rdname na_predict
#' @export
na_glm <- function(data, formula, family = gaussian, ...) {
na_predict(data, formula, learnFun = glm,
predictFun = function(object, newdata) {
predict(object, newdata = newdata, type = "response")
}, family = family, ...)
}
#' @rdname na_predict
#' @export
na_rpart <- function(data, formula, ...) {
na_predict(data, formula, learnFun = rpart,
predictFun = function(object, newdata) {
predict(object, newdata, type =
ifelse(object$method == "class",
"class", "vector"))
}, ...)
}
#' @rdname na_predict
#' @export
na_rf <- function(data, formula, ...) {
na_predict(data, formula, learnFun = randomForest, ...)
}
#' @rdname na_predict
#' @export
na_nb <- function(data, formula, ...) {
na_predict(data, formula, learnFun = naiveBayes, ...)
}
#' @rdname na_predict
#' @export
na_knn <- function(data, formula, ...) {
y_var <- lhs_vars(formula)
mf <- model.frame(formula, data = data, na.action = na.pass)
y <- mf[, 1L]
X <- mf[, -1L, drop = FALSE]
nas <- is.na(y)
data[, y_var] <- as_imputed(data[, y_var])
if ( anyNA(X) )
warning("predictors contain missing values")
pred <- knn(X[!nas, ], X[nas, ], y[!nas], prob = FALSE, ...)
if ( !is_simple_vector(pred) )
stop("invalid format of predictions")
if ( length(pred) != sum(nas) )
stop("model failed predict all the missing values")
data[nas, y_var] <- pred
data
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.