#' Fit a grid of SDMs
#'
#' @param data_args A list of (sf) data frames.
#' @param predictor_args A list of RasterStacks with predictors.
#' @param sample_args A list of arguments for sampling background points.
#' @param resample_args A list of arguments for creating train/test splits.
#' @param fit_args A list of arguments for fitting the models.
#' @param predict_args A list of arguments for predicting the models.
#' @param score_args A list of arguments for scoring the models.
#' @param n_cores Default is NULL. Can be number of cores to use for parallel processing.
#' @param keep_cols Default is NULL. Can be character vector to indicate which columns of model data frame to select.
#'
#' @return A data frame with resamples, model fits, predictions and scores.
#' @export
sdm_grid <- function(data_args = NULL,
predictor_args = NULL,
sample_args = NULL,
resample_args = NULL,
fit_args = NULL,
predict_args = NULL,
score_args = NULL,
n_cores = NULL,
keep_cols = NULL,
bind = TRUE) {
fit_predict_args <- purrr::map2(fit_args, predict_args, c)
args_list <- list(data_args, predictor_args, sample_args, resample_args, fit_predict_args)
names(args_list) <- c("data_args", "predictor_args","sample_args", "resample_args", "fit_predict_args")
args_list <- purrr::compact(args_list)
if(is.data.frame(data_args)){
args_list[["data_args"]] <- NULL
}
args_cross <- purrr::cross(args_list)
args_cross <- purrr::map(args_cross, function(x) {
x[["score_args"]] <- score_args
return(x)
})
names_list <-list(data_name = names(data_args),
predictor_name = names(predictor_args),
sampling_name = names(sample_args),
resampling_name = names(resample_args),
model_name = names(fit_args))
if(is.data.frame(data_args)){
names_list[["data_name"]] <- NULL
}
names_list <- purrr::compact(names_list)
names_cross <- purrr::cross(names_list)
args_cross <- purrr::map2(args_cross, names_cross, function(x,y){
x[["name_args"]] <- y
return(x)
})
if(is.data.frame(data_args)){
args_cross <- purrr::map(args_cross, function(x) { x[["data_args"]] <- data_args; return(x)})
}
if (is.null(n_cores)) {
model_list <- purrr::map(args_cross, function(args) {
out <- sdm_pipe(args)
if(!is.null(keep_cols)) {
out <- dplyr::select(out, dplyr::one_of(keep_cols))
}
return(out)
})
if(bind) {
do.call(bind_rows, model_list)
} else {
model_list
}
}
else {
cl <- parallel::makeCluster(n_cores)
doParallel::registerDoParallel(cl)
loaded_pkgs <- .packages()
model_list <- foreach::foreach(i = args_cross,
.packages = loaded_pkgs) %dopar% {
out <- sdm_pipe(i)
if(!is.null(keep_cols)) {
out <- dplyr::select(out, dplyr::one_of(keep_cols))
}
out
}
parallel::stopCluster(cl)
if(bind) {
do.call(bind_rows, model_list)
} else {
model_list
}
}
}
fit_predict_models <- function(data, model_call, drop_cols=NULL, select_cols = NULL, ...) {
fit <- fit_models(data, model_call = model_call, drop_cols = drop_cols)
predict_models(fit, select_cols = select_cols, ...)
}
sdm_pipe <- function(args) {
data <- args[["data_args"]]
if(!is.null(args[["sample_args"]])) {
args[["sample_args"]]$x <- data
data <- do.call(sample_background, args[["sample_args"]])
}
if(!is.null(args[["predictor_args"]])) {
args[["predictor_args"]]$x <- data
data <- do.call(extract_preds, args[["predictor_args"]])
}
if(!is.null(args[["resample_args"]])) {
args[["resample_args"]]$data <- data
data <- do.call(split_train_test, args[["resample_args"]])
}
if(!is.null(args[["fit_predict_args"]])) {
args[["fit_predict_args"]]$data <- data
data <- rlang::invoke(fit_predict_models, args[["fit_predict_args"]])
}
if(!is.null(args[["score_args"]])) {
args[["score_args"]]$data <- data
data <- do.call(score_models, args[["score_args"]])
}
args[["name_args"]]$data <- data
do.call(name_models, args[["name_args"]])
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.