R/create_case_when.R

Defines functions create_case_when create_sql_case_when is_case_when formulas formulas.case_when variable.names.case_when print.case_when .create_case_when .translate_to_sql

Documented in create_case_when create_sql_case_when formulas formulas.case_when print.case_when variable.names.case_when

#' @include utils.R
#' @import rlang methods
#' @importFrom assertthat assert_that
#' @importFrom pryr modify_lang
#' @importFrom purrr map walk
#' @importFrom dplyr case_when
#' @importFrom crayon cyan magenta green
#' @importFrom utils capture.output
#' @importFrom stats variable.names
NULL

#' A case_when factory
#'
#' `create_case_when` allows to create reusable [dplyr::case_when()] functions.
#'  It returns a function that can be used in place of
#'  [dplyr::case_when()]. The arguments of the returned function are
#'  given by the `vars` argument of `create_case_when()`.
#'
#' The returned function is of class `case_when`.
#'
#' @inheritParams dplyr::case_when
#' @param vars A character vector that determined the names of the arguments
#'     of the returned function.
#' @return A function, usable in place of [dplyr::case_when()].
#' @export
#' @examples
#' x <- 1:50
#' y <- 51:100
#'
#' cw_fb <- create_case_when(
#'   number %% 35 == 0 ~ "fizz buzz",
#'   number %% 5 == 0 ~ "fizz",
#'   number %% 7 == 0 ~ "buzz",
#'   TRUE ~ as.character(number),
#'   vars = "number"
#' )
#'
#' cw_fb(number = x)
#' cw_fb(number = y)
#'
#' # Formulas and variable names can be extracted
#' patterns <- formulas(cw_fb)
#' var_name <- variable.names(cw_fb)
#'
#' # Dots support splicing
#' create_case_when(!!! patterns, vars = var_name)
create_case_when <- function(..., vars = "x") {
  formulas <- rlang::dots_list(...)
  structure(
    .create_case_when(!!! formulas, vars = vars),
    class = c("dplyr_case_when", "case_when", "function")
  )
}

#' Create a reusable SQL case_when function
#'
#' This function is a helper devoted to developers who want to add a custom
#' `case_when` function to a [SQL translator][dbplyr::sql_variant()]. In this case,
#' you should use the `fn` argument instead of the `con` argument.
#'
#' @inheritParams create_case_when
#' @inheritParams dplyr::sql_translate_env
#' @param fn A function to be used by the returned function.
#' @keywords internal
#' @export
#' @examples
#' con <- structure(
#'   list(),
#'   class = c("TestCon", "Oracle", "DBITestConnection", "DBIConnection")
#' )
#'
#' fn <- dplyr::sql_translate_env(con)$scalar$case_when
#'
#' cw_fb <- create_sql_case_when(
#'   number %% 35 == 0 ~ "fizz buzz",
#'   number %% 5 == 0 ~ "fizz",
#'   number %% 7 == 0 ~ "buzz",
#'   TRUE ~ as.character(number),
#'   vars = "number",
#'   fn = fn
#' )
#'
#' testcon_var <- dbplyr::sql_variant(
#'   dbplyr::sql_translator(
#'     cw_fb = cw_fb,
#'     .parent = dplyr::sql_translate_env(con)$scalar
#'   ),
#'   dplyr::sql_translate_env(con)$aggregate,
#'   dplyr::sql_translate_env(con)$window
#' )
#'
#' sql_translate_env.TestCon <- function(x) testcon_var
#'
#' dbplyr::translate_sql(cw_fb(x), con = con)
create_sql_case_when <-
  function(...,
           vars = "x",
           con = NULL,
           fn = dplyr::sql_translate_env(con = con)$scalar$case_when
  ) {
    formulas <- rlang::dots_list(...)
    args <- c(formulas, list(vars = vars, fn = fn))
    structure(
      do.call(.create_case_when, args),
      class = c("sql_case_when", "case_when", "function")
    )
}

is_case_when <- function(x) {
  if (is.list(x)) return(vapply(x, is_case_when, FUN.VALUE = logical(1)))
  inherits(x, "case_when")
}

#' Get formulas
#'
#' `formulas` retrieves formulas from an object that depends on.
#'
#' @export
#' @param x An object used to select a method.
#' @param ... Other arguments passed on to methods.
#' @keywords internal
#' @return A list of `formula` objects.
formulas <- function(x, ...) UseMethod("formulas")

#' @describeIn create_case_when Get the formulas of a `case_when` function.
#' @param x,object A `case_when` function, created by `create_case_when`.
#' @export
formulas.case_when <- function(x, ...) get("formulas", envir = environment(x))

#' @describeIn create_case_when Get the variable names of a `case_when` function.
#' @export
variable.names.case_when <- function(object, ...) get("vars", envir = environment(object))

#' @describeIn create_case_when Print informations about a `case_when` function.
#' @export
print.case_when <- function(x, ...) {
  formulas <- formulas(x)
  var_names <- variable.names(x)
  n_forms <- length(formulas)
  n_vars <- length(var_names)
  out <- utils::capture.output(purrr::walk(formulas, print, showEnv = FALSE))
  out <- c(
    crayon::cyan("<CASE WHEN>"),
    crayon::magenta(n_vars, paste0("variable", plural(var_names), ":"),
                    paste(var_names, collapse = ", ")
                    ),
    crayon::magenta(n_forms, paste0("condition", plural(formulas), ":")),
    crayon::green(paste("->", out)), ""
  )
  cat(paste0(out, collapse = "\n"))
  invisible(x)
}

.create_case_when <- function(..., vars, fn = dplyr::case_when) {
  assertthat::assert_that(is.character(vars))
  fun_fmls <- purrr::map(rlang::set_names(vars), ~ rlang::missing_arg())
  fun_body <- substitute({
    match_call <- match.call()
    args_call <- as.list(match_call[-1])
    modify_vars <- function(x) {
      if (is.name(x)) {
        if (as.character(x) %in% names(args_call))
          return(args_call[[as.character(x)]])
      }
      x
    }
    n <- length(formulas)
    new_formulas <- vector("list", n)
    for (i in seq_len(n)) {
      lhs <- rlang::f_lhs(formulas[[i]])
      rhs <- rlang::f_rhs(formulas[[i]])
      new_lhs <- pryr::modify_lang(lhs, modify_vars)
      new_rhs <- pryr::modify_lang(rhs, modify_vars)
      new_formulas[[i]] <- rlang::new_formula(new_lhs, new_rhs, env = parent.frame())
    }
    do.call(fn, new_formulas)
  })
  formulas <- rlang::dots_list(...)
  purrr::walk(formulas,
              ~ assertthat::assert_that(rlang::is_formula(.x),
                                        msg = "An argument is not a formula."
              )
  )
  rlang::new_function(fun_fmls, fun_body)
}

.translate_to_sql <- function(cw_fn, con) {
  if (is.list(cw_fn)) {
    return(lapply(cw_fn, .translate_to_sql, con = con))
  } else {
    formulas <- formulas(cw_fn)
    vars <- variable.names(cw_fn)
    create_sql_case_when(!!! formulas, vars = vars, con = con)
  }
}
RLesur/casewhen documentation built on May 5, 2019, 12:29 a.m.