R/misc.R

Defines functions get_from_info recipes_names_outcomes recipes_names_predictors recipes_remove_cols remove_original_cols vec_paste0 conditionMessage.recipes_error recipes_error_context stop_recipes_step stop_recipes check_new_data uses_dim_red dimred_data eval_dimred_call changelog print_col_names cast is_tune get_keep_original_cols check_training_set check_nominal_type rand_id check_name simple_terms to_character sel2char is_trained check_type is_qual is_skipable detect_step fully_trained prepare printer ellipse_check merge_term_info train_info flatten_na n_complete_rows strings2factors has_lvls get_levels fun_calls dummy_extract_names dummy_names names0 get_rhs_vars get_lhs_vars

Documented in check_name check_new_data check_type detect_step dummy_extract_names dummy_names ellipse_check fully_trained get_keep_original_cols is_trained names0 prepare printer rand_id recipes_names_outcomes recipes_names_predictors recipes_remove_cols remove_original_cols sel2char

get_lhs_vars <- function(formula, data) {
  if (!rlang::is_formula(formula)) {
    formula <- as.formula(formula)
  }
  ## Want to make sure that multiple outcomes can be expressed as
  ## additions with no cbind business and that `.` works too (maybe)
  new_formula <- rlang::new_formula(lhs = NULL, rhs = f_lhs(formula))
  get_rhs_vars(new_formula, data)
}

get_rhs_vars <- function(formula, data, no_lhs = FALSE) {
  if (!rlang::is_formula(formula)) {
    formula <- as.formula(formula)
  }
  if (no_lhs) {
    formula <- rlang::new_formula(lhs = NULL, rhs = f_rhs(formula))
  }

  ## This will need a lot of work to account for cases with `.`
  ## or embedded functions like `Sepal.Length + poly(Sepal.Width)`.
  ## or should it? what about Y ~ log(x)?
  ## Answer: when called from `form2args`, the function
  ## `inline_check` stops when in-line functions are used.

  outcomes_names <- all.names(
    rlang::f_lhs(formula), 
    functions = FALSE, 
    unique = TRUE
  )

  predictors_names <- all.names(
    rlang::f_rhs(formula),
    functions = FALSE,
    unique = TRUE
  )
  
  if (any(predictors_names == ".")) {
    predictors_names <- predictors_names[predictors_names != "."]
    predictors_names <- c(predictors_names, colnames(data))
    predictors_names <- unique(predictors_names)
  }

  if (length(predictors_names) > 0 && length(outcomes_names) > 0) {
    predictors_names <- setdiff(predictors_names, outcomes_names)
  }

  predictors_names
}

#' Naming Tools
#'
#' `names0` creates a series of `num` names with a common prefix.
#'  The names are numbered with leading zeros (e.g.
#'  `prefix01`-`prefix10` instead of `prefix1`-`prefix10`).
#'  `dummy_names` can be used for renaming unordered and ordered
#'  dummy variables (in [step_dummy()]).
#'
#' @param num A single integer for how many elements are created.
#' @param prefix A character string that will start each name.
#' @param var A single string for the original factor name.
#' @param lvl A character vectors of the factor levels (in order).
#'  When used with [step_dummy()], `lvl` would be the suffixes
#'  that result _after_ `model.matrix` is called (see the
#'  example below).
#' @param ordinal A logical; was the original factor ordered?
#' @param sep A single character value for the separator between the names and
#'  levels.
#'
#' @details When using `dummy_names()`, factor levels that are not valid
#'  variable names (e.g. "some text  with spaces") will be changed to valid
#'  names by [base::make.names()]; see example below. This function will also
#'  change the names of ordinal dummy variables. Instead of values such as
#'  "`.L`", "`.Q`", or "`^4`", ordinal dummy variables are given simple integer
#'  suffixes such as "`_1`", "`_2`", etc.
#'
#' @return `names0` returns a character string of length `num` and
#'  `dummy_names` generates a character vector the same length as
#'  `lvl`.
#'
#' @seealso [developer_functions]
#'
#' @examples
#' names0(9, "a")
#' names0(10, "a")
#'
#' example <- data.frame(
#'   x = ordered(letters[1:5]),
#'   y = factor(LETTERS[1:5]),
#'   z = factor(paste(LETTERS[1:5], 1:5))
#' )
#'
#' dummy_names("y", levels(example$y)[-1])
#' dummy_names("z", levels(example$z)[-1])
#'
#' after_mm <- colnames(model.matrix(~x, data = example))[-1]
#' after_mm
#' levels(example$x)
#'
#' dummy_names("x", substring(after_mm, 2), ordinal = TRUE)
#' @export
names0 <- function(num, prefix = "x") {
  if (num < 1) {
    cli::cli_abort("{.arg num} should be > 0.")
  }
  ind <- format(seq_len(num))
  ind <- gsub(" ", "0", ind)
  paste0(prefix, ind)
}

#' @export
#' @rdname names0
dummy_names <- function(var, lvl, ordinal = FALSE, sep = "_") {
  # Work around `paste()` recycling bug with 0 length input
  args <- vctrs::vec_recycle_common(var, lvl)
  var <- args[[1]]
  lvl <- args[[2]]

  if (!ordinal) {
    nms <- paste(var, make.names(lvl), sep = sep)
  } else {
    # assuming they are in order:
    nms <- paste0(var, names0(length(lvl), sep))
  }

  nms
}

#' @export
#' @rdname names0
dummy_extract_names <- function(var, lvl, ordinal = FALSE, sep = "_") {
  # Work around `paste()` recycling bug with 0 length input
  args <- vctrs::vec_recycle_common(var, lvl)
  var <- args[[1]]
  lvl <- args[[2]]

  if (!ordinal) {
    nms <- paste(var, make.names(lvl), sep = sep)
  } else {
    # assuming they are in order:
    nms <- paste0(var, names0(length(lvl), sep))
  }

  while (any(duplicated(nms))) {
    dupe_count <- vapply(seq_along(nms), function(i) {
      sum(nms[i] == nms[1:i])
    }, 1L)
    nms[dupe_count > 1] <- paste(
      nms[dupe_count > 1],
      dupe_count[dupe_count > 1],
      sep = sep
    )
  }
  nms
}

fun_calls <- function(f, data) {
  setdiff(all.names(f), colnames(data))
}

get_levels <- function(x) {
  if (!is.factor(x) & !is.character(x)) {
    return(list(values = NA, ordered = NA))
  }
  out <-
    if (is.factor(x)) {
      list(
        values = levels(x),
        ordered = is.ordered(x),
        factor = TRUE
      )
    } else {
      list(
        values = sort(unique(x)),
        ordered = FALSE,
        factor = FALSE
      )
    }
  out
}

has_lvls <- function(info) {
  !vapply(info, function(x) all(is.na(x$values)), c(logic = TRUE))
}

strings2factors <- function(x, info) {
  check_lvls <- has_lvls(info)
  if (!any(check_lvls)) {
    return(x)
  }
  info <- info[check_lvls]
  vars <- names(info)
  info <- info[vars %in% names(x)]
  for (i in seq_along(info)) {
    lcol <- names(info)[i]
    x[, lcol] <-
      factor(
        as.character(x[[lcol]]),
        levels = info[[i]]$values,
        ordered = info[[i]]$ordered,
        exclude = NULL
      )
  }
  x
}


# ------------------------------------------------------------------------------

# `complete.cases` fails on list columns. This version counts a list column
# as missing if _all_ values are missing. For if a list vector element is a
# data frame with one missing value, that element of the list column will
# be counted as complete.
n_complete_rows <- function(x) {
  is_list_col <- purrr::map_lgl(x, is.list)
  pos_list_cols <- which(is_list_col)

  for (pos_list_col in pos_list_cols) {
    x[[pos_list_col]] <- purrr::map_lgl(x[[pos_list_col]], flatten_na)
  }

  sum(complete.cases(x))
}

flatten_na <- function(x) {
  if (all(is.na(x))) {
    NA
  } else {
    FALSE
  }
}

## short summary of training set
train_info <- function(x) {
  data.frame(
    nrows = nrow(x),
    ncomplete = n_complete_rows(x)
  )
}

# ------------------------------------------------------------------------------



## `merge_term_info` takes the information on the current variable
## list and the information on the new set of variables (after each step)
## and merges them. Special attention is paid to cases where the
## _type_ of data is changed for a common column in the data.


merge_term_info <- function(.new, .old) {
  # Look for conflicts where the new variable type is different from
  # the original value
  .new %>%
    dplyr::rename(new_type = type) %>%
    dplyr::left_join(.old, by = "variable", multiple = "all") %>%
    dplyr::mutate(
      type = ifelse(is.na(type), "other", "type"),
      type = ifelse(type != new_type, new_type, type)
    ) %>%
    dplyr::select(-new_type)
}


#' Check for Empty Ellipses
#'
#' `ellipse_check()` is deprecated. Instead, empty selections should be
#' supported by all steps.
#'
#' @param ... Arguments pass in from a call to `step`.
#' 
#' @return `ellipse_check()`: If not empty, a list of quosures. If empty, an 
#'   error is thrown.
#' 
#' @keywords internal
#' @rdname recipes-internal
#' @export
ellipse_check <- function(...) {
  terms <- quos(...)
  if (is_empty(terms)) {
    cli::cli_abort(
      c(
        "!" = "Please supply at least one variable specification.",
        "i" = "See {.help [?selections](recipes::selections)} \\
              for more information."
      )
    )
  }
  terms
}

#' Printing Workhorse Function
#'
#' `printer()` is used for printing steps.
#'
#' @param tr_obj A character vector of names that have been
#'  resolved during preparing the recipe (e.g. the `columns` object
#'  of [step_log()]).
#' @param untr_obj An object of selectors prior to prepping the
#'  recipe (e.g. `terms` in most steps).
#' @param trained A logical for whether the step has been trained.
#' @param width An integer denoting where the output should be wrapped.
#' 
#' @return `printer()`: `NULL`, invisibly.
#' 
#' @keywords internal
#' @rdname recipes-internal
#' @export
printer <- function(tr_obj = NULL,
                    untr_obj = NULL,
                    trained = FALSE,
                    width = max(20, options()$width - 30)) {
  if (trained) {
    txt <- format_ch_vec(tr_obj, width = width)
  } else {
    txt <- format_selectors(untr_obj, width = width)
  }

  if (length(txt) == 0L) {
    txt <- "<none>"
  }

  cat(txt)

  if (trained) {
    cat(" [trained]\n")
  } else {
    cat("\n")
  }

  invisible(NULL)
}


#' @keywords internal
#' @rdname recipes-internal
#' @export
prepare <- function(x, ...) {
  cli::cli_abort(
    "As of version 0.0.1.9006 please use {.fn prep} instead of {.fn prepare}."
  )
}


#' Check to see if a recipe is trained/prepared
#'
#' @param x A recipe
#' @return A logical which is true if all of the recipe steps have been run
#'  through `prep`. If no steps have been added to the recipe, `TRUE` is
#'  returned only if the recipe has been prepped.
#' @export
#'
#' @seealso [developer_functions]
#'
#' @examples
#' rec <- recipe(Species ~ ., data = iris) %>%
#'   step_center(all_numeric())
#'
#' rec %>% fully_trained()
#'
#'
#' rec %>%
#'   prep(training = iris) %>%
#'   fully_trained()
fully_trained <- function(x) {
  if (is.null(x$steps)) {
    if (any(names(x) == "last_term_info")) {
      res <- TRUE
    } else {
      res <- FALSE
    }
  } else {
    is_tr <- purrr::map_lgl(x$steps, function(x) isTRUE(x$trained))
    res <- all(is_tr)
  }
  res
}

#' Detect if a particular step or check is used in a recipe
#'
#' @param recipe A recipe to check.
#' @param name Character name of a step or check, omitted the prefix. That is,
#'   to check if `step_intercept` is present, use `name = intercept`.
#' @return Logical indicating if recipes contains given step.
#' @export
#'
#' @seealso [developer_functions]
#'
#' @examples
#' rec <- recipe(Species ~ ., data = iris) %>%
#'   step_intercept()
#'
#' detect_step(rec, "intercept")
detect_step <- function(recipe, name) {
  name %in% tidy(recipe)$type
}

# to be used in a recipe
is_skipable <- function(x) {
  if (all("skip" != names(x))) {
    return(FALSE)
  } else {
    return(x$skip)
  }
}

is_qual <- function(x) {
  is.factor(x) | is.character(x)
}

#' Quantitatively check on variables
#'
#' This internal function is to be used in the prep function to ensure that
#'   the type of the variables matches the expectation. Throws an error if
#'   check fails.
#' @param dat A data frame or tibble of the training data.
#' @param quant A logical indicating whether the data is expected to be numeric
#'   (TRUE) or a factor/character (FALSE). Is ignored if `types` is specified.
#' @param types Character vector of allowed types. Following the same types as
#'   [has_role()]. See details for more.
#'
#' @details
#' Using `types` is a more fine-tuned way to use this. function compared to using
#' `quant`. `types` should specify all allowed types as designated by
#' [.get_data_types]. Suppose you want to allow doubles, integers, characters,
#' factors and ordered factors, then you should specify
#' `types = c("double", "integer", "string", "factor", "ordered")` to get a
#' clear error message.
#'
#' @seealso [developer_functions]
#'
#' @export
#' @keywords internal
check_type <- function(dat, quant = TRUE, types = NULL, call = caller_env()) {
  if (is.null(types)) {
    if (quant) {
      all_good <- vapply(dat, is.numeric, logical(1))
      types <- "numeric"
    } else {
      all_good <- vapply(dat, is_qual, logical(1))
      types <- "factor or character"
    }
  } else {
    all_good <- purrr::map_lgl(get_types(dat)$type, ~ any(.x %in% types))
  }

  if (!all(all_good)) {
    info <- get_types(dat)[!all_good, ]
    classes <- map_chr(info$type, function(x) x[1])
    counts <- vctrs::vec_split(info$variable, classes)
    counts$count <- lengths(counts$val)
    counts$text_len <- cli::console_width() - 18 - (counts$count > 1) -
      nchar(counts$key) - (counts$count > 2)

    # cli::ansi_collapse() doesn't work for length(x) == 1
    # https://github.com/r-lib/cli/issues/590
    variable_collapse <- function(x, width) {
      x <- paste0("{.var ", x, "}")
      if (length(x) == 1) {
        res <- cli::ansi_strtrim(x, width = width)
      } else if (length(x) == 2) {
        res <- cli::ansi_collapse(
          x, last = " and ", width = width, style = "head"
        )
      } else {
        res <- cli::ansi_collapse(
          x, width = width, style = "head"
        )
      }
      res
    }

    problems <- glue::glue_data(
      counts,
      "{count} {key} variable{ifelse(count == 1, '', 's')} \\
      found: {purrr::map2_chr(val, text_len, variable_collapse)}"
    )
    names(problems) <- rep("*", length(problems))

    message <- "All columns selected for the step should be {.or {types}}."

    cli::cli_abort(
      c("x" = message, problems),
      call = call
    )
  }

  invisible(all_good)
}

## Support functions

#' Check to see if a step or check as been trained
#' 
#' `is_trained()` is a helper function that returned a single logical to
#' indicate whether a recipe is traine or not.
#' 
#' @param x a step object.
#' @return `is_trained()`: A single logical.
#' 
#' @seealso [developer_functions]
#' @keywords internal
#' 
#' @rdname recipes-internal
#' @export
is_trained <- function(x) {
  x$trained
}


#' Convert Selectors to Character
#'
#' `sel2char()` takes a list of selectors (e.g. `terms` in most steps) and 
#' returns a character vector version for printing.
#' 
#' @param x A list of selectors
#' @return `sel2char()`: A character vector.
#' 
#' @seealso [developer_functions]
#' 
#' @keywords internal
#' @rdname recipes-internal
#' @export
sel2char <- function(x) {
  unname(map_chr(x, to_character))
}

to_character <- function(x) {
  if (rlang::is_quosure(x)) {
    res <- rlang::quo_text(x)
  } else {
    res <- as_character(x)
  }
  res
}


simple_terms <- function(x, ...) {
  if (is_trained(x)) {
    res <- tibble(terms = unname(x$columns))
  } else {
    term_names <- sel2char(x$terms)
    res <- tibble(terms = term_names)
  }
  res
}

#' check that newly created variable names don't overlap
#'
#' `check_name` is to be used in the bake function to ensure that
#'   newly created variable names don't overlap with existing names.
#'   Throws an error if check fails.
#' @param res A data frame or tibble of the newly created variables.
#' @param new_data A data frame or tibble passed to the bake function.
#' @param object A trained object passed to the bake function.
#' @param newname A string of variable names if prefix isn't specified
#'   in the trained object.
#' @param names A logical determining if the names should be set using
#' the names function (TRUE) or colnames function (FALSE).
#' @param call The execution environment of a currently running function, e.g.
#'   `caller_env()`. The function will be mentioned in error messages as the
#'   source of the error. See the call argument of [rlang::abort()] for more
#'   information.
#'
#' @seealso [developer_functions]
#'
#' @export
#' @keywords internal
check_name <- function(res, new_data, object, newname = NULL, names = FALSE,
                       call = caller_env()) {
  if (is.null(newname)) {
    newname <- names0(ncol(res), object$prefix)
  }
  new_data_names <- colnames(new_data)
  intersection <- new_data_names %in% newname
  if (any(intersection)) {
    nms <- new_data_names[intersection]
    cli::cli_abort(
      c("Name collision occurred. The following variable names already exist:",
        "*" = "{.var {nms}}"),
      call = call
    )

  }
  if (names) {
    names(res) <- newname
  } else {
    colnames(res) <- newname
  }
  res
}

#' Make a random identification field for steps
#'
#'
#' @export
#' @param prefix A single character string
#' @param len An integer for the number of random characters
#' @return A character string with the prefix and random letters separated by
#'  and underscore.
#'
#' @seealso [developer_functions]
#' @keywords internal
rand_id <- function(prefix = "step", len = 5) {
  candidates <- c(letters, LETTERS, paste(0:9))
  paste(prefix,
    paste0(sample(candidates, len, replace = TRUE), collapse = ""),
    sep = "_"
  )
}


check_nominal_type <- function(x, lvl) {
  all_act_cols <- names(x)

  # What columns do we expect to be factors based on the data
  # _before_ the recipes was prepped.

  # Keep in mind that some columns (like outcome data) may not
  # be in the current data so we remove those up-front.
  lvl <- lvl[names(lvl) %in% all_act_cols]

  # Figure out what we expect new data to be:
  fac_ref_cols <- purrr::map_lgl(lvl, function(x) isTRUE(x$factor))
  fac_ref_cols <- names(lvl)[fac_ref_cols]

  if (length(fac_ref_cols) > 0) {

    # Which are actual factors?
    fac_act_cols <- purrr::map_lgl(x, is.factor)

    fac_act_cols <- names(fac_act_cols)[fac_act_cols]

    # There may be some original factors that do not
    was_factor <- fac_ref_cols[!(fac_ref_cols %in% fac_act_cols)]

    if (length(was_factor) > 0) {
      cli::cli_warn(
        c(
          "!" = "There {?w/was/were} {length(was_factor)} column{?s} that \\
                {?was a factor/were factors} when the recipe was prepped: \\
                ",
          "*" = "{.and {.var {was_factor}}}",
          "i" = "This may cause errors when processing new data."
        )
      )
    }
  }
  invisible(NULL)
}

check_training_set <- function(x, rec, fresh, call = rlang::caller_env()) {
  # In case a variable has multiple roles
  vars <- unique(rec$var_info$variable)

  training_null <- is.null(x)
  if (training_null) {
    if (fresh) {
      cli::cli_abort(
        "A training set must be supplied to the {.arg training} argument \\
        when {.code fresh = TRUE}.",
        call = call
      )
    }
    x <- rec$template
  } else {
    if (!is_tibble(x)) {
      x <- as_tibble(x)
    }
    recipes_ptype_validate(rec, new_data = x, stage = "prep", call = call)

    x <- x[, vars]
  }

  steps_trained <- vapply(rec$steps, is_trained, logical(1))
  if (any(steps_trained) & !fresh) {
    if (!rec$retained) {
      cli::cli_abort(
        "To prep new steps after prepping the original recipe, \\
        {.code retain = TRUE} must be set each time that the recipe is \\
        trained.",
        call = call
      )
    }
    if (!training_null) {
      cli::cli_warn(
        c(
          "!" = "The previous data will be used by {.fn prep}.",
          "i" = "The data passed using {.arg training} will be ignored."
        ),
        call = call
      )
    }
    x <- rec$template
  }
  x
}

#' Get the `keep_original_cols` value of a recipe step
#'
#' @export
#' @param object A recipe step
#' @return A logical to keep the original variables in the output
#' @keywords internal
#' @seealso [developer_functions]
get_keep_original_cols <- function(object) {
  # Allow prepping of old recipes created before addition of keep_original_cols
  step_class <- class(object)[1]

  if (is.null(object$keep_original_cols)) {
    ret <- FALSE
    cli::cli_warn(
      c(
        "{.arg keep_original_cols} was added to {.fn {step_class}} after this \\
         recipe was created.",
        "i" = "Regenerate your recipe to avoid this warning."
      )
    )
  } else {
    ret <- object$keep_original_cols
  }

  ret
}

# ------------------------------------------------------------------------------

# from tune package
is_tune <- function(x) {
  if (is.call(x)) {
    if (rlang::call_name(x) == "tune") {
      return(TRUE)
    } else {
      return(FALSE)
    }
  } else {
    return(FALSE)
  }
  FALSE
}


# ------------------------------------------------------------------------------
# For all imputation functions that substitute elements into an existing vector:
# vctrs's cast functions would be better but we'll deal with the known cases
# to avoid a dependency.

cast <- function(x, ref) {
  if (is.factor(ref)) {
    x <- factor(x, levels = levels(ref), ordered = is.ordered(ref))
  } else {
    if (is.integer(ref) & !is.factor(ref)) {
      x <- as.integer(round(x, 0))
    }
  }
  x
}

## -----------------------------------------------------------------------------

print_col_names <- function(x, prefix = "") {
  if (length(x) == 0) {
    return(invisible(TRUE))
  }
  wdth <- options()$width
  if (length(prefix) > 0) {
    prefix <- paste0(prefix, " (", length(x), "): ")
  }
  nm_sdth <- cumsum(c(nchar(prefix), purrr::map_int(x, nchar) + 2))
  keep_x <- nm_sdth + 5 < wdth
  x <- x[keep_x[-1]]
  y <- paste0(prefix, paste0(x, collapse = ", "))
  if (!all(keep_x)) {
    y <- paste0(y, ", ...")
  }
  cat(y, "\n", sep = "")
  return(invisible(TRUE))
}

changelog <- function(show, before, after, x) {
  if (!show) {
    return(invisible(TRUE))
  }
  rm_cols <- setdiff(before, after)
  new_cols <- setdiff(after, before)

  cat(class(x)[1], " (", x$id, "): ", sep = "")
  if (length(new_cols) == 0 & length(rm_cols) == 0) {
    cat("same number of columns\n\n")
  } else {
    cat("\n")
    print_col_names(new_cols, " new")
    print_col_names(rm_cols, " removed")
    cat("\n")
  }
}

# ------------------------------------------------------------------------------

eval_dimred_call <- function(fn, ...) {
  cl <- rlang::call2(fn, .ns = "dimRed", ...)
  rlang::eval_tidy(cl)
}

dimred_data <- function(dat) {
  cl <- rlang::call2("dimRedData", .ns = "dimRed", rlang::expr(as.matrix(dat)))
  rlang::eval_tidy(cl)
}

uses_dim_red <- function(x) {
  dr <- inherits(x, "dimRedResult")
  if (dr) {
    cli::cli_abort(
      "Recipes version >= 0.1.17 represents the estimates using a different \\
      format. Please recreate this recipe or use version 0.1.16 or less. See \\
      issue {.href [#823](https://github.com/tidymodels/recipes/issues/823)}."
    )
  }
  invisible(NULL)
}

# ------------------------------------------------------------------------------

#' Check for required column at bake-time
#'
#' When baking a step, create an information error message when a column that
#' is used by the step is not present in `new_data`.
#'
#' @param req A character vector of required columns.
#' @param object A step object.
#' @param new_data A tibble of data being baked.
#' @return Invisible NULL. Side effects are the focus of the function.
#' @keywords internal
#'
#' @seealso [developer_functions]
#'
#' @export
check_new_data <- function(req, object, new_data) {
  if (is.null(req) || length(req) == 0L) {
    return(invisible(NULL))
  }
  col_diff <- setdiff(req, names(new_data))
  if (length(col_diff) == 0) {
    return(invisible(NULL))
  }
  step_cls <- class(object)[1]
  step_id <- object$id
  cli::cli_abort(
    "The following required {cli::qty(col_diff)} column{?s} {?is/are} \
    missing from `new_data` in step '{step_id}': {col_diff}.",
    class = "new_data_missing_column",
    call = rlang::call2(step_cls)
  )
}

stop_recipes <- function(class = NULL,
                         call = NULL,
                         parent = NULL) {
  rlang::abort(
    class = c(class, "recipes_error"),
    call = call,
    parent = parent
  )
}

stop_recipes_step <- function(call = NULL,
                              parent = NULL) {
  stop_recipes(
    class = "recipes_error_step",
    call = call,
    parent = parent
  )
}

recipes_error_context <- function(expr, step_name) {
  withCallingHandlers(
    expr = force(expr),
    error = function(cnd) {
      stop_recipes_step(
        call = call(step_name),
        parent = cnd
      )
    }
  )
}

#' @method conditionMessage recipes_error
#' @export
conditionMessage.recipes_error <- function(c) {
  rlang::cnd_message(c, prefix = TRUE)
}

vec_paste0 <- function(..., collapse = NULL) {
  args <- vctrs::vec_recycle_common(...)
  rlang::inject(paste0(!!!args, collapse = collapse))
}

#' Removes original columns if options apply
#'
#' This helper function should be used whenever the argument
#' `keep_original_cols` is used in a function.
#'
#' @param new_data A tibble.
#' @param object A step object.
#' @param col_names A character vector, denoting columns to remove.
#' @return new_data with `col_names` removed if
#'     `get_keep_original_cols(object) == TRUE` or `object$preserve == TRUE`.
#' @keywords internal
#'
#' @seealso [developer_functions]
#'
#' @export
remove_original_cols <- function(new_data, object, col_names) {
  keep_original_cols <- get_keep_original_cols(object)
  if (any(isFALSE(object$preserve), !keep_original_cols)) {
    new_data <- recipes_remove_cols(new_data, object, col_names)
  }
  new_data
}

#' Removes columns if options apply
#'
#' This helper function removes columns based on character vectors.
#'
#' @param new_data A tibble.
#' @param object A step object.
#' @param col_names A character vector, denoting columns to remove. Will
#'   overwrite `object$removals` if set.
#'
#' @return `new_data` with column names removed if specified by `col_names` or
#'   `object$removals`.
#' @keywords internal
#'
#' @seealso [developer_functions]
#'
#' @export
recipes_remove_cols <- function(new_data, object, col_names = character()) {
  if (length(col_names) > 0) {
    removals <- col_names
  } else if (length(object$removals) > 0) {
    removals <- object$removals
  } else {
    return(new_data)
  }

  if (length(removals) > 0) {
    # drop = FALSE in case someone uses this on a data.frame
    new_data <- new_data[, !(colnames(new_data) %in% removals), drop = FALSE]
  }
  new_data
}

#' Role indicators
#'
#' This helper function is meant to be used in `prep()` methods to identify
#' predictors and outcomes by names.
#'
#' @param info data.frame with variable information of columns. 
#'
#' @return Character vector of column names.
#' @keywords internal
#'
#' @seealso [developer_functions]
#'
#' @name recipes-role-indicator
NULL

#' @rdname recipes-role-indicator
#' @export
recipes_names_predictors <- function(info) {
  get_from_info(info, "predictor")
}

#' @rdname recipes-role-indicator
#' @export
recipes_names_outcomes <- function(info) {
  get_from_info(info, "outcome")
}

get_from_info <- function(info, role) {
  res <- info$variable[info$role == role & !is.na(info$role)]

  res
}

Try the recipes package in your browser

Any scripts or data that you put into this service are public.

recipes documentation built on July 4, 2024, 9:06 a.m.