R/get_synth_draws.R

Defines functions .get_draws3d .get_synth_draws3d .get_synth_draws_predictor_match .get_synth_draws

Documented in .get_synth_draws .get_synth_draws3d .get_synth_draws_predictor_match

## Copyright 2021 Google LLC
##
## Licensed under the Apache License, Version 2.0 (the "License");
## you may not use this file except in compliance with the License.
## You may obtain a copy of the License at
##
##     https://www.apache.org/licenses/LICENSE-2.0
##
## Unless required by applicable law or agreed to in writing, software
## distributed under the License is distributed on an "AS IS" BASIS,
## WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
## See the License for the specific language governing permissions and
## limitations under the License.
##
#' Get Synthetic Draws in Tidy Format for Single Treated Unit
#'
#' This internal helper function extracts synthetic draws from a Stan fit object,
#' combines them with observed outcome data, and returns a tidy data frame suitable
#' for further analysis or plotting. This function is specifically designed for
#' scenarios with a single treated unit.
#'
#' @param fit A Stan fit object containing the model results.
#' @param pre_data A data frame with outcome data before the intervention.
#' @param post_data A data frame with outcome data after the intervention.
#' @param time The name of the time period variable (as a string).
#' @param outcome The name of the outcome variable (as a string).
#'
#' @return A data frame containing:
#'   * `draw`: The index of the synthetic draw.
#'   * `time`: The time period.
#'   * `y_synth`: The synthetic outcome for the given draw and time period.
#'   * `outcome`: The observed outcome for the given time period.
#'
.get_synth_draws <- function(fit, pre_data, post_data, time, outcome) {
  y_sim_draws <- .get_par_long(fit = fit, par = y_sim)
  dateXwalk <- pre_data %>%
    dplyr::mutate(idx = 1:dplyr::n()) %>%
    dplyr::select(idx, !!time)
  y_hat <- dplyr::inner_join(y_sim_draws, dateXwalk, by = "idx") %>%
    dplyr::rename(y_synth = y_sim)

  y_pred_draws <- .get_par_long(fit = fit, par = y_pred)

  dateXwalk <- post_data %>%
    dplyr::mutate(idx = 1:dplyr::n()) %>%
    dplyr::select(idx, !!time)
  y_pred_hat <- dplyr::inner_join(y_pred_draws, dateXwalk, by = "idx") %>%
    dplyr::rename(y_synth = y_pred)
  y_synth <- dplyr::bind_rows(y_hat, y_pred_hat) %>%
    dplyr::select(-idx)

  pre_outcome <- pre_data %>%
    dplyr::select(!!outcome, !!time)
  post_outcome <- post_data %>%
    dplyr::select(!!outcome, !!time)

  y <- dplyr::bind_rows(pre_outcome, post_outcome)
  y_synth <- dplyr::full_join(y_synth, y, by = rlang::as_name(time))
  return(y_synth)
}

# TODO(jvives): Unify get_synth_draws functions into one.
#' Get Synthetic Draws in Tidy Format for Single Treated Unit (Predictor Match Model)
#'
#' This internal helper function extracts synthetic draws from a Stan fit object
#' generated by a predictor match model. It combines these draws with observed
#' outcome data and returns a tidy data frame suitable for analysis or plotting.
#' It specifically works with variable definitions from the predictor match model.
#'
#' @param fit A Stan fit object containing the model results.
#' @param pre_data A data frame with outcome data before the intervention.
#' @param post_data A data frame with outcome data after the intervention.
#' @param time The name of the time period variable (as a string).
#' @param outcome The name of the outcome variable (as a string).
#'
#' @return A data frame containing:
#'   * `draw`: The index of the synthetic draw.
#'   * `time`: The time period.
#'   * `y_synth`: The synthetic outcome for the given draw and time period.
#'   * `outcome`: The observed outcome for the given time period.
.get_synth_draws_predictor_match <- function(fit, pre_data,
                                             post_data, time, outcome) {
  X1_sim_draws <- .get_par_long(fit = fit, par = X1_sim)
  dateXwalk <- pre_data %>%
    dplyr::mutate(idx = 1:dplyr::n()) %>%
    dplyr::select(idx, !!time)
  y_hat <- dplyr::inner_join(X1_sim_draws, dateXwalk, by = "idx") %>%
    dplyr::rename(y_synth = X1_sim)

  X1_pred_draws <- .get_par_long(fit = fit, par = X1_pred)
  dateXwalk <- post_data %>%
    dplyr::mutate(idx = 1:dplyr::n()) %>%
    dplyr::select(idx, !!time)
  X1_pred_hat <- dplyr::inner_join(X1_pred_draws, dateXwalk, by = "idx") %>%
    dplyr::rename(y_synth = X1_pred)

  y_synth <- dplyr::bind_rows(y_hat, X1_pred_hat) %>%
    dplyr::select(-idx)

  pre_outcome <- pre_data %>%
    dplyr::select(!!outcome, !!time)
  post_outcome <- post_data %>%
    dplyr::select(!!outcome, !!time)

  y <- dplyr::bind_rows(pre_outcome, post_outcome)
  y_synth <- dplyr::full_join(y_synth, y, by = rlang::as_name(time))
  return(y_synth)
}

#' Get Synthetic Draws in Tidy Format for Multiple Treated Units (3D Array)
#'
#' This internal helper function extracts synthetic draws from a Stan fit object
#' where the draws are stored in a 3D array. It handles multiple treated units and
#' combines the draws with observed outcome data, returning a tidy data frame
#' suitable for analysis or plotting.
#'
#' @param fit A Stan fit object containing the model results.
#' @param data A data frame with the input data, including outcome, time, and unit identifier.
#' @param id The name of the variable in `data` that identifies units (as a string).
#' @param treated_ids A vector of identifiers for the treated units.
#' @param time The name of the time period variable (as a string).
#' @param outcome The name of the outcome variable (as a string).
#' @param intervention The name of the variable in `data` that indicates the intervention time (as a string).
#'
#' @return A data frame containing:
#'   * `draw`: The index of the synthetic draw.
#'   * `id`: The identifier of the treated unit.
#'   * `time`: The time period.
#'   * `y_hat`: The synthetic outcome for the given draw, unit, and time period.
#'
.get_synth_draws3d <- function(fit, data, id, treated_ids, time, outcome,
                               intervention) {
  y_sim_draws <-
    .get_draws3d(
      fit = fit,
      data = data,
      id = id,
      treated_ids = treated_ids,
      time = time,
      outcome = outcome,
      intervention = intervention,
      period = "pre"
    )

  y_pred_draws <-
    .get_draws3d(
      fit = fit,
      data = data,
      id = id,
      treated_ids = treated_ids,
      time = time,
      outcome = outcome,
      intervention = intervention,
      period = "post"
    )

  y_draws <- dplyr::bind_rows(y_sim_draws, y_pred_draws)

  return(y_draws)
}

.get_draws3d <- function(fit, data, id, treated_ids, time, outcome,
                         intervention, period = c("pre", "post")) {
  period <- match.arg(period)
  if (period == "pre") {
    y_sim_draws <- rstan::extract(fit,
      pars = "y_sim"
    )[[1]]
    data <- data %>%
      dplyr::filter(!!time < intervention)
  } else {
    y_sim_draws <- rstan::extract(fit,
      pars = "y_pred"
    )[[1]]
    data <- data %>%
      dplyr::filter(!!time >= intervention)
  }

  dimnames(y_sim_draws) <- list(
    "draw" = seq(1, dim(y_sim_draws)[[1]]),
    "i_idx" = seq(1, dim(y_sim_draws)[[2]]),
    "t_idx" = seq(1, dim(y_sim_draws)[[3]])
  )

  y_sim_draws <- y_sim_draws %>%
    cubelyr::as.tbl_cube() %>%
    tibble::as_tibble() %>%
    dplyr::rename(y_hat = 4)

  wide_df_treated <- data %>%
    dplyr::filter(!!id %in% treated_ids) %>%
    dplyr::select(
      !!id,
      !!time,
      !!outcome
    ) %>%
    tidyr::pivot_wider(names_from = !!id, values_from = !!outcome) %>%
    dplyr::arrange(time)


  iXwalk <- wide_df_treated %>%
    dplyr::select(-!!time) %>%
    colnames() %>%
    dplyr::tibble(id = .) %>%
    dplyr::mutate(i_idx = 1:dplyr::n())

  tXwalk <- wide_df_treated %>%
    dplyr::select(!!time) %>%
    dplyr::mutate(t_idx = 1:dplyr::n())

  y_sim_draws <- y_sim_draws %>%
    dplyr::inner_join(iXwalk, by = "i_idx") %>%
    dplyr::inner_join(tXwalk, by = "t_idx") %>%
    dplyr::select(-i_idx, -t_idx)

  return(y_sim_draws)
}

Try the bsynth package in your browser

Any scripts or data that you put into this service are public.

bsynth documentation built on June 25, 2024, 5:08 p.m.