R/update.R

Defines functions update.recipe update.pipeline

Documented in update.pipeline update.recipe

#' Update a recipe with named arguments
#'
#' @param object recipe object
#' @param ... arguments and values to update the recipe object with names of arguments need to be
#'   specified by the step_id followed by a double underscore and the parameter name, e.g.
#'   `corr__threshold = 0.98`
#'
#' @return recipe
#' @export
#' @importFrom purrr map2 map_int partial
#' @importFrom stringr str_split
#' @importFrom rlang enquos list2
#' @importFrom dplyr filter select left_join
#' @importFrom stats update
#' @importFrom generics tidy
#' @importFrom parsnip varying_args
update.recipe <- function(object, ...) {

  # some checks
  args <- list2(...)

  if (length(args) == 0)
    stop("Please pass at least one named argument.", call. = FALSE)

  # process arguments into tibble
  args_tbl <- tibble(
    id = names(args),
    value = unlist(args),
    name = "")

  for (i in seq_len(nrow(args_tbl))) {
    x <- str_split(args_tbl[i, ]$id, "__", n = 2)[[1]]

    if (length(x) == 2) {
      args_tbl[i, "id"] <- x[1]
      args_tbl[i, "name"] <- x[2]
    }
  }

  if (any(is.na(args_tbl$name)))
    stop("All recipe arguments must follow convention of {step_id}__{name}")

  # get tbl of arguments which vary in the recipe
  varying_params <- varying_args(object)

  rec_df <- tidy(object) %>%
    select(!!sym("id"), !!sym("number"))

  varying_params <- varying_params %>%
    left_join(rec_df, by = "id") %>%
    filter(!!sym("varying") == TRUE)

  if (nrow(varying_params) == 0)
    stop("No arguments in recipes set to `varying()`", call. = FALSE)

  varying_params <- left_join(varying_params, args_tbl, by = c("id", "name"))

  # update recipe
  for (i in seq_len(nrow(varying_params))) {
    x <- list(varying_params[i, ]$value)
    names(x) <- varying_params[i, ]$name

    object$steps[[varying_params[i, ]$number]] <-
      do.call(partial(update, object$steps[[varying_params[i, ]$number]]), x)
  }

  object
}


#' Update a pipeline object with new recipe and model_spec parameters
#'
#' Designed to be used internally to the package
#'
#' @param object pipeline object
#' @param ... arguments to update the pipelines parameters with
#'
#' @return pipeline object
#' @export
#' @importFrom rlang list2
#' @importFrom stringr str_detect
#' @importFrom stats update
#' @importFrom parsnip set_args
update.pipeline <- function(object, ...) {
  args <- list2(...)

  # separate model_spec and recipe arguments
  args_model <- args[!str_detect(names(args), "__")]
  args_recipe <- args[str_detect(names(args), "__")]

  if (length(args_model) > 0) {
    object$model_spec <- object$model_spec %>% set_args(!!!args_model)
  }

  if (length(args_recipe) > 0) {
    object$recipe <- object$recipe %>% update(!!!args_recipe)
  }

  object
}
stevenpawley/tidycrossval documentation built on Oct. 3, 2019, 3:32 p.m.