R/crossval_ml.R

Defines functions crossval_ml

Documented in crossval_ml

#' Generic cross-validation function
#'
#' Generic cross-validation
#'
#' @param x input covariates' matrix
#' @param y response variable; a vector
#' @param fit_func a function for fitting the model
#' @param predict_func a function for predicting values from the model
#' @param fit_params a list; additional (model-specific) parameters to be passed
#' to \code{fit_func}
#' @param k an integer; number of folds in k-fold cross validation
#' @param repeats an integer; number of repeats for the k-fold cross validation
#' @param p a double; proportion of data in the training/testing set, default is 1 and
#' must be > 0.5. If \code{p} < 1, a validation set error is calculated on the
#' remaining 1-\code{p} fraction data
#' @param seed random seed for reproducibility of results
#' @param eval_metric a function measuring the test errors; if not provided: RMSE for regression and
#' accuracy for classification
#' @param cl an integer; the number of clusters for parallel execution
#' @param errorhandling specifies how a task evalution error should be handled.
#' If value is "stop", then execution will be stopped if an error occurs. If value
#' is "remove", the result for that task will not be returned. If value is "pass",
#' then the error object generated by task evaluation will be included with the
#' rest of the results. The default value is "stop".
#' @param packages character vector of packages that the tasks depend on
#' @param verbose logical flag enabling verbose messages. This can be very useful for
#' troubleshooting.
#' @param show_progress show evolution of the algorithm
#' @param ... additional parameters
#'
#' @return
#' @export
#'
#' @examples
#'
#'# dataset
#'
#' set.seed(123)
#' n <- 1000 ; p <- 10
#' X <- matrix(rnorm(n * p), n, p)
#' y <- rnorm(n)
#'
#'# linear model example -----
#'
#' crossval::crossval_ml(x = X, y = y, k = 5, repeats = 3)
#'
#'
#'# randomForest example -----
#'
#'require(randomForest)
#'
#'# fit randomForest with mtry = 2
#'
#'crossval::crossval_ml(x = X, y = y, k = 5, repeats = 3,
#'                   fit_func = randomForest::randomForest, predict_func = predict,
#'                   packages = "randomForest", fit_params = list(mtry = 2))
#'
#'# fit randomForest with mtry = 4
#'
#'crossval::crossval_ml(x = X, y = y, k = 5, repeats = 3,
#'                   fit_func = randomForest::randomForest, predict_func = predict,
#'                   packages = "randomForest", fit_params = list(mtry = 4))
#'
#'# fit randomForest with mtry = 4, with a validation set
#'
#'crossval::crossval_ml(x = X, y = y, k = 5, repeats = 2, p = 0.8,
#'                   fit_func = randomForest::randomForest, predict_func = predict,
#'                   packages = "randomForest", fit_params = list(mtry = 4))
#'
crossval_ml <- function(x,
                        y,
                        fit_func = crossval::fit_lm,
                        predict_func = crossval::predict_lm,
                        fit_params = NULL,
                        # and hyperparameters
                        k = 5,
                        repeats = 3,
                        p = 1,
                        seed = 123,
                        eval_metric = NULL,
                        cl = NULL,
                        errorhandling = c('stop', 'remove', 'pass'),
                        packages = c("stats", "Rcpp"),
                        verbose = FALSE,
                        show_progress = TRUE,
                     ...) {
  n_y <- length(y)
  stopifnot(n_y == nrow(x))

  set.seed(seed)
  if (p == 1)
    # default
  {
    x <- as.matrix(x)
  } else {
    index_train <- sample.int(n_y, size = floor(p * n_y))
    x <- as.matrix(x[index_train,])
    y <- y[index_train]
    x_validation <- as.matrix(x[-index_train,])
    y_validation <- y[-index_train]
  }

  errorhandling <- match.arg(errorhandling)
  stopifnot(floor(k) == k || k > 10)
  stopifnot(p >= 0.5 && p <= 1)
  stopifnot(floor(repeats) == repeats)

  # evaluation metric for cv error
  if (is.null(eval_metric))
  {
    if (is.factor(y))
      # classification
    {
      eval_metric <- function (preds, actual)
      {
        res <- mean(preds == actual)
        names(res) <- "accuracy"
        return(res)
      }

    } else {
      # regression

      eval_metric <- function (preds, actual)
      {
        res <- sqrt(mean((preds - actual) ^ 2))
        names(res) <- "rmse"
        return(res)
      }

    }
  }

  set.seed(seed)
  list_folds <- lapply(1:repeats,
                       function (i)
                         crossval::create_folds(y = y, k = k))

  ptm <- proc.time()

  # parallel exec.
  if (!is.null(cl) && cl > 0)
  {
    cl_SOCK <- parallel::makeCluster(cl, type = "SOCK")
    doSNOW::registerDoSNOW(cl_SOCK)
    `%op1%` <-  foreach::`%dopar%`
    `%op2%` <-  foreach::`%do%`
    nb_iter <- k * repeats

    pb <- txtProgressBar(min = 0,
                         max = k,
                         style = 3)
    progress <- function(n)
      utils::setTxtProgressBar(pb, n)
    opts <- list(progress = progress)

    i <- NULL
    j <- NULL
    res <- foreach::foreach(
      i = 1:k,
      .packages = packages,
      .combine = rbind,
      .errorhandling = errorhandling,
      .options.snow = opts,
      .verbose = verbose,
      .export = c("create_folds")
    ) %op1% {
      foreach::foreach(
        j = 1:repeats,
        .packages = packages,
        .combine = cbind,
        .verbose = FALSE,
        .errorhandling = errorhandling,
        .export = c("fit_params")
      ) %op2% {
        train_index <- -list_folds[[j]][[i]]
        test_index <-
          -train_index

        # fit
        set.seed(seed) # in case the algo is randomized
        fit_func_train <-
          function(x, y, ...)
            fit_func(x = x[train_index,],
                     y = y[train_index],
                     ...)

        fit_obj <-
          do.call(what = fit_func_train,
                  args = c(list(x = x[train_index,],
                                y = y[train_index]),
                           fit_params))

        # predict
        preds <-
          try(predict_func(fit_obj, newdata = x[test_index,]),
              silent = TRUE)
        if (class(preds) == "try-error")
        {
          preds <- try(predict_func(fit_obj, newx = x[test_index,]),
                       silent = TRUE)
          if (class(preds) == "try-error")
          {
            preds <- rep(NA, length(test_index))
          }
        }

        # measure the error
        error_measure <-
          eval_metric(preds, y[test_index])

        if (show_progress)
        {
          setTxtProgressBar(pb, i * j)
        }

        if (p == 1) {
          error_measure

        } else {
          # there is a validation set

          # predict on validation set
          preds_validation <-
            try(predict_func(fit_obj,
                             newdata = x_validation),
                silent = TRUE)

          if (class(preds_validation) == "try-error")
          {
            preds_validation <- try(predict_func(fit_obj,
                                                 newx = x_validation),
                                    silent = TRUE)

            if (class(preds_validation) == "try-error")
            {
              preds_validation <- rep(NA, length(y_validation))
            }
          }

          # measure the validation error
          c(error_measure,
            eval_metric(preds_validation, y_validation))
        }

      }

    }
    close(pb)
    snow::stopCluster(cl_SOCK)

  }  else {
    # sequential exec.

    `%op%` <-  foreach::`%do%`

    pb <- txtProgressBar(min = 0,
                         max = k,
                         style = 3)
    progress <- function(n)
      utils::setTxtProgressBar(pb, n)

    i <- NULL
    res <- foreach::foreach(
      i = 1:k,
      .packages = packages,
      .combine = rbind,
      .errorhandling = errorhandling,
      .verbose = verbose,
      .export = c("create_folds")
    ) %op% {
      if (show_progress)
      {
        setTxtProgressBar(pb, i)
      }

      temp <-
        foreach::foreach(
          j = 1:repeats,
          .packages = packages,
          .combine = cbind,
          .verbose = FALSE,
          .errorhandling = errorhandling,
          .export = c("fit_params")
        ) %op% {
          train_index <- -list_folds[[j]][[i]]
          test_index <-
            -train_index

          # fit
          set.seed(seed) # in case the algo is randomized
          fit_func_train <-
            function(x, y, ...)
              fit_func(x = x[train_index,],
                       y = y[train_index],
                       ...)

          fit_obj <-
            do.call(what = fit_func_train,
                    args = c(list(x = x[train_index,],
                                  y = y[train_index]),
                             fit_params))

          # predict
          preds <-
            try(predict_func(fit_obj, newdata = x[test_index,]),
                silent = TRUE)
          if (class(preds) == "try-error")
          {
            preds <- try(predict_func(fit_obj, newx = x[test_index,]),
                         silent = TRUE)
            if (class(preds) == "try-error")
            {
              preds <- rep(NA, length(test_index))
            }
          }

          # measure the error
          error_measure <-
            eval_metric(preds, y[test_index])

          if (p == 1) {
            error_measure

          } else {
            # there is a validation set

            # predict on validation set
            preds_validation <-
              try(predict_func(fit_obj,
                               newdata = x_validation),
                  silent = TRUE)

            if (class(preds_validation) == "try-error")
            {
              preds_validation <- try(predict_func(fit_obj,
                                                   newx = x_validation),
                                      silent = TRUE)

              if (class(preds_validation) == "try-error")
              {
                preds_validation <- rep(NA, length(y_validation))
              }
            }

            # measure the validation error
            c(error_measure,
              eval_metric(preds_validation, y_validation))
          }

        }

    }

  }

  if (show_progress)
  {
    cat("\n")
    print(proc.time() - ptm)
    cat("\n")
  }

  if (p == 1)
  {
    colnames(res) <- paste0("repeat_", 1:ncol(res))
    rownames(res) <- paste0("fold_", 1:nrow(res))

    return(list(
      folds = res,
      mean = mean(res, na.rm = TRUE),
      sd = sd(res, na.rm = TRUE),
      median = median(res, na.rm = TRUE)
    ))
  } else {
    if (repeats > 1)
    {
      colnames(res) <- paste0("repeat_", 1:ncol(res))
      rownames(res) <-
        paste0(rep(c(
          "fold_training_", "fold_validation_"
        ), k),
        rep(1:k, each = 2))
    } else {
      res <- as.numeric(res)
      names(res) <-
        paste0(rep(c(
          "fold_training_", "fold_validation_"
        ), k),
        rep(1:k, each = 2))
    }

    n_folds <- nrow(res)
    train_test_df <- res[seq(1, n_folds, by = 2), ]
    validation_df <- res[seq(2, n_folds, by = 2), ]

    return(
      list(
        folds = res,
        mean_training = mean(train_test_df),
        mean_validation = mean(validation_df),
        sd_training = sd(train_test_df),
        sd_validation = sd(validation_df),
        median_training = median(train_test_df),
        median_validation = median(validation_df)
      )
    )
  }

}
compiler::cmpfun(crossval_ml)
thierrymoudiki/crossval documentation built on Aug. 17, 2020, 5:51 a.m.