R/predict.R

Defines functions na_predict na_lm na_glm na_rpart na_rf na_nb na_knn

Documented in na_glm na_knn na_lm na_nb na_predict na_rf na_rpart

#' Replace missing values using multivariate statistical model
#'
#' @param data       a data.frame.
#' @param formula    an object of class "\code{\link[stats]{formula}}":
#'                   a symbolic description of the model to be fitted.
#' @param learnFun   learning function in form \code{learnFun(formula, data, \dots)}.
#' @param predictFun function used for making predictions in form
#'                   \code{predictFun(object, newdata)}.
#' @param family     in \code{na_glm}, this is the \code{family} argument from
#'                   the \code{\link[stats]{glm}} method.
#' @param \dots      further arguments passed to \code{learnFun}.
#'
#' @details
#'
#' Multiple convenience wrappers allow user to use: linear regression (\code{na_lm}),
#' generalized linear models (\code{na_glm}), recursive partitioning and regression
#' trees (\code{na_rpart}), random forests (\code{na_rf}) and additionally, for
#' categorical data: naive Bayes (\code{na_nb}) and k-nearest neighbour classifiers
#' (\code{na_knn}). Both \code{na_rpart} and \code{na_rf} can be used for predicting
#' continuous and categorical variables.
#'
#' @seealso \code{\link[stats]{lm}}, \code{\link[stats]{glm}}, \code{\link[rpart]{rpart}},
#'          \code{\link[randomForest]{randomForest}}, \code{\link[e1071]{naiveBayes}},
#'          \code{\link[class]{knn}}
#'
#' @examples
#'
#' set.seed(123)
#'
#' dat <- mtcars
#' dat$disp[sample.int(nrow(dat), 10)] <- NA
#' dat$gear[sample.int(nrow(dat), 10)] <- NA
#' dat$gear <- as.factor(dat$gear)
#'
#' na_predict(dat, disp ~ mpg + drat, learnFun = glm, predictFun = function(object, newdata) {
#'            predict(object, newdata= newdata, type = "response") })
#' na_predict(dat, gear ~ mpg + drat, learnFun = e1071::naiveBayes)
#'
#' # continuous variables
#' na_lm(dat, disp ~ mpg + drat)
#' na_glm(dat, disp ~ mpg + drat)
#' na_rpart(dat, disp ~ mpg + drat)
#' na_rf(dat, disp ~ mpg + drat)
#'
#' # categorical variables
#' na_nb(dat, gear ~ mpg + drat)
#' na_knn(dat, gear ~ mpg + drat)
#' na_rpart(dat, factor(gear) ~ mpg + drat)
#' na_rf(dat, factor(gear) ~ mpg + drat)
#'
#' @importFrom rpart rpart
#' @importFrom class knn
#' @importFrom e1071 naiveBayes
#' @importFrom randomForest randomForest
#' @importFrom stats formula model.frame predict lm glm gaussian
#'
#' @export

na_predict <- function(data, formula, learnFun, predictFun = predict, ...) {

  y_var <- lhs_vars(formula, data)
  x_var <- rhs_vars(formula, data)
  nas <- is.na(data[, y_var])
  data[, y_var] <- as_imputed(data[, y_var])

  if ( anyNA(data[, x_var]) )
    warning("predictors contain missing values")

  model <- do.call(learnFun, list(formula, data = data[!nas, , drop = FALSE], ...))
  pred <- predictFun(model, newdata = data[nas, ])

  if ( !is_simple_vector(pred) )
    stop("invalid format of predictions")
  if ( length(pred) != sum(nas) )
    stop("model failed predict all the missing values")

  data[nas, y_var] <- pred
  data
}

#' @rdname na_predict
#' @export

na_lm <- function(data, formula, ...) {
  na_predict(data, formula, learnFun = lm, ...)
}

#' @rdname na_predict
#' @export

na_glm <- function(data, formula, family = gaussian, ...) {
  na_predict(data, formula, learnFun = glm,
             predictFun = function(object, newdata) {
               predict(object, newdata = newdata, type = "response")
             }, family = family, ...)
}

#' @rdname na_predict
#' @export

na_rpart <- function(data, formula, ...) {
  na_predict(data, formula, learnFun = rpart,
             predictFun = function(object, newdata) {
               predict(object, newdata, type =
                         ifelse(object$method == "class",
                                "class", "vector"))
             }, ...)
}

#' @rdname na_predict
#' @export

na_rf <- function(data, formula, ...) {
  na_predict(data, formula, learnFun = randomForest, ...)
}

#' @rdname na_predict
#' @export

na_nb <- function(data, formula, ...) {
  na_predict(data, formula, learnFun = naiveBayes, ...)
}

#' @rdname na_predict
#' @export

na_knn <- function(data, formula, ...) {

  y_var <- lhs_vars(formula)
  mf <- model.frame(formula, data = data, na.action = na.pass)
  y <- mf[, 1L]
  X <- mf[, -1L, drop = FALSE]
  nas <- is.na(y)
  data[, y_var] <- as_imputed(data[, y_var])

  if ( anyNA(X) )
    warning("predictors contain missing values")

  pred <- knn(X[!nas, ], X[nas, ], y[!nas], prob = FALSE, ...)

  if ( !is_simple_vector(pred) )
    stop("invalid format of predictions")
  if ( length(pred) != sum(nas) )
    stop("model failed predict all the missing values")

  data[nas, y_var] <- pred
  data
}
twolodzko/misster documentation built on May 24, 2019, 2:54 p.m.