R/deprec-all.R

Defines functions prop_ci count_string

Documented in prop_ci

# Temporary function since renaming.
#' @noRd
count_string <- function(...){
  warning("`count_string() is depreciated. Use count_pattern() instead (renamed).", call. = FALSE)
  count_pattern(...)
}

#' Vectorised inverse logit function
#'
#' Calculate the inverse logit function \eqn{exp(x) / (1 + exp(x))}. Modified from Faraway package.
#' The original allowed for some elements to be `NA` which this does not. I added a check
#' on `x` to avoid `NaN`. This occurs when `x` is greater than about
#' 750 but, since \eqn{exp(x) / (1 + exp(x)) = 1} for \eqn{x>=20}, I only check for x > 20 and
#' output 1 for these cases.
#'
#' @param x A numeric vector.
#' @keywords internal
#' @export
ilogit <- function (x){
  warning("`ilogit() is depreciated. Use `jemodel::ilogit()` instead.", call. = FALSE)
  if (any(big <- (x > 20))) {
    lv <- x
    lv[big] <- 1
    if (any(!big))
      lv[!big] <- Recall(x[!big])
    return(lv)
  }
  exp(x) / (1 + exp(x))
}

#########################################################################################
# prop_ci: Confidence intervals for binary target variable by values of a discrete predictor.
#########################################################################################
#'
#' Confidence intervals for binary target variable by values of a discrete predictor
#'
#' For a data frame input, one variable is a binary target (`target_name`) and another is selected to
#' be a predictor variable (`var_name`). Mean response and a confidence interval is calculated for the
#' target variable for each level or value of the predictor. The results are plotted and returned as a
#' table. The function most appropriate for factor predictors but will work with other variable types also.
#'
#' The target variable must be binary. Top compute confidence intervals this is converted to 0 and 1 values.
#' If it is not obvious which value corresponds to 1 and which to 0 then it will be based on level order
#' if a factor and the first observation otherwise. Giving the value of corresponding to 1 in the argument
#' `pos_class` will override this.
#'
#' Use the `plot` and `return_plot` arguments to control output. By default (designed to be
#'  used interactively) returns a table and prints a plot.  If `return_plot = TRUE` then just the
#'  plot is returned. If  `return_plot = FALSE` and
#'  `plot = FALSE` then the table is returned and no plot is generated. The default
#'
#' @param dt A data frame.
#' @param target_name Column to use as target variable. Column name (quoted or unquoted) or position.
#' @param var_name Column to use as predictor variable. Column name (quoted or unquoted) or position.
#' @param min_n Integer >= 1. Predictor levels with less than `min_n` observations are not displayed.
#' @param show_all Logical. Defaults to `TRUE`. If `FALSE` will not show levels whose confidence interval
#'   overlaps the mean response of all observations.
#' @param order_n Logical. Whether to force plot and table to order by number of observations of the
#'   predictior. The default setting `NULL` retains ordering if predictor is numeric or ordered factor
#'   and orders by number of observations otherwise.
#' @param conf_level Numeric in (0,1). Confidence level used for confidence intervals.
#' @param prop_lim Optional x axis limits passed to `ggplot()` e.g. `c(0,1)`.
#' @param pos_class Optional. Specify value in target to associate with class 1.
#' @param plot Optional logical. Output a plot or not.
#' @param return_plot Optional logical. If `TRUE` the plot is returned instead of the table (this overrides
#'   `plot` argument).
#'
#' @importFrom ggplot2 aes
#' @keywords internal
#' @export
prop_ci <- function(dt, target_name, var_name, min_n = 1, show_all = TRUE, order_n = NULL,
                    conf_level = 0.95, prop_lim = NULL, pos_class = NULL, plot = TRUE,
                    return_plot = FALSE) {
  lifecycle::deprecate_warn("0.3.1", "prop_ci()", "response::response()")
  if (!is.data.frame(dt)) stop("`dt` must be a data frame.", call. = FALSE)
  y_label <- names(dplyr::select(dt, {{var_name}})) #for plot
  dt <- dplyr::rename(dt,
                      target = {{target_name}},
                      var = {{var_name}})
  if (any(is.na(dt$target))){
    dt <- dplyr::filter(dt, !is.na(.data$target))
    message("There are NA values in target variable. These rows will be excluded from any calculations.")
  }
  targ_vec <- dplyr::pull(dt, .data$target)
  # target vector checks
  if (length(unique(targ_vec)) > 2L){
    stop("Target variable must be binary.", call. = FALSE)
  }
  if (is.factor(targ_vec) && all(targ_vec %in% c("TRUE", "FALSE"))) targ_vec <- as.logical(targ_vec) #use expected pos class if factor
  if (!all(targ_vec %in% c(0, 1))){
    if (is.null(pos_class)) pos_class <- as.vector(targ_vec)[1]
    neg_class <- setdiff(as.vector(targ_vec), pos_class)
    targ_vec <- ifelse(targ_vec == pos_class, 1L, 0L)
    message("Target not binary or logical. Treating ",
            pos_class,
            " as 1 and ",
            neg_class,
            " as 0.\nUse 'pos_class' argument to set different value for class 1.")
  }
  if(is.factor(targ_vec)){
    targ_vec <- as.integer(levels(targ_vec))[targ_vec]
  }
  dt <- dplyr::mutate(dt, target = targ_vec)
  if (is.factor(dt$var) && !("(Missing)" %in% dt$var)){
    dt <- dplyr::mutate(dt, var = forcats::fct_explicit_na(.data$var))
  }
  if (is.null(order_n)){
    order_n <- if (is.numeric(dt$var) || is.ordered(dt$var)) FALSE else TRUE
  }

  mean_all <- mean(targ_vec)

  dt_summ <- dt %>%
    dplyr::group_by(.data$var) %>%
    dplyr::summarise(n = dplyr::n(),
                     n_pos = sum(.data$target),
                     prop = mean(.data$target)) %>%
    dplyr::rename(value = .data$var) %>%
    dplyr::arrange(dplyr::desc(.data$n)) %>%
    dplyr::mutate(lo = binom::binom.wilson(.data$prop * .data$n, .data$n, conf.level = conf_level)[['lower']],
                  hi = binom::binom.wilson(.data$prop * .data$n, .data$n, conf.level = conf_level)[['upper']],
                  sig = dplyr::case_when(
                    .data$lo > mean_all ~ 3,
                    .data$hi < mean_all ~ 1,
                    T ~ 2),
                  sig = factor(.data$sig, levels = c(1, 2, 3), labels = c("lo", "none", "hi"))
    ) %>%
    dplyr::filter(.data$n >= min_n) %>%
    purrr::when(!show_all ~ filter(., .data$sig != "none"), ~.) %>%
    purrr::when(order_n ~ dplyr::arrange(., dplyr::desc(.data$n)), ~dplyr::arrange(., .data$value))
  if (order_n){
    dt_plot <- dplyr::mutate(dt_summ, value = stats::reorder(.data$value, .data$n))
  }else{
    dt_plot <- dt_summ
  }
  if (plot || return_plot){
    cols <- c("#F8766D", "#00BA38", "#619CFF")
    gg <- dt_plot %>%
      ggplot2::ggplot(aes(x = .data$value, y = .data$prop, color = .data$sig)) +
      ggplot2::geom_point() +
      ggplot2::geom_errorbar(aes(ymin = .data$lo, ymax = .data$hi)) +
      ggplot2::coord_flip() +
      ggplot2::geom_hline(yintercept = mean_all, linetype = 2) +
      ggplot2::ylab("Mean Proportion Target") +
      ggplot2::xlab(y_label) +
      ggplot2::theme(legend.position = "none") +
      ggplot2::scale_colour_manual(values = c("lo" = cols[1], "none" = cols[3], "hi" = cols[2])) +
      {if(all(!is.null(prop_lim))) ggplot2::ylim(prop_lim[1], prop_lim[2])}
    if (!return_plot) print(gg)
  }
  if (return_plot) gg else dt_summ
}

#' Summarise coefficients from glmnet in a table
#'
#' Returns a tibble of coefficients for glmnet model `fit` with parameter `s`. Only
#' coefficients with absolute value greater than `min_coef` are included.
#'
#' This is a simplified version of `jemodel::glmnet_to_table()` kept here to provide some backwards
#' compatibility. Longer term it may be removed from this package.
#'
#' @param fit A fitted glmnet model.
#' @param min_coef Coefficients with smaller absolute value than this are excluded from the table.
#' @param ... Arguments passed to `coef()`. For `glmnet`, most commonly used will be `s` (see `predict.cv.glmnet()`).
#' @keywords internal
#' @export
glmnet_to_table <- function(fit, ..., min_coef=1E-10) {
  lifecycle::deprecate_warn("0.3.1", "glmnet_to_table()", "jemodel::coef_to_table()")
  ce <- coef(fit, ...)
  coef_mat <- as.matrix(ce)
  tibble::tibble(name = rownames(coef_mat), coef = coef_mat[, 1]) %>%
    dplyr::filter(abs(coef) >= min_coef) %>%
    dplyr::arrange(dplyr::desc(coef))
}

#' @description
#' `count_matches2()` is a deprecated function that has been replaced by the `detail` argument.
#' @rdname count_matches
#' @keywords internal
#' @export
count_matches2 <- function(df, values = string_missing(), all = FALSE, prop = FALSE) {
  lifecycle::deprecate_warn("0.3.1", "count_matches2()",
                            details = "Replaced by `count_matches(detail=TRUE)`")
  if (!is.list(df)) {
    stop("Argument \"df\" must be a list.", call. = FALSE)
  }
  if (!is.character(values)||!is.vector(values)){
    stop("Argument \"values\" must be a character vector.", call. = FALSE)
  }
  tb <- lapply(values,
               FUN = count_matches_simple,
               df = df,
               all = TRUE,
               prop = prop) %>%
    dplyr::bind_rows() %>%
    dplyr::mutate(string = values) %>%
    dplyr::select(.data$string, dplyr::everything())

  if (all){
    return(tb)
  }
  sum_rows <- rowSums(tb[, -1])
  sum_cols <- c(1L, colSums(tb[, -1]))
  tb <- tb[sum_rows > 0, sum_cols > 0]
  if (sum(sum_rows) == 0){
    message("No matches in the data.")
    return(invisible(tb))
  }
  tb
}
jedwards24/edwards documentation built on Sept. 2, 2023, 8:16 a.m.