#' Convert Resampling Objects to Other Formats
#'
#' These functions can convert resampling objects between
#' \pkg{rsample} and \pkg{caret}.
#'
#' @param object An `rset` object. Currently,
#' `nested_cv` is not supported.
#' @return `rsample2caret` returns a list that mimics the
#' `index` and `indexOut` elements of a
#' `trainControl` object. `caret2rsample` returns an
#' `rset` object of the appropriate class.
#' @export
#' @importFrom purrr map
rsample2caret <- function(object, data = c("analysis", "assessment")) {
if(!inherits(object, "rset"))
stop("`object` must be an `rset`", call. = FALSE)
data <- match.arg(data)
in_ind <- purrr::map(object$splits, as.integer, data = "analysis")
names(in_ind) <- labels(object)
out_ind <- purrr::map(object$splits, as.integer, data = "assessment")
names(out_ind) <- names(in_ind)
list(index = in_ind, indexOut = out_ind)
}
#' @rdname rsample2caret
#' @param ctrl An object produced by `trainControl` that has
#' had the `index` and `indexOut` elements populated by
#' integers. One method of getting this is to extract the
#' `control` objects from an object produced by `train`.
#' @param data The data that was originally used to produce the
#' `ctrl` object.
#' @importFrom purrr map map2
#' @importFrom tibble tibble
#' @importFrom dplyr bind_cols
#' @export
caret2rsample <- function(ctrl, data = NULL) {
if (is.null(data))
stop("Must supply original data", call. = FALSE)
if (!any(names(ctrl) == "index"))
stop("`ctrl` should have an element `index`", call. = FALSE)
if (!any(names(ctrl) == "indexOut"))
stop("`ctrl` should have an element `indexOut`", call. = FALSE)
if (is.null(ctrl$index))
stop("`ctrl$index` should be populated with integers", call. = FALSE)
if (is.null(ctrl$indexOut))
stop("`ctrl$indexOut` should be populated with integers", call. = FALSE)
indices <- purrr::map2(ctrl$index, ctrl$indexOut, extract_int)
id_data <- names(indices)
indices <- unname(indices)
indices <- purrr::map(indices, add_data, y = data)
indices <-
map(indices, add_rsplit_class, cl = map_rsplit_method(ctrl$method))
indices <- tibble::tibble(splits = indices)
if (ctrl$method %in% c("repeatedcv", "adaptive_cv")) {
id_data <- strsplit(id_data, split = ".", fixed = TRUE)
id_data <- tibble::tibble(
id = vapply(id_data, function(x)
x[2], character(1)),
id2 = vapply(id_data, function(x)
x[1], character(1))
)
} else {
id_data <- tibble::tibble(id = id_data)
}
out <- dplyr::bind_cols(indices, id_data)
attrib <- map_attr(ctrl)
for (i in names(attrib))
attr(out, i) <- attrib[[i]]
out <- add_rset_class(out, map_rset_method(ctrl$method))
out
}
extract_int <- function(x, y)
list(in_id = x, out_id = y)
add_data <- function(x, y)
c(list(data = y), x)
add_rsplit_class <- function(x, cl) {
class(x) <- c("rsplit", cl)
x
}
add_rset_class <- function(x, cl) {
class(x) <- c(cl, "rset", "tbl_df", "tbl", "data.frame")
x
}
map_rsplit_method <- function(method) {
out <- switch(
method,
cv = , repeatedcv = , adaptive_cv = "vfold_split",
boot = , boot_all =, boot632 = , optimism_boot = , adaptive_boot = "boot_split",
LOOCV = "loo_split",
LGOCV = , adaptive_LGOCV = "mc_split",
timeSlice = "rof_split",
"error"
)
if (out == "error")
stop("Resampling method `",
method,
"` cannot be converted into an `rsplit` object",
call. = FALSE)
out
}
map_rset_method <- function(method) {
out <- switch(
method,
cv = , repeatedcv = , adaptive_cv = "vfold_cv",
boot = , boot_all =, boot632 = , optimism_boot = , adaptive_boot = "bootstraps",
LOOCV = "loo_cv",
LGOCV = , adaptive_LGOCV = "mc_cv",
timeSlice = "rolling_origin",
"error"
)
if (out == "error")
stop("Resampling method `",
method,
"` cannot be converted into an `rset` object",
call. = FALSE)
out
}
map_attr <- function(object) {
if (grepl("cv$", object$method)) {
out <- list(v = object$number,
repeats = ifelse(!is.na(object$repeats),
object$repeats, 1),
strata = TRUE)
} else if (grepl("boot", object$method)) {
out <- list(times = object$number,
apparent = FALSE,
strata = FALSE)
} else if (grepl("LGOCV$", object$method)) {
out <- list(times = object$number,
prop = object$p,
strata = FALSE)
} else if (object$method == "LOOCV") {
out <- list()
} else if (object$method == "timeSlice") {
out <- list(
initial = object$initialWindow,
assess = object$horizon,
cumulative = !object$fixedWindow,
skip = object$skip
)
} else {
stop("Method", object$method, "cannot be converted")
}
out
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.