R/treatment-effects-posterior.R

Defines functions is_treated counter_factual fitted_with_counter_factual_draws treatment_effects.default treatment_effects

Documented in treatment_effects treatment_effects.default

#' Get (individual) treatment effect draws from posterior
#'
#' CTE = Conditional Treatment Effects (usually used to generate (C)ATE or ATT)
#' \code{newdata} specifies the conditions, if unspecified it defaults to the original data.
#' Assumes treated column is either a integer column of 1's (treated) and 0's (nontreated) or logical indicating treatment if TRUE.
#'
#' @param model A supported Bayesian model fit that can provide fits and predictions.
#' @param treatment A character string specifying the name of the treatment variable.
#' @param newdata Data frame to generate fitted values from. If omitted, defaults to the data used to fit the model.
#' @param subset Either "treated", "nontreated", or "all". Default is "all".
#' @param common_support_method Either "sd", or "chisq". Default is unspecified, and no common support calculation is done.
#' @param cutoff Cutoff for common support (if in use).
#' @param ... Arguments to be passed to \code{tidybayes::fitted_draws} typically scale for \code{BART} models.
#'
#' @return A tidy data frame (tibble) with treatment effect values.
#' @export
#'

treatment_effects <- function(model, treatment, newdata, subset = "all", common_support_method, cutoff, ...) {
  UseMethod("treatment_effects")
}

#' Get treatment effect draws from posterior
#'
#' CTE = Conditional Treatment Effects (or CATE, the average effects)
#' \code{newdata} specifies the conditions, if unspecified it defaults to the original data.
#' Assumes treated column is either a integer column of 1's (treated) and 0's (nontreated) or logical indicating treatment if TRUE.
#'
#' @inheritParams treatment_effects
#'
#' @return A tidy data frame (tibble) with treatment effect values.
#' @export
#'
treatment_effects.default <- function(model, treatment, newdata, subset = "all", common_support_method, cutoff, ...) {
  stopifnot(
    !missing(treatment),
    is.character(treatment),
    length(treatment) == 1
  )

  if (missing(newdata)) {
    check_method(model,
      method = "model.matrix",
      helper = "Please specify 'newdata' argument = data from model fitting."
    )

    modeldata <- stats::model.matrix(model)
  } else {
    modeldata <- newdata
  }

  posterior_fit_with_cf <- fitted_with_counter_factual_draws(model, modeldata, treatment, subset, ...)

  posterior_treatment <- dplyr::select(
    dplyr::mutate(posterior_fit_with_cf, cte = (2L * as.integer(!!rlang::sym(treatment)) - 1L) * (.data$observed - .data$cfactual)), # equivalent to treatment - non_treatment
    -.data$observed, -.data$cfactual
  )

  # add boolean for common support
  if (!missing(common_support_method)) {
    stopifnot(
      !missing(cutoff),
      # should use model data only, unless needs to be specified (e.g. for BART models).
      missing(newdata) | !has_tidytreatment_methods(model)
    )

    if (!missing(newdata)) message("Note: Argument 'newdata' must be original dataset when calculating common support.")

    common_supp <-
      calc_common_support_from_fitted_and_cf(
        fitted_and_cf = posterior_fit_with_cf,
        modeldata = modeldata,
        treatment = treatment,
        method = common_support_method,
        cutoff = cutoff
      )

    posterior_treatment <- dplyr::left_join(posterior_treatment, common_supp, by = ".row")
  }


  return(posterior_treatment)
}

fitted_with_counter_factual_draws <- function(model, newdata, treatment, subset, ...) {
  stopifnot(
    has_tidytreatment_methods(model) | !missing(newdata)
  )

  if (missing(newdata)) {
    newdata <- stats::model.matrix(model)
  }

  use_subset <- match.arg(subset, c("all", "treated", "nontreated"))

  stopifnot(
    treatment %in% colnames(newdata),
    is.data.frame(newdata)
  )

  stopifnot(
    is_01_integer_vector(newdata[, treatment]) | is.logical(newdata[, treatment])
  )

  obs_fitted <- tidybayes::fitted_draws(
    model = model, value = "observed",
    newdata = newdata,
    include_newdata = FALSE,
    ...
  )

  cfactual_fitted <- tidybayes::fitted_draws(
    model = model, value = "cfactual",
    newdata = dplyr::mutate(newdata, !!treatment := counter_factual(!!rlang::sym(treatment))),
    include_newdata = FALSE,
    ...
  )

  obs_fitted <- dplyr::left_join(
    obs_fitted,
    dplyr::mutate(dplyr::select(newdata, !!treatment), .row = 1:dplyr::n()),
    by = c(".row")
  )

  out <- dplyr::left_join(
    obs_fitted,
    cfactual_fitted,
    by = c(".row", ".chain", ".iteration", ".draw")
  )

  if (use_subset == "treated") {
    out <- dplyr::filter(out, is_treated(!!rlang::sym(treatment)))
  } else if (use_subset == "nontreated") {
    out <- dplyr::filter(out, !is_treated(!!rlang::sym(treatment)))
  }

  return(out)
}

counter_factual <- function(x) {
  if (is.integer(x)) {
    return(1L - x)
  } else if (is.logical(x)) {
    return(!x)
  } else {
    return(rep(NA, times = length(x)))
  }
}

is_treated <- function(x) {
  if (is.integer(x)) {
    return(x == 1L)
  } else if (is.logical(x)) {
    return(x)
  } else {
    return(rep(NA, times = length(x)))
  }
}

Try the tidytreatment package in your browser

Any scripts or data that you put into this service are public.

tidytreatment documentation built on March 18, 2022, 6:30 p.m.