R/multiCV.R

Defines functions err multiCV summary.multiCV print.multiCV

Documented in err multiCV

#' Common error metrics
#' 
#' @param y,yhat vectors of true and predicted values.
#' @param wt     weights for observations, uniform by default.
#' @param method vector of names of error metrics to compute.
#' 
#' @details 
#' 
#' The following error metrics are implemented:
#' 
#' \tabular{ll}{
#' method \tab function \cr
#' ME     \tab \code{mean(y-yhat)} \cr
#' MAE    \tab \code{mean(abs(y-yhat))} \cr
#' RMSE   \tab \code{sqrt(mean((y-yhat)^2))} \cr
#' MAPE   \tab \code{mean(abs((y-yhat)/y))} \cr
#' MdAE   \tab \code{median(abs(y-yhat))} \cr
#' RMdSE  \tab \code{sqrt(median((y-yhat)^2))} \cr
#' MdAPE  \tab \code{median(abs((y-yhat)/y))} \cr
#' MRAE   \tab \code{mean(abs(y-yhat))/mean(abs(mean(y)-y))} \cr
#' MRRSE  \tab \code{mean((y-yhat)^2)/mean((mean(y)-y)^2)} \cr
#' wME    \tab \code{mean((y-yhat) * wt)} \cr
#' wMAE   \tab \code{mean(abs(y-yhat) * wt)} \cr
#' wRMSE  \tab \code{sqrt(mean((y-yhat)^2 * wt))}
#' }
#' 
#' @examples 
#' 
#' x <- rnorm(15)
#' err(x, mean(x))
#' 
#' @references 
#' 
#' Shcherbakov, M. V., Brebels, A., Shcherbakova, N. L.,
#' Tyukov, A. P., Janovsky, T. A., & Kamaev, V. A. E. (2013).
#' A survey of forecast error measures.
#' World Applied Sciences Journal, 24, 171-176.
#' 
#' @importFrom stats median
#' @export

err <- function(y, yhat, wt = 1, method = c("MAE", "RMSE", "MdAE")) {
  structure(unlist(lapply(method, function(m) {
      switch(m, 
             ME    = mean(y-yhat),
             MAE   = mean(abs(y-yhat)),
             RMSE  = sqrt(mean((y-yhat)^2)),
             MAPE  = mean(abs((y-yhat)/y)),
             MdAE  = median(abs(y-yhat)),
             RMdSE = sqrt(median((y-yhat)^2)),
             MdAPE = median(abs((y-yhat)/y)),
             MRAE  = mean(abs(y-yhat))/mean(abs(mean(y)-y)),
             MRRSE = mean((y-yhat)^2)/mean((mean(y)-y)^2),
             wME   = mean((y-yhat) * wt),
             wMAE  = mean(abs(y-yhat) * wt),
             wRMSE = sqrt(mean((y-yhat)^2 * wt)),
             stop("unsupported method"))
    })), names = method)
}


#' Cross-validation for multiple models
#' 
#' @param \dots    models.
#' @param data     data to be used for training and testing the models.
#' @param weights  vector of weights to be used for weighted error measures.
#' @param measures character vector of names of error measures to be used
#'                 (see \code{\link{err}}).
#' @param nfolds   number of folds for \emph{k}-fold cross-validation, if 
#'                 \code{nfolds < 2}, it defaults to holdout sample
#'                 cross-validation.
#' @param split    when usinng holdout sample cross validation (\code{nfolds < 2}),
#'                 this is a fraction of data to be used as a \emph{training}
#'                 set.
#'                 
#' @examples 
#' 
#' model1 <- lm(mpg ~ 1, data = mtcars)
#' model2 <- lm(mpg ~ cyl + disp, data = mtcars)
#' 
#' multiCV(model1, model2, data = mtcars)
#'
#' @importFrom stats predict formula update
#' @importFrom utils txtProgressBar setTxtProgressBar
#' @importFrom dplyr %>%     
#' @export

multiCV <- function(..., data, measures = c("MAE", "RMSE", "MdAE"),
                    weights, nfolds = 1L, split = 0.8) {
  
  dots <- list(...)
  nams <- sapply(substitute(...()), as.character)

  data <- as.data.frame(data)
  n <- nrow(data)
  k <- length(dots)
  nfolds <- nfolds[1L]
  
  if (nfolds > 1L) {
    idx <- sample(rep(1:nfolds, length.out = n))
  } else {
    split <- split[1L]
    if (split <= 0 || split >= 1)
      stop("split needs to be in (0,1)")
    idx <- sample.int(2L, n, prob = c(1-split, split), replace = TRUE)
  }
  
  if (missing(weights))
    weights <- 1
  weights <- rep(weights, length.out = n)
  
  out <- as.data.frame(matrix(NA, nfolds * k, length(measures) + 4L))
  colnames(out) <- c("fold", "model", measures, "n_test", "frac")
  
  pb <- txtProgressBar(style = 3L)
  iter <- 1L
  
  for (i in 1L:nfolds) {
    
    sel <- idx == i
    data_train <- data[-sel, ]
    data_test <- data[sel, ]
    weights_test <- weights[sel]
    
    for (j in 1L:k) {
      tmp <- tryCatch({
        model <- update(dots[[j]], data = data_train)
        y <- unname(data_test[, all.vars(formula(model))[1]])
        yhat <- unname(predict(model, newdata = data_test))
        list(y = y, yhat = yhat)
      }, error = function(e) {
        warning(e)
        list(y = NA, yhat = NA)
      })
      
      nas <- is.na(tmp$y) | is.na(tmp$yhat) | is.na(weights_test)
      nt <- sum(!nas)

      out$fold[iter] <- i
      out$model[iter] <- nams[j]
      out$n_test[iter] <- nt
      out$frac[iter] <- nt/n
      out[iter, measures] <- err(tmp$y, tmp$yhat, weights_test, measures)
      
      setTxtProgressBar(pb, iter/(nfolds*k))
      iter <- iter + 1L
    }
  }

  close(pb)
 
  structure(out, class = c("multiCV", "data.frame"))

}

#' @importFrom dplyr select group_by summarise_all funs_
#' @export

summary.multiCV <- function(object, aggregate = "mean",
                            digits = getOption("digits"), ...) {
  
  class(object) <- "data.frame"
  
  printme <- object %>%
    dplyr::select(-fold, -n_test, -frac) %>%
    group_by(model) %>%
    summarise_all(funs_(aggregate)) %>%
    as.data.frame()
  
  print(printme, digits = digits, ...)
  
}

#' @export

print.multiCV <- function(x, ...) {
  summary(x, ...)
}






# is_unit <- function(x) all(0 <= x & x <= 1)
#
# #============================================
#
# crEnt = function(y, yhat, wt) {
#   if (!is_unit(y) || !is_unit(yhat)) {
#     warning("values not in [0,1]")
#     return(NA)
#   }
#   -mean(y * log(yhat))
# },
# sqBin = function(y, yhat, wt) {
#   if (!is_unit(y) || !is_unit(yhat)) {
#     warning("values not in [0,1]")
#     return(NA)
#   }
#   mean(1 - (y * yhat)^2)
# },
# Hinge = function(y, yhat, wt) {
#   if (!is_unit(y) || !is_unit(yhat)) {
#     warning("values not in [0,1]")
#     return(NA)
#   }
#   mean(pmax(0, 1 - (y * yhat)))
# },
# lgsls = function(y, yhat, wt) {
#   if (!is_unit(y) || !is_unit(yhat)) {
#     warning("values not in [0,1]")
#     return(NA)
#   }
#   mean(log(1 + exp(-y*yhat)))
# }
twolodzko/twextras documentation built on May 3, 2019, 1:52 p.m.