R/psrwe_powerprior_watt.R

Defines functions get_stan_data_wattcon rwe_ana_con rwe_ana_bin rwe_ana get_stan_data_watt psrwe_powerp_watt

Documented in psrwe_powerp_watt

#' Get posterior samples based on PS-power prior approach (WATT)
#'
#' Draw posterior samples of the parameters of interest for the PS-power prior
#' approach with weights of ATT (WATT)
#'
#'
#' @param dta_psbor A class \code{PSRWE_BOR} object generated by
#'     \code{\link{psrwe_borrow}}.
#' @param v_outcome Column name corresponding to the outcome.
#' @param outcome_type Type of outcomes: \code{continuous} or \code{binary}.
#' @param mcmc_method MCMC sampling via either \code{rstan}, \code{analytic},
#'     or \code{wattcon}.
#' @param tau0_method Method for estimating SD0 via either \code{Wang2019} or
#'     \code{weighted} for continuous outcomes only.
#' @param ipw_method Method for IPW via either \code{Heng.Li} or
#'     \code{Xi.Ada.Wang}.
#' @param seed Random seed.
#' @param ... extra parameters for calling function \code{\link{rwe_stan}}.
#'
#' @return A class \code{PSRWE_RST} list with the following objects
#'
#' \describe{
#'     \item{Observed}{Observed mean and SD of the outcome by group, arm
#'     and stratum}
#'
#'     \item{Control}{A list of estimated mean and SD of the outcome by stratum
#'     in the control arm}
#'
#'     \item{Treatment}{A list of estimated mean and SD of the outcome by
#'     stratum in the treatment arm for RCT}
#'
#'     \item{Effect}{A list of estimated mean and SD of the treatment effect by
#'     stratum for RCT}
#'
#'     \item{Borrow}{Borrowing information from \code{dta_psbor}}
#'
#'     \item{stan_rst}{Result from STAN sampling}
#' }
#'
#' @examples
#'
#' \donttest{
#' data(ex_dta)
#' dta_ps <- psrwe_est(ex_dta,
#'        v_covs = paste("V", 1:7, sep = ""),
#'        v_grp = "Group",
#'        cur_grp_level = "current",
#'        nstrata = 1)
#' ps_borrow <- psrwe_borrow(total_borrow = 30, dta_ps)
#' rst <- psrwe_powerp_watt(ps_borrow, v_outcome = "Y_Bin", seed = 123)}
#'
#' @export
#'
psrwe_powerp_watt <- function(dta_psbor, v_outcome = "Y",
                              outcome_type = c("continuous", "binary"),
                              mcmc_method = c("rstan", "analytic", "wattcon"),
                              tau0_method = c("Wang2019", "weighted"),
                              ipw_method = c("Heng.Li", "Xi.Ada.Wang"),
                              ..., seed = NULL) {

    ## check
    stopifnot(inherits(dta_psbor,
                       what = get_rwe_class("PSDIST")))

    type       <- match.arg(outcome_type)
    stopifnot(v_outcome %in% colnames(dta_psbor$data))

    mcmc_method <- match.arg(mcmc_method)
    stopifnot(dta_psbor$nstrata == 1)
    tau0_method <- match.arg(tau0_method)
    ipw_method <- match.arg(ipw_method)

    if (mcmc_method[1] == "wattcon") {
       if (type[1] != "continuous") {
          stop("The 'wattcon' is for continuous outcomes only.")
       }
    }

    if (ipw_method[1] == "Xi.Ada.Wang") {
        if (outcome_type != "binary" || mcmc_method != "rstan") {
            stop("The 'Xi.Ada.Wang' is for binary outcomes and rstan only.")
        }
    }

    ## save the seed from global if any then set random seed
    old_seed <- NULL
    if (!is.null(seed)) {
        if (exists(".Random.seed", envir = .GlobalEnv)) {
            old_seed <- get(".Random.seed", envir = .GlobalEnv)
        }
        set.seed(seed)
    }

    ## observed
    rst_obs <- get_observed(dta_psbor$data, v_outcome)

    if (mcmc_method[1] %in% c("rstan", "analytic")) {
        ## prepare stan data
        lst_dta <- get_stan_data_watt(dta_psbor, v_outcome,
                                      tau0_method = tau0_method[1],
                                      ipw_method = ipw_method[1])

        ## sampling
        stan_mdl <- if_else("continuous" == type,
                            "powerps",
                            "powerpsbinary")
    } else {
        ## prepare stan data
        lst_dta <- get_stan_data_wattcon(dta_psbor, v_outcome,
                                         ipw_method = ipw_method[1])

        ## sampling
        stan_mdl <- "powerps_wattcon"
    }

    ## run stan or get from analytical solution
    is_rct     <- dta_psbor$is_rct
    trt_post   <- NULL
    trt_thetas <- NULL

    if (mcmc_method[1] %in% c("rstan", "wattcon")) {
        ctl_post   <- rwe_stan(lst_data = lst_dta$ctl,
                               stan_mdl = stan_mdl,
                               ...)
        ctl_thetas <- extract(ctl_post, "thetas")$thetas
        ctl_thetas <- matrix(ctl_thetas, ncol = 1)

        if (is_rct) {
            trt_post <- rwe_stan(lst_data = lst_dta$trt,
                                 stan_mdl = stan_mdl, ...)
            trt_thetas <- extract(trt_post, "thetas")$thetas
            trt_thetas <- matrix(trt_thetas, ncol = 1)
        }
    } else {
        ctl_post   <- rwe_ana(lst_data = lst_dta$ctl,
                              outcome_type = type[1], ...)
        ctl_thetas <- ctl_post$thetas

        if (is_rct) {
            trt_post <- rwe_ana(lst_data = lst_dta$trt,
                                outcome_type = type[1], ...)
            trt_thetas <- trt_post$thetas
        }
    }

    ## summary
    rst_trt    <- NULL
    rst_effect <- NULL
    if (is_rct) {
        rst_trt    <- get_post_theta(trt_thetas, dta_psbor$Borrow$N_Cur_TRT)
        rst_effect <- get_post_theta(trt_thetas - ctl_thetas,
                                     dta_psbor$Borrow$N_Current)
        n_ctl      <- dta_psbor$Borrow$N_Cur_CTL
    } else {
        n_ctl      <- dta_psbor$Borrow$N_Current
    }
    rst_ctl <- get_post_theta(ctl_thetas, n_ctl)

    ## reset the original seed back to the global or
    ## remove the one set within this session earlier.
    if (!is.null(seed)) {
        if (!is.null(old_seed)) {
            invisible(assign(".Random.seed", old_seed, envir = .GlobalEnv))
        } else {
            invisible(rm(list = c(".Random.seed"), envir = .GlobalEnv))
        }
    }

    ## return
    rst <-  list(stan_rst = list(ctl_post = ctl_post,
                                 trt_post = trt_post),
                 Observed  = rst_obs,
                 Control   = rst_ctl,
                 Treatment = rst_trt,
                 Effect    = rst_effect,
                 Borrow    = dta_psbor$Borrow,
                 Total_borrow  = dta_psbor$Total_borrow,
                 Method        = "ps_pp",
                 Method_weight = "WATT",
                 Outcome_type  = type,
                 Prior_type    = "fixed",
                 MCMC_method   = mcmc_method[1],
                 tau0_method   = tau0_method[1],
                 is_rct        = is_rct)

    class(rst) <- get_rwe_class("ANARST")
    rst
}


#' Get data for STAN watt
#'
#'
#' @noRd
#'
get_stan_data_watt <- function(dta_psbor, v_outcome,
                               tau0_method = "Wang2019",
                               ipw_method = "Heng.Li") {
    f_curd <- function(i, d1, d0 = NULL, d0_watt = NULL,
                       tau0_method = "Wang2019") {
        cur_d <- c(N1    = length(d1),
                   YBAR1 = mean(d1),
                   YSUM1 = sum(d1))

        if (is.null(d0)) {
            cur_d <- c(cur_d,
                       N0    = 0,
                       YBAR0 = 0,
                       SD0   = 0)
        } else {
            if (is.null(d0_watt)) {
                d0_watt <- rep(1, length(d0))
            }

            if (tau0_method[1] == "Wang2019") {
                SD0 <- sd(d0)
            } else if (tau0_method[1] == "weighted") {
                # SD0 <- sd(d0 * d0_watt) /      # w_i * Y_i
                #        sqrt(sum(d0_watt)) *    # new nominal sample size
                #        sqrt(length(d0))        # cancel out n0 in stan "powerps"
                SD0 <- sd(d0) *                # Y_i
                       sqrt(sum(d0_watt^2)) /  # w_i^2
                       sum(d0_watt) *          # normalized and nominal
                       sqrt(length(d0))        # cancel out n0 in stan "powerps"
            } else {
                stop("The tau0_method is not implemented.")
            }

            cur_d <- c(cur_d,
                       N0    = length(d0),
                       YBAR0 = sum(d0 * d0_watt) / sum(d0_watt),
                       SD0   = SD0)
        }

        list(stan_d = cur_d,
             y1     = d1,
             inx1   = rep(i, length(d1)))
    }

    is_rct  <- dta_psbor$is_rct
    data    <- dta_psbor$data
    data    <- data[!is.na(data[["_strata_"]]), ]

    strata  <- levels(data[["_strata_"]])
    nstrata <- length(strata)

    ctl_stan_d  <- NULL
    ctl_y1      <- NULL
    ctl_inx1    <- NULL

    trt_stan_d  <- NULL
    trt_y1      <- NULL
    trt_inx1    <- NULL

    for (i in seq_len(nstrata)) {
        cur_01  <- get_cur_d(data, strata[i], v_outcome)
        cur_d1  <- cur_01$cur_d1
        cur_d0  <- cur_01$cur_d0
        cur_d1t <- cur_01$cur_d1t

        cur_01_e    <- get_cur_d(data, strata[i], "_ps_")
        cur_d0_e    <- cur_01_e$cur_d0
        cur_d0_watt <- cur_d0_e / (1 - cur_d0_e)

        ### overwrite watt for ipw_method of Xi.Ada.Wang
        if (ipw_method[1] == "Xi.Ada.Wang") {
            cur_d0_watt <- cur_d0_e * cur_d0_watt
        }

        ctl_cur    <- f_curd(i, cur_d1, cur_d0, cur_d0_watt,
                             tau0_method = tau0_method[1])
        ctl_stan_d <- rbind(ctl_stan_d, ctl_cur$stan_d)
        ctl_y1     <- c(ctl_y1,   ctl_cur$y1)
        ctl_inx1   <- c(ctl_inx1, ctl_cur$inx1)

        if (is_rct) {
            trt_cur    <- f_curd(i, cur_d1t, tau0_method = tau0_method[1])
            trt_stan_d <- rbind(trt_stan_d, trt_cur$stan_d)
            trt_y1     <- c(trt_y1,   trt_cur$y1)
            trt_inx1   <- c(trt_inx1, trt_cur$inx1)
        }
    }

    ctl_lst_data  <- list(S     = nstrata,
                          A     = dta_psbor$Total_borrow,
                          RS    = as.array(dta_psbor$Borrow$Proportion),
                          FIXVS = 1,
                          N0    = as.array(ctl_stan_d[, "N0"]),
                          N1    = as.array(ctl_stan_d[, "N1"]),
                          YBAR0 = as.array(ctl_stan_d[, "YBAR0"]),
                          SD0   = as.array(ctl_stan_d[, "SD0"]),
                          TN1   = length(ctl_y1),
                          Y1    = ctl_y1,
                          INX1  = ctl_inx1,
                          YBAR1 = as.array(ctl_stan_d[, "YBAR1"]),
                          YSUM1 = as.array(ctl_stan_d[, "YSUM1"]))

    trt_lst_data <- NULL
    if (is_rct) {
        trt_lst_data  <- list(S     = nstrata,
                              A     = 0,
                              RS    = as.array(dta_psbor$Borrow$Proportion),
                              FIXVS = 1,
                              N0    = as.array(trt_stan_d[, "N0"]),
                              N1    = as.array(trt_stan_d[, "N1"]),
                              YBAR0 = as.array(trt_stan_d[, "YBAR0"]),
                              SD0   = as.array(trt_stan_d[, "SD0"]),
                              TN1   = length(trt_y1),
                              Y1    = trt_y1,
                              INX1  = trt_inx1,
                              YBAR1 = as.array(trt_stan_d[, "YBAR1"]),
                              YSUM1 = as.array(trt_stan_d[, "YSUM1"]))
    }

    list(ctl = ctl_lst_data,
         trt = trt_lst_data)
}


#' RWE analytical posterior
#'
#'
#' @noRd
#'
rwe_ana <- function(lst_data, outcome_type, ...) {
    if (outcome_type[1] == "binary") {
        rst <- rwe_ana_bin(lst_data, ...)
    } else if (outcome_type[1] == "continuous") {
        rst <- rwe_ana_con(lst_data, ...)
    } else {
        stop("The outcome_type is not implemented.")
    }
    return(rst)
}


#' RWE binary analytical posterior
#'
#'
#' @noRd
#'
rwe_ana_bin <- function(lst_data,
                        n_resample = 4000,
                        beta_a_init = 0.01,
                        beta_b_init = 0.01,
                        ...) {
    ns <- lst_data$S
    if (lst_data$N0 > 0) {
        alpha0 <- lst_data$A * lst_data$RS / lst_data$N0
        alpha0[alpha0 > 1] <- 1
    } else{
        alpha0 <- rep(1, ns)
    }

    n0 <- lst_data$N0
    n1 <- lst_data$N1
    p0 <- lst_data$YBAR0  # watt
    p1 <- lst_data$YBAR1

    ## posterior
    beta_a0   <- alpha0 * n0 * p0 + beta_a_init
    beta_b0   <- alpha0 * n0 * (1 - p0) + beta_b_init
    beta_a    <- n1 * p1 + beta_a0
    beta_b    <- n1 * (1 - p1) + beta_b0
    post_mean <- beta_a / (beta_a + beta_b)
    post_var  <- beta_a * beta_b / ((beta_a + beta_b)^2 *
                                    (beta_a + beta_b + 1))
    post_dsn  <- list(beta_a0 = beta_a0,
                      beta_b0 = beta_b0,
                      beta_a  = beta_a,
                      beta_b  = beta_b,
                      mean    = post_mean,
                      var     = post_var)

    ## posterior samples
    thetas <- matrix(rbeta(ns * n_resample, post_dsn$beta_a, post_dsn$beta_b),
                     nrow = ns, ncol = n_resample)
    thetas <- t(thetas)

    ## return
    rst <- list(post_dsn = post_dsn,
                thetas = thetas)
    return(rst)
}


#' RWE continuous analytical posterior
#'
#'
#' @noRd
#'
rwe_ana_con <- function(lst_data,
                        n_resample = 4000,
                        ...) {
    ns <- lst_data$S
    if (lst_data$N0 > 0) {
        alpha0 <- lst_data$A * lst_data$RS / lst_data$N0
        alpha0[alpha0 > 1] <- 1
    } else{
        alpha0 <- rep(1, ns)
    }

    n0 <- lst_data$N0
    n1 <- lst_data$N1
    ybar0 <- lst_data$YBAR0  # watt
    ybar1 <- lst_data$YBAR1
    sigma0 <- lst_data$SD0   # watt
    sigma1 <- sd(lst_data$Y1)

    ## posterior
    if (n0 > 0) {
      post_var <- 1 / (n1 / sigma1^2 + alpha0 * n0 / sigma0^2)
      post_mean <- (n1 / sigma1^2 * ybar1 + alpha0 * n0 / sigma0^2 * ybar0) *
                   post_var
    } else {
      post_var <- 1 / (n1 / sigma1^2)
      post_mean <- (n1 / sigma1^2 * ybar1) *
                   post_var
    }
    post_dsn  <- list(n0     = n0,
                      n1     = n1,
                      ybar0  = ybar0,
                      ybar1  = ybar1,
                      sigma0 = sigma0,
                      sigma1 = sigma1,
                      mean   = post_mean,
                      var    = post_var)

    ## posterior samples
    thetas <- matrix(rnorm(ns * n_resample, post_dsn$mean, sqrt(post_dsn$var)),
                     nrow = ns, ncol = n_resample)
    thetas <- t(thetas)

    ## return
    rst <- list(post_dsn = post_dsn,
                thetas = thetas)
    return(rst)
}


#' Get data for STAN watt for continuous outcomes
#'
#'
#' @noRd
#'
get_stan_data_wattcon <- function(dta_psbor, v_outcome) {
    is_rct  <- dta_psbor$is_rct
    data    <- dta_psbor$data
    data    <- data[!is.na(data[["_strata_"]]), ]
    strata  <- levels(data[["_strata_"]])
    A       <- dta_psbor$Total_borrow

    ctl_y1  <- NULL
    trt_y1  <- NULL

    cur_01  <- get_cur_d(data, strata[1], v_outcome)
    cur_d1  <- cur_01$cur_d1
    cur_d0  <- cur_01$cur_d0
    cur_d1t <- cur_01$cur_d1t

    cur_01_e    <- get_cur_d(data, strata[1], "_ps_")
    cur_d0_e    <- cur_01_e$cur_d0
    cur_d0_watt <- cur_d0_e / (1 - cur_d0_e)

    ctl_y1        <- cur_d1
    ctl_y0        <- cur_d0
    ctl_watt_di   <- cur_d0_watt / sum(cur_d0_watt)

    if (is_rct) {
        trt_y1 <- cur_d1t
    }

    ctl_lst_data  <- list(A         = A,
                          # N0        = length(ctl_y0),
                          # Y0        = as.array(ctl_y0),
                          Y0Tilde   = sum(ctl_watt_di * ctl_y0),
                          SD0       = sd(ctl_y0),
                          # A_WATT_DI = A * as.array(ctl_watt_di),
                          N1        = length(ctl_y1),
                          Y1        = as.array(ctl_y1))

    trt_lst_data <- NULL
    if (is_rct) {
        trt_lst_data  <- list(A         = 0,
                              # N0        = 1,
                              # Y0        = as.array(0),
                              Y0Tilde   = 0,
                              SD0       = 0,
                              # A_WATT_DI = as.array(0),
                              N1        = length(trt_y1),
                              Y1        = as.array(trt_y1))
    }

    list(ctl = ctl_lst_data,
         trt = trt_lst_data)
}
olssol/psrwe documentation built on July 17, 2024, 4:06 p.m.