R/caret.R

#' Convert Resampling Objects to Other Formats
#'
#' These functions can convert resampling objects between
#'  \pkg{rsample} and \pkg{caret}.
#'
#' @param object An `rset` object. Currently,
#'  `nested_cv` is not supported.
#' @return `rsample2caret` returns a list that mimics the
#'  `index` and `indexOut` elements of a
#'  `trainControl` object. `caret2rsample` returns an
#'  `rset` object of the appropriate class.
#' @export
#' @importFrom purrr map
rsample2caret <- function(object, data = c("analysis", "assessment")) {
  if(!inherits(object, "rset"))
    stop("`object` must be an `rset`", call. = FALSE)
  data <- match.arg(data)
  in_ind <- purrr::map(object$splits, as.integer, data = "analysis")
  names(in_ind) <- labels(object)
  out_ind <- purrr::map(object$splits, as.integer, data = "assessment")
  names(out_ind) <- names(in_ind)
  list(index = in_ind, indexOut = out_ind)
}

#' @rdname rsample2caret
#' @param ctrl An object produced by `trainControl` that has
#'  had the `index` and `indexOut` elements populated by
#'  integers. One method of getting this is to extract the
#'  `control` objects from an object produced by `train`.
#' @param data The data that was originally used to produce the
#'  `ctrl` object.
#' @importFrom purrr map map2
#' @importFrom tibble tibble
#' @importFrom dplyr bind_cols
#' @export
caret2rsample <- function(ctrl, data = NULL) {
  if (is.null(data))
    stop("Must supply original data", call. = FALSE)
  if (!any(names(ctrl) == "index"))
    stop("`ctrl` should have an element `index`", call. = FALSE)
  if (!any(names(ctrl) == "indexOut"))
    stop("`ctrl` should have an element `indexOut`", call. = FALSE)
  if (is.null(ctrl$index))
    stop("`ctrl$index` should be populated with integers", call. = FALSE)
  if (is.null(ctrl$indexOut))
    stop("`ctrl$indexOut` should be populated with integers", call. = FALSE)

  indices <- purrr::map2(ctrl$index, ctrl$indexOut, extract_int)
  id_data <- names(indices)
  indices <- unname(indices)
  indices <- purrr::map(indices, add_data, y = data)
  indices <-
    map(indices, add_rsplit_class, cl = map_rsplit_method(ctrl$method))
  indices <- tibble::tibble(splits = indices)
  if (ctrl$method %in% c("repeatedcv", "adaptive_cv")) {
    id_data <- strsplit(id_data, split = ".", fixed = TRUE)
    id_data <- tibble::tibble(
      id  = vapply(id_data, function(x)
        x[2], character(1)),
      id2 = vapply(id_data, function(x)
        x[1], character(1))
    )
  } else {
    id_data <- tibble::tibble(id = id_data)
  }
  out <- dplyr::bind_cols(indices, id_data)
  attrib <- map_attr(ctrl)
  for (i in names(attrib))
    attr(out, i) <- attrib[[i]]
  out <- add_rset_class(out, map_rset_method(ctrl$method))
  out
}

extract_int <- function(x, y)
  list(in_id = x, out_id = y)

add_data <- function(x, y)
  c(list(data = y), x)

add_rsplit_class <- function(x, cl) {
  class(x) <- c("rsplit", cl)
  x
}

add_rset_class <- function(x, cl) {
  class(x) <- c(cl, "rset", "tbl_df", "tbl", "data.frame")
  x
}

map_rsplit_method <- function(method) {
  out <- switch(
    method,
    cv = , repeatedcv = , adaptive_cv = "vfold_split",
    boot = , boot_all =, boot632 = , optimism_boot = , adaptive_boot = "boot_split",
    LOOCV = "loo_split",
    LGOCV = , adaptive_LGOCV = "mc_split",
    timeSlice = "rof_split",
    "error"
  )
  if (out == "error")
    stop("Resampling method `",
         method,
         "` cannot be converted into an `rsplit` object",
         call. = FALSE)
  out
}

map_rset_method <- function(method) {
  out <- switch(
    method,
    cv = , repeatedcv = , adaptive_cv = "vfold_cv",
    boot = , boot_all =, boot632 = , optimism_boot = , adaptive_boot = "bootstraps",
    LOOCV = "loo_cv",
    LGOCV = , adaptive_LGOCV = "mc_cv",
    timeSlice = "rolling_origin",
    "error"
  )
  if (out == "error")
    stop("Resampling method `",
         method,
         "` cannot be converted into an `rset` object",
         call. = FALSE)
  out
}


map_attr <- function(object) {
  if (grepl("cv$", object$method)) {
    out <- list(v = object$number,
                repeats = ifelse(!is.na(object$repeats),
                             object$repeats, 1),
                strata = TRUE)
  } else if (grepl("boot", object$method)) {
    out <- list(times = object$number,
                apparent = FALSE,
                strata = FALSE)
  } else if (grepl("LGOCV$", object$method)) {
    out <- list(times = object$number,
                prop = object$p,
                strata = FALSE)
  } else if (object$method == "LOOCV") {
    out <- list()
  } else if (object$method == "timeSlice") {
    out <- list(
      initial = object$initialWindow,
      assess = object$horizon,
      cumulative = !object$fixedWindow,
      skip = object$skip
    )
  } else {
    stop("Method", object$method, "cannot be converted")
  }
  out
}
topepo/rsample documentation built on May 4, 2019, 4:25 p.m.