R/sdm_grid.R

Defines functions sdm_grid fit_predict_models sdm_pipe

Documented in sdm_grid

#' Fit a grid of SDMs
#'
#' @param data_args A list of (sf) data frames.
#' @param predictor_args A list of RasterStacks with predictors.
#' @param sample_args A list of arguments for sampling background points.
#' @param resample_args A list of arguments  for creating train/test splits.
#' @param fit_args A list of arguments  for fitting the models.
#' @param predict_args A list of arguments for predicting the models.
#' @param score_args A list of arguments  for scoring the models.
#' @param n_cores Default is NULL. Can be number of cores to use for parallel processing.
#' @param keep_cols Default is NULL. Can be character vector to indicate which columns of model data frame to select.
#'
#' @return A data frame with resamples, model fits, predictions and scores.
#' @export

sdm_grid <- function(data_args = NULL,
                     predictor_args = NULL,
                     sample_args = NULL,
                     resample_args = NULL,
                     fit_args = NULL,
                     predict_args = NULL,
                     score_args = NULL,
                     n_cores = NULL,
                     keep_cols = NULL,
                     bind = TRUE) {

  fit_predict_args <- purrr::map2(fit_args, predict_args, c)
  args_list <- list(data_args, predictor_args, sample_args, resample_args, fit_predict_args)
  names(args_list) <- c("data_args", "predictor_args","sample_args", "resample_args", "fit_predict_args")

  args_list <- purrr::compact(args_list)

  if(is.data.frame(data_args)){
    args_list[["data_args"]] <- NULL
  }

  args_cross <- purrr::cross(args_list)

  args_cross <- purrr::map(args_cross, function(x) {
   x[["score_args"]] <- score_args
   return(x)
  })

  names_list <-list(data_name = names(data_args),
                    predictor_name = names(predictor_args),
                    sampling_name = names(sample_args),
                    resampling_name = names(resample_args),
                    model_name = names(fit_args))

  if(is.data.frame(data_args)){
    names_list[["data_name"]] <- NULL
  }

  names_list <- purrr::compact(names_list)

  names_cross <- purrr::cross(names_list)

  args_cross <- purrr::map2(args_cross, names_cross, function(x,y){
    x[["name_args"]] <- y
    return(x)
  })

  if(is.data.frame(data_args)){
    args_cross <- purrr::map(args_cross, function(x) { x[["data_args"]] <- data_args; return(x)})
  }


  if (is.null(n_cores)) {

    model_list <- purrr::map(args_cross, function(args) {

      out <- sdm_pipe(args)

      if(!is.null(keep_cols)) {
        out <- dplyr::select(out, dplyr::one_of(keep_cols))
      }
      return(out)


    })

    if(bind) {
      do.call(bind_rows, model_list)
    } else {
      model_list
    }


  }

  else {

    cl <- parallel::makeCluster(n_cores)
    doParallel::registerDoParallel(cl)

    loaded_pkgs <- .packages()

    model_list <- foreach::foreach(i = args_cross,
                                   .packages = loaded_pkgs) %dopar% {

                                     out <- sdm_pipe(i)

                                     if(!is.null(keep_cols)) {
                                       out <- dplyr::select(out, dplyr::one_of(keep_cols))
                                     }

                                     out

                                   }

    parallel::stopCluster(cl)

    if(bind) {
      do.call(bind_rows, model_list)
    } else {
      model_list
    }
  }

}


fit_predict_models <- function(data, model_call, drop_cols=NULL, select_cols = NULL, ...) {

  fit <- fit_models(data, model_call = model_call, drop_cols = drop_cols)
  predict_models(fit, select_cols = select_cols, ...)
}


sdm_pipe <- function(args) {

  data <- args[["data_args"]]

  if(!is.null(args[["sample_args"]])) {
    args[["sample_args"]]$x <- data
    data <- do.call(sample_background, args[["sample_args"]])
  }

  if(!is.null(args[["predictor_args"]])) {
    args[["predictor_args"]]$x <- data
    data <- do.call(extract_preds, args[["predictor_args"]])
  }

  if(!is.null(args[["resample_args"]])) {
    args[["resample_args"]]$data <- data
    data <- do.call(split_train_test, args[["resample_args"]])
  }

  if(!is.null(args[["fit_predict_args"]])) {
    args[["fit_predict_args"]]$data <- data
    data <- rlang::invoke(fit_predict_models, args[["fit_predict_args"]])
  }

  if(!is.null(args[["score_args"]])) {
    args[["score_args"]]$data <- data
    data <- do.call(score_models, args[["score_args"]])
  }

  args[["name_args"]]$data <- data
  do.call(name_models, args[["name_args"]])

}
juoe/sdmflow documentation built on Feb. 23, 2020, 7:38 p.m.