R/trt_effects.R

Defines functions mutau_cor treatment_effect

Documented in mutau_cor treatment_effect

#' Average treatment effect in a baggr model
#'
#' @param bg a [baggr] model
#' @param summary logical; if TRUE returns summary statistics as explained below.
#' @param interval uncertainty interval width (numeric between 0 and 1), if summarising
#' @param transform a transformation to apply to the result, should be an R function;
#'                  (this is commonly used when calling `treatment_effect` from other
#'                  plotting or printing functions)
#' @param message logical; use to disable messages prompted by using with
#'                no pooling models
#' @return A list with 2 vectors (corresponding to MCMC samples)
#'         `tau` (mean effect) and `sigma_tau` (SD). If `summary=TRUE`,
#'         both vectors are summarised as mean and lower/upper bounds according to
#'         `interval`
#' @export
#' @importFrom rstan extract


treatment_effect <- function(bg, summary = FALSE,
                             transform = NULL, interval = .95,
                             message = TRUE) {
  check_if_baggr(bg)

  if(bg$pooling == "none"){
    if(message)
      message("There is no treatment effect estimated when pooling = 'none'.")
    return(list(tau = as.numeric(NA), sigma_tau = as.numeric(NA)))
  }
  if(bg$model %in% c("rubin", "mutau", "mutau_full", "logit", "rubin_full")) {
    tau <- rstan::extract(bg$fit, pars="mu")[[1]]
    if(bg$model %in% c("rubin", "logit"))
      tau <- c(tau)
    if(bg$model %in% c("mutau", "mutau_full"))
      tau <- tau[,1,2]

    if(bg$pooling == "partial"){
      if(bg$model %in% c("mutau", "mutau_full"))
        sigma_tau <- rstan::extract(bg$fit, pars="hypersd[1,2]")[[1]]
      else
        sigma_tau <- rstan::extract(bg$fit, pars="tau")[[1]]
      if(bg$model %in% c("rubin", "logit"))
        sigma_tau <- c(sigma_tau)
    }
    if(bg$pooling == "full")
      sigma_tau <- 0 #same dim as tau, but by convention set to 0
  } else if(bg$model == "quantiles") {
    # In this case we have N columns = N quantiles
    tau <- as.matrix(bg$fit, "beta_1")
    # only take diagonals:
    sigma_tau <- t(apply(rstan::extract(bg$fit, "Sigma_1")[[1]], 1, diag))
    # in model with correlation, we have Var(), not SD()
    sigma_tau <- sqrt(sigma_tau)
  } else if(bg$model == "sslab") {
    mean_params <- c("tau[1]", "tau[2]",
                     "sigma_TE[1]", "sigma_TE[2]",
                     "kappa[1,1,2]", "kappa[1,2,2]")
    sigma_params <- c("hypersd_tau[1]", "hypersd_tau[2]",
                      "hypersd_sigma_TE[1]", "hypersd_sigma_TE[2]",
                      "hypersd_kappa[1,1,2]", "hypersd_kappa[1,2,2]")
    tau <- as.matrix(bg$fit, mean_params)
    sigma_tau <- as.matrix(bg$fit, sigma_params)

  } else {
    stop("Can't calculate treatment effect for this model.")
  }

  if(length(bg$effects) > 1)
    colnames(tau) <- colnames(sigma_tau) <- bg$effects

  if(!is.null(transform)){
    tau <- do.call(transform, list(tau))
    sigma_tau <- NA # by convention we set it to NA so that people don't transform
    # and then do operations on it by accident
  }
  if(summary) {
    tau <- mint(tau, int=interval, median=TRUE, sd = TRUE)
    sigma_tau <- mint(sigma_tau, int=interval, median=TRUE, sd = TRUE)
  }

  return(list(tau = tau, sigma_tau = sigma_tau))
}



#' Correlation between mu and tau in a baggr model
#'
#' @param bg  a [baggr] model where `model = "mutau"`
#' @param summary logical; if TRUE returns summary statistics as explained below.
#' @param interval uncertainty interval width (numeric between 0 and 1), if summarising
#' @return a vector of values
#' @export
mutau_cor <- function(bg,
                      summary = FALSE,
                      interval = 0.95) {
  # m <- matrix(apply(as.matrix(bg$fit, "L_Omega"), 2, mean), 2, 2)
  # h <- apply(as.matrix(bg$fit, "hypersd"), 2, mean)
  # (m %*% t(m))[2,1]
  # diag(h) %*% (m %*% t(m)) %*% diag(h)

  m <- as.matrix(bg$fit, "L_Omega")
  # we want entry (2,1) in LL^T which is simply (1,1)*(2,1)
  # in a lower-triangular matrix; and (1,1) should be == 1 everywhere
  if(!all(m[,1] == 1))
    warning("Error with correlation matrix in the mu&tau model. Inspect L_Omega parameters.")
  if(summary)
    return(mint(m[,2], int=interval, median=TRUE, sd = TRUE))
  else
    return(m[,2])

}

Try the baggr package in your browser

Any scripts or data that you put into this service are public.

baggr documentation built on March 31, 2023, 10:02 p.m.