R/split_train_test.R

Defines functions split_train_test

Documented in split_train_test

#' Create training-testing data splits for crossvalidation.
#'
#' @param data A (sf) data frame.
#' @param method Either "kfold", "loo", "mc", "cluster" or "grid".
#' @param ... Arguments passed on to the resampling functions.
#'
#' @return A data frame with n / prod(n) rows and columns test and train.
#' test and train are list-columns containing \code{\link{resample}} objects.
#' @export

split_train_test <- function(data, method = "kfold", ...) {

  if ("sf" %in% class(data) & method != "grid") {
    data <- st_set_geometry(data, NULL)
  }

  if (method == "kfold") {
    out <- modelr::crossv_kfold(data, ...)
  } else if (method == "mc") {
    out <- modelr::crossv_mc(data, ...)
  } else if (method == "cluster") {
    out <- crossv_cluster(data, ...)
  }  else if (method == "grid") {
    out <- crossv_grid(data, ...)
  } else if (method == "predstrat") {
    out <- crossv_predstrat(data, ...)
  } else if (method == "loo") {
    out <- crossv_loo(data, ...)
  } else if (method == "temporal_blocking") {
    out <- crossv_temporal_blocking(data, ...)
  } else if (method == "year_window") {
    out <- crossv_year_window(data, ...)
  } else {
    stop("Unknown method name.")
  }

  return(out)

}
juoe/sdmflow documentation built on Feb. 23, 2020, 7:38 p.m.