R/full_models.R

Defines functions full_models

Documented in full_models

#' Get full models from sdmflow data.frame
#'
#' @param data A data frame containing models and resamples
#' @param fit_args List of arguments that was used for fitting the models, including model calls.
#'
#' @return A data frame containing the full model fits using the entire training data.
#' @export

full_models <- function(data, fit_args, n_cores = NULL){

  zero_var <- function(train) {
    names(unlist(which(!purrr::map(train, function(x) length(unique(x))) > 1)))
  }

  dat <- data %>%
    dplyr::distinct(data_name, predictor_name, sampling_name, resampling_name, model_name, .keep_all = TRUE) %>%
    dplyr::mutate(full_dat = purrr::map2(train, test, function(train,test) {
      rbind(as.data.frame(train), as.data.frame(test)) })) %>%
    dplyr::select(full_dat, data_name, predictor_name, sampling_name, resampling_name, model_name) %>%
    mutate(call = fit_args[.$model_name])

  if(is.null(n_cores)){
    out <- dat %>% mutate(model = purrr::map2(full_dat, call, function(train, call){

      if (!is.null(call$drop_cols)){
        train <- dplyr::select(train, -one_of(call$drop_cols))
      }

      if (!is.null(call$drop_zero_var_preds)){
        if (call$drop_zero_var_preds) {
          train <- dplyr::select(train, -dplyr::one_of(zero_var(train)))
        }
      }

      eval(call$model_call)

    }))
  } else {
    cl <- parallel::makeCluster(n_cores)
    doParallel::registerDoParallel(cl)

    loaded_pkgs <- .packages()

    model_list <- foreach::foreach(r = iterators::iter(dat, by = "row"),
                                   .packages = loaded_pkgs) %dopar% {

                                    train <- r$full_dat[[1]]
                                    call <- r$call[[1]]

                                    if (!is.null(call$drop_cols)){
                                      train <- dplyr::select(train, -one_of(call$drop_cols))
                                    }

                                    if (!is.null(call$drop_zero_var_preds)){
                                      if (call$drop_zero_var_preds) {
                                        train <- dplyr::select(train, -dplyr::one_of(zero_var(train)))
                                      }
                                    }

                                    eval(call$model_call)

                                   }
    parallel::stopCluster(cl)
    out <- dat %>% mutate(model = model_list)

  }

  out %>%
    dplyr::rename(train = full_dat) %>%
    dplyr::select(-one_of("resampling_name"))

}
juoe/sdmflow documentation built on Feb. 23, 2020, 7:38 p.m.