R/utils.R

Defines functions .make.covariate.table .generate.ps

Documented in .generate.ps .make.covariate.table

#' Generate Sample Participation Probabilities
#'
#' This function is designed for use within 'weighting()' and 'assess()'.
#'
#' @param data data frame comprised of "stacked" sample and target population data
#' @param sample_indicator variable name denoting sample membership (1 = in sample, 0 = out of sample)
#' @param covariates vector of covariate names in data set that predict sample membership
#' @param estimation_method method to estimate the probability of sample membership. Default is logistic regression ("lr").Other methods supported are Random Forests ("rf") and Lasso ("lasso")
#' @return sample participation probabilities for each unit in the data frame
#' @importFrom glmnet cv.glmnet
#' @importFrom randomForest randomForest
#' @importFrom stats as.formula glm lm predict quantile
#'

.generate.ps <- function(data,
                         sample_indicator,
                         covariates,
                         estimation_method) {

  # Logistic Regression
  if (estimation_method == "lr") {
    formula <- paste(sample_indicator,
      paste(covariates, collapse = "+"),
      sep = "~"
    ) %>%
      as.formula()

    ps <- formula %>%
      glm(
        data = data,
        family = "quasibinomial"
      ) %>%
      predict(type = "response")
  }

  # Random Forest
  if (estimation_method == "rf") {
    formula <- paste(
      paste("as.factor(", sample_indicator, ")"),
      paste(covariates, collapse = "+"),
      sep = "~"
    ) %>%
      as.formula()

    ps <- randomForest::randomForest(formula,
      data = data,
      na.action = na.omit
    ) %>%
      # sampsize = 454,
      # ntree = 1500) %>%
      predict(type = "prob") %>%
      as.data.frame() %>%
      dplyr::pull("1")
  }

  # Lasso
  if (estimation_method == "lasso") {
    test.x <- model.matrix(~ -1 + .,
      data = data %>% dplyr::select(tidyselect::all_of(covariates))
    )

    test.y <- data %>% dplyr::pull(sample_indicator)

    ps <- glmnet::cv.glmnet(
      x = test.x,
      y = test.y,
      family = "binomial"
    ) %>%
      predict(
        newx = test.x,
        s = "lambda.1se",
        type = "response"
      ) %>%
      as.numeric()
  }

  ### Set any participation probabilities of 0 in the sample to the minimum non-zero value ###

  if (0 %in% ps) {
    ps[which(data[, sample_indicator] == 1 & ps == 0)] <- min(ps[which(data[, sample_indicator] == 1 & ps != 0)], na.rm = TRUE)
  }

  return(ps)
}

#' Create Covariate Balance Table
#'
#' This function is designed for use within \code{weighting()} and \code{assess()}.'
#'
#' @param data Dataframe comprised of "stacked" sample and target population data
#' @param sample_indicator Binary variable denoting sample membership (1 = in sample, 0 = out of sample)
#' @param covariates Vector of covariates in dataframe that predict sample membership
#' @param sample_weights Name of column in dataframe holding weights for calculating weighted sample means of covariates in dataframe. If NULL, sample means are unweighted.
#' @param estimation_method Method to estimate the probability of sample membership. Default is logistic regression ("lr"). Other methods supported are Random Forests ("rf") and Lasso ("lasso").
#' @param disjoint_data Logical. Defaults to TRUE. If TRUE, then sample and population data are considered disjoint. This affects calculation of the weights.
#' @importFrom stats model.matrix weighted.mean
#' @importFrom dplyr funs

.make.covariate.table <- function(data,
                                  sample_indicator,
                                  covariates,
                                  sample_weights = NULL,
                                  estimation_method = "lr",
                                  disjoint_data = TRUE) {

  data <- data %>%
    tidyr::drop_na(tidyselect::all_of(covariates)) %>%
    as.data.frame()

  if (is.null(sample_weights)) {

    data$sample_weights <- 1
  } else {

    data <- data %>%
      dplyr::rename(sample_weights = !!rlang::sym(sample_weights)) %>%
      dplyr::mutate(sample_weights = ifelse(data[, sample_indicator] == 0, 1, sample_weights))
  }

  expanded.data <- data.frame(data[, sample_indicator],
                              model.matrix(~ -1 + ., data = data[, c(covariates, "sample_weights")]))

  names(expanded.data)[1] <- sample_indicator

  if (!disjoint_data) {

    expanded.data <- expanded.data %>%
      dplyr::filter(!!rlang::sym(sample_indicator) == 1) %>%
      rbind(expanded.data %>%
              dplyr::mutate(!!rlang::sym(sample_indicator) := 0,
                            sample_weights = 1))
  }

  get_covariate <- function(name) {

    covariate <- dplyr::case_when(
      stringr::str_detect(name, "_mean_weighted$") ~  stringr::str_remove(name, "_mean_weighted$"),
      stringr::str_detect(name, "_mean$") ~ stringr::str_remove(name, "_mean$"),
      stringr::str_detect(name, "_var_weighted$") ~ stringr::str_remove(name, "_var_weighted$"),
      stringr::str_detect(name, "_var$") ~ stringr::str_remove(name, "_var$"),
    )

    return(covariate)
  }

  get_statistic <- function(name) {

    statistic <- dplyr::case_when(
      stringr::str_detect(name, "_mean_weighted$") ~ "mean_weighted",
      stringr::str_detect(name, "_mean$") ~ "mean",
      stringr::str_detect(name, "_var_weighted$") ~ "var_weighted",
      stringr::str_detect(name, "_var$") ~ "var")

    return(statistic)
  }

  tab <- expanded.data %>%
    dplyr::group_by(!!rlang::sym(sample_indicator)) %>%
    dplyr::summarise(dplyr::across(tidyselect::all_of(covariates),
                                   list(mean = mean,
                                        mean_weighted = ~weighted.mean(., sample_weights),
                                        var = var)))

  tab_pop <- tab %>%
    dplyr::filter(!!rlang::sym(sample_indicator) == 0) %>%
    tidyr::pivot_longer(cols = -!!rlang::sym(sample_indicator)) %>%
    dplyr::mutate(covariate = get_covariate(name),
                  statistic = get_statistic(name)
    ) %>%
    tidyr::pivot_wider(names_from = statistic,
                       values_from = value) %>%
    dplyr::group_by(covariate) %>%
    dplyr::summarise(dplyr::across(c("mean", "var"), na.omit)) %>%
    `colnames<-`(c("covariate", "pop_mean", "pop_var"))

  tab_sample <- tab %>%
    dplyr::filter(!!rlang::sym(sample_indicator) == 1) %>%
    tidyr::pivot_longer(cols = -!!rlang::sym(sample_indicator)) %>%
    dplyr::mutate(covariate = get_covariate(name),
                  statistic = get_statistic(name)
    ) %>%
    tidyr::pivot_wider(names_from = statistic,
                       values_from = value) %>%
    dplyr::group_by(covariate) %>%
    dplyr::summarise(across(c("mean", "mean_weighted"), na.omit)) %>%
    `colnames<-`(c("covariate", "sample_mean_unweighted", "sample_mean_weighted"))

  tab_merged <- merge(tab_pop, tab_sample, by = "covariate")  %>%
    dplyr::mutate(pop_sd = sqrt(pop_var),
                  ASMD_unweighted = abs((sample_mean_unweighted - pop_mean) / pop_sd),
                  ASMD_weighted = abs((sample_mean_weighted - pop_mean) / pop_sd)) %>%
    dplyr::mutate(dplyr::across(tidyselect::where(is.numeric), round, digits = 3))

  if (!is.null(sample_weights)) {

    covariate_table <- tab_merged %>%
      dplyr::select(covariate, sample_mean_unweighted, sample_mean_weighted,
                    pop_mean, pop_sd, ASMD_unweighted, ASMD_weighted)


    covariate_kable <- covariate_table %>%
      dplyr::mutate(sample_mean = paste0(sample_mean_weighted,
                                         " [",
                                         sample_mean_unweighted,
                                         "]"),
                    ASMD = paste0(ASMD_weighted,
                                  " [",
                                  ASMD_unweighted,
                                  "]")) %>%
      dplyr::mutate(ASMD = ASMD %>%
                      kableExtra::cell_spec(color = ifelse(ASMD_weighted < ASMD_unweighted,
                                                           "green",
                                                           "red"))) %>%
      dplyr::select(covariate, sample_mean, pop_mean, pop_sd, ASMD) %>%
      dplyr::rename(Covariate = covariate,
                    `Sample Mean` = sample_mean,
                    `Population Mean` = pop_mean,
                    `Population SD` = pop_sd)


      names(covariate_kable)[2] <- paste0(names(covariate_kable)[2],
                                          kableExtra::footnote_marker_symbol(1))

      names(covariate_kable)[5] <- paste0(names(covariate_kable)[5],
                                          kableExtra::footnote_marker_symbol(2))

      covariate_kable <- covariate_kable %>%
        kableExtra::kbl(caption = "Covariate Table",
                        align = "l",
                        escape = FALSE) %>%
        kableExtra::kable_styling(c("striped", "hover"), fixed_thead = TRUE) %>%
        kableExtra::column_spec(1, bold = TRUE, border_right = TRUE, color = "black", background = "lightgrey") %>%
        kableExtra::footnote(symbol = c("Weighted Sample Mean [Unweighted Sample Mean]",
                                        "Calculated Using Weighted Sample Mean [Calculated Using Unweighted Sample Mean]"))

  } else {

    covariate_table <- tab_merged %>%
      dplyr::select(covariate, sample_mean_unweighted, pop_mean, pop_sd, ASMD_unweighted) %>%
      dplyr::rename(sample_mean = sample_mean_unweighted,
                    ASMD = ASMD_unweighted)

    covariate_kable <- covariate_table %>%
      dplyr::rename(Covariate = covariate,
                    `Sample Mean` = sample_mean,
                    `Population Mean` = pop_mean,
                    `Population SD` = pop_sd) %>%
      kableExtra::kbl(caption = "Covariate Table",
                      align = "l") %>%
      kableExtra::kable_styling(c("striped", "hover"), fixed_thead = TRUE) %>%
      kableExtra::column_spec(1, bold = TRUE, border_right = TRUE, color = "black", background = "lightgrey")
  }

  if(length(covariates) > 25) nrow <- 5

  else nrow <- NULL

  cov_dist_facet_plot <- expanded.data %>%
    tidyr::pivot_longer(cols = covariates[1:40] %>%
                          tidyselect::all_of() %>%
                          na.omit(),
                        names_to = "covariate") %>%
    ggplot() +
    facet_wrap(~covariate,
               scales = "free",
               nrow = nrow) +
    geom_density(aes(x = value, fill = factor(!!rlang::sym(sample_indicator))),
                 alpha = 0.7) +
    scale_x_continuous(expand = c(0, 0),
                       n.breaks = 3) +
    scale_y_continuous(expand = c(0, 0)) +
    scale_fill_discrete(name = NULL,
                        labels = c("Population", "Sample")) +
    ggtitle("Covariate Density Plots") +
    theme_minimal() +
    theme(axis.ticks.x = element_line(),
          axis.text.y = element_blank(),
          axis.text.x = element_text(angle = 45,
                                     hjust = 1),
          axis.line = element_line(),
          axis.title = element_blank(),
          plot.title = element_text(size = 12))

  cov_dist_plots <- list()

  for (covariate in covariates) {

    new_plot <- expanded.data %>%
      ggplot() +
      geom_density(aes(x = !!rlang::sym(covariate), fill = factor(!!rlang::sym(sample_indicator))),
                   alpha = 0.7) +
      scale_x_continuous(expand = c(0, 0)) +
      scale_y_continuous(expand = c(0, 0)) +
      scale_fill_discrete(name = NULL,
                          labels = c("Population", "Sample")) +
      ggtitle(paste(covariate, "Density Plot")) +
      theme_minimal() +
      theme(axis.ticks.x = element_line(),
            axis.text.y = element_blank(),
            axis.line = element_line(),
            axis.title = element_blank(),
            plot.title = element_text(size = 12))

    cov_dist_plots[[covariate]] <- new_plot
  }

  out <- list(covariate_table = covariate_table,
              covariate_kable = covariate_kable,
              cov_dist_facet_plot = cov_dist_facet_plot,
              cov_dist_plots = cov_dist_plots)

  return(invisible(out))
}

# adding ggplot arguments in .make.covariate.table as global variables since they aren't recognized
if(getRversion() >= "2.15.1")  utils::globalVariables(c(":=", "name", "statistic", "value", "across", "pop_var", "sample_mean_unweighted", "pop_mean", "pop_sd", "sample_mean_weighted", "ASMD_unweighted", "ASMD_weighted", "ASMD", "sample_mean", "facet_wrap", "geom_density", "scale_x_continuous", "scale_fill_discrete", "ggtitle", "theme_minimal", "element_line"))
NUstat/generalizeR documentation built on Feb. 7, 2024, 6:39 p.m.