R/A1-weights_med.R

Defines functions .get_m.smd .get_c.smd .get_smd.med .plot_balance.med .plot_med .compute_weights.med .check_plot.med .clean_weights.med .prep_med weights_med

Documented in .check_plot.med .clean_weights.med .compute_weights.med .get_smd.med .plot_balance.med .plot_med .prep_med weights_med

#### OK  weights_med ###########################################################


#' Estimate weights for natural (in)direct effects estimation
#'
#' Estimates inverse probability weights that form pseudo treated and control samples and cross-world weights that form (one or two) pseudo cross-world samples.
#' @inheritParams estimate_effects
#' @param a.c.form Formula for the P(A|C) model (the propensity score model).
#' @param a.cm.form Formula for the P(A|C,M) model.
#' @param max.stabilized.wt Max stabilized weight allowed. Larger weights are truncated to this level. Default is 30.
#' @param plot Whether to output weight distribution and balance plots. Defaults to TRUE.
#' @param c.order Order in which covariates are to be plotted. If not specify, use the order that appears in \code{a.c.form}.
#' @param m.order Order in which mediators are to be plotted. If not specify, use the order that appears in \code{a.cm.form}.
#' @param c.std Covariates whose mean differences are to be standardized in balance plot. Ignore if \code{plot==FALSE}.
#' @param m.std Mediators whose mean differences are to be standardized in balance plot. Ignore if \code{plot==FALSE}.
#' @return A list including\itemize{
#' \item{w.dat}{A data frame for the pseudo samples with estimated weights.}
#' \item{plot.wts}{A plot of the distributions of the weights.}
#' \item{plot.balance}{A plot of the balance in covariates and mediators of the pseudo samples.}
#' }
#' @family weighting schemes
#' @export

weights_med <- function(
    data,
    s.wt.var = NULL,
    cross.world = "10", # options: "10", "01", "both"
    a.c.form,
    a.cm.form,

    max.stabilized.wt = 30,

    plot = TRUE,
    c.order = NULL,
    m.order = NULL,
    c.std = NULL,
    m.std = NULL
) {

    # CLEAN INPUTS

    c.vars <- m.vars <- NULL

    .prep_med()


    # COMPUTE WEIGHTS

    w.dat <- .compute_weights.med(data              = data,
                                   cross.world       = cross.world,
                                   a.c.form          = a.c.form,
                                   a.cm.form         = a.cm.form,
                                   max.stabilized.wt = max.stabilized.wt)


    # MAKE PLOTS

    if (plot)  plots <- .plot_med(w.dat = w.dat,
                                  c.vars = c.vars,
                                  m.vars = m.vars,
                                  c.std = c.std,
                                  m.std = m.std)



    # OUTPUT

    if (plot)  return(list(w.dat = w.dat, plots  = plots))

    return(w.dat)

}



#### OK  .prep_med #############################################################

#' Internal functions: preparing inputs before main computing
#'
#' A set of internal functions to do the prep work for the main functions before the key computing takes place.
#' @name dot-prep
#' @keywords internal
NULL

#' @rdname dot-prep
#' @order 1

.prep_med <- function() {

    top.env <- parent.frame()

    .setup_data(top.env)

    .clean_cross.world(top.env)

    .clean_weights.med(top.env)

    if (top.env$plot) .check_plot.med(top.env)
}







#### OK  .clean_weights.med ####################################################

#' Internal functions: get treatment variable and clean inputs for weighting
#'
#' Called by \code{.prep_} functions
#' @inheritParams env-block
#' @name dot-clean_weights
#' @keywords internal
NULL



#' @rdname dot-clean_weights
#' @order 2
#' @details \code{.clean_weights.med()} is used by \code{.prep_med()}.

.clean_weights.med <- function(env) {

    if (!is.numeric(env$max.stabilized.wt))
        stop("max.stabilized.wt must be a numeric value.")


    a.c  <- env$a.c.form
    a.cm <- env$a.cm.form


    if (!formula(a.c)[[2]]==formula(a.cm)[[2]])
        stop("Treatment variable not the same in a.c.form and a.cm.form.")

    if (!all(all.vars(formula(a.c)[[3]]) %in%
             all.vars(formula(a.cm)[[3]])))
        stop("Some C variable(s) in a.c.form but not in a.cm.form")

    a.var  <- all.vars(formula(a.c)[[2]])
    c.vars <- all.vars(formula(a.c)[[3]])
    m.vars <- setdiff(all.vars(formula(a.cm)[[3]]),
                      all.vars(formula(a.c)[[3]]))

    stray.vars <- setdiff(c(a.var, c.vars, m.vars), names(env$data))

    if (length(stray.vars)>0)
        stop(paste("Variable(s)", paste(stray.vars, collapse = ", "), "not found in dataset."))

    if (!is_binary01(env$data[, a.var]))
        stop(paste("Treatment variable (", a.var, ") must be numeric and in binary 0/1 form."))

    env$data$.a <- env$data[, a.var]

    env$c.vars <- c.vars
    env$m.vars <- m.vars

}




#### OK  .check_plot.med #######################################################

#' Internal functions: check inputs for plotting
#'
#' Functions called by \code{.prep_} functions within \code{weights_} an \code{estimate_} functions.
#' @inheritParams env-block
#' @name dot-check_plot
#' @keywords internal
NULL

#' @rdname dot-check_plot
#' @order 1
#' @details \code{.check_plot.med()} is called by \code{.prep_med()}.

.check_plot.med <- function(env) {

    c.vars <- env$c.vars
    m.vars <- env$m.vars
    c.order <- env$c.order
    m.order <- env$m.order
    c.std <- env$c.std
    m.std <- env$m.std


    if (!is.null(c.order)) {

        if (!setequal(c.vars, c.order)) {
            warning("Variables in c.order do not match covariates from a.c.form. Ignoring c.order.")
        } else
            env$c.vars <- c.vars <- c.order
    }


    if (!is.null(m.order)) {

        if (!setequal(m.vars, m.order)) {
            warning("Variables in m.order do not match mediators obtained from the combination of a.c.form and a.cm.form. Ignoring m.order.")
        } else
            env$m.vars <- m.vars <- m.order
    }



    if (is.null(c.std) && is.null(m.std)) {

        maybe.cont <- sapply(c(c.vars, m.vars),
                             function(z) maybe_continuous(env$data[, z]))

        if (any(maybe.cont))
            message(paste("Consider whether the balance plot should use standardized mean differences for numeric covariate/mediators",
                          paste(c(c.vars, m.vars)[which(maybe.cont)],
                                collapse = ", "),
                          "(if they are continuous variables).",
                          "To turn off this message, specify c.std=\"\", m.std=\"\"."))

        return()
    }


    if (length(c.std==1) && c.std=="" &&
        length(m.std==1) && m.std=="")
        return()




    c.std <- setdiff(c.std, "")
    m.std <- setdiff(m.std, "")

    if (length(setdiff(c.std, c.vars))>0)
        stop("Variables specified in c.std are not all contained in model formula a.c.form.")

    if (length(setdiff(m.std, m.vars))>0)
        stop("Variables specified in m.std are not all part of M variables based on formulas a.c.form and a.cm.form.")



    vars.std <- c(c.std, m.std)

    ok.std <- sapply(vars.std, function(z) maybe_continuous(env$data[, z]))

    if (!all(ok.std))
        stop(paste("Check variable(s)",
                   paste(vars.std[which(!ok.std)], collapse = ", "),
                   "before proceeding. Only include continuous variables in c.std and m.std."))

}





#### OK  .compute_weights.med ##################################################

#' (For maintainer) Compute weights
#'
#' Internal functions that compute weights using cleaned inputs.
#' @param data Prepared internal dataset (e.g., result of \code{.pred_data()}).
#' @param cross.world Cleaned cross.world (e.g., from \code{.clean_cross.world()}).
#' @param a.c.form,a.cm.form Cleaned formulas (e.g., from \code{.clean_a.forms()}).
#' @param max.stabilized.wt A scalar.
#' @return A data frame for the pseudo samples with estimated weights, in stacked format.
#' @name dot-compute_weights
#' @keywords internal
NULL

#' @rdname dot-compute_weights
#' @order 1

.compute_weights.med <- function(
    data,
    cross.world,
    a.c.form,
    a.cm.form,
    max.stabilized.wt

) {

    a.c.fu <- glm(formula = a.c.form,
                  data    = data,
                  weights = data$.s.wt,
                  family  = quasibinomial)
    a.cm.fu <- glm(formula = a.cm.form,
                   data    = data,
                   weights = data$.s.wt,
                   family  = quasibinomial)



    max.wt <- lapply(list(control=0, treat=1), function(z) {
        max.stabilized.wt * (sum(data$.s.wt) / sum(data$.s.wt * (data$.a==z)))
    })




    p00 <- data[data$.a==0, ];  p00$.samp <- "p00"
    p11 <- data[data$.a==1, ];  p11$.samp <- "p11"

    p00$.w.wt <- 1 + exp( predict(a.c.fu, newdata = p00, type = "link"))
    p11$.w.wt <- 1 + exp(-predict(a.c.fu, newdata = p11, type = "link"))

    p00$.w.wt <- .trunc_right(p00$.w.wt, max.wt$control)
    p11$.w.wt <- .trunc_right(p11$.w.wt, max.wt$treat)

    p00$.f.wt <- p00$.s.wt * p00$.w.wt
    p11$.f.wt <- p11$.s.wt * p11$.w.wt

    out <- rbind(p00, p11)



    if ("10" %in% cross.world) {

        p10 <- data[data$.a==1, ];  p10$.samp <- "p10"

        p10$.w.wt <-
            exp(-predict(a.cm.fu, newdata = p10, type = "link")) *    # odds
            (1 + exp(predict(a.c.fu, newdata = p10, type = "link")))  # inv.prob

        p10$.w.wt <- .trunc_right(p10$.w.wt, max.wt$treat)

        p10$.f.wt <- p10$.s.wt * p10$.w.wt

        out <- rbind(out, p10)
    }


    if ("01" %in% cross.world) {

        p01 <- data[data$.a==0, ];  p01$.samp <- "p01"

        p01$.w.wt <-
            exp(predict(a.cm.fu, newdata = p01, type = "link")) *     # odds
            (1 + exp(-predict(a.c.fu, newdata = p01, type = "link"))) # inv.prob

        p01$.w.wt <- .trunc_right(p01$.w.wt, max.wt$control)

        p01$.f.wt <- p01$.s.wt * p01$.w.wt

        out <- rbind(out, p01)
    }

    out

}




#### OK  .plot_med #############################################################

#' (For maintainer) Plot weight distributions and mean balance
#'
#' @name dot-plot_w.dat
#' @keywords internal
NULL


#' @rdname dot-plot_w.dat
#' @order 1
#' @param w.dat Weighted data
#' @param c.vars Names of covariates, already cleaned.
#' @param m.vars Names of mediators, already cleaned.
#' @param c.std Names of covariates whose mean differences are to be standardized, already cleaned.
#' @param m.std Names of mediators whose mean differences are to be standardized, already cleaned.
#' @param key.balance Whether all balance components are crucial to the estimator using the weighting method, defaults to FALSE. See relevance in \bold{Value} section.

.plot_med <- function(w.dat,
                      c.vars,
                      m.vars,
                      c.std,
                      m.std,
                      key.balance = FALSE) {


    out <- .plot_wt_dist(w.dat)


    if (key.balance) { bal.name <- "key.balance"
    } else           { bal.name <- "balance"
    }

    out[[bal.name]] <- .plot_balance.med(w.dat = w.dat,
                                         c.vars = c.vars,
                                         m.vars = m.vars,
                                         c.std = c.std,
                                         m.std = m.std,
                                         key.balance = key.balance)

    out
}





#### OK  .plot_balance.med #####################################################

#' (For maintainer) Plot balance
#'
#' Internal functions that makes balance plots for weighted data.
#' @inheritParams .plot_med
#' @importFrom ggplot2 ggplot aes geom_vline geom_point scale_color_manual scale_shape_manual labs theme_bw facet_wrap xlim
#' @importFrom rlang .data
#' @return Plot of balance on the means of covariates and mediators between relevant pseudo samples and full sample. If \code{estimate_wtd==TRUE}, add "(anchor)" and "(for <effect>)" notes to plot labels to draw attention to how each balance matters to the estimator.
#' @name dot-plot_balance
#' @keywords internal
NULL

#' @rdname dot-plot_balance
#' @order 1

.plot_balance.med <- function(w.dat,
                              c.vars,
                              m.vars,
                              c.std,
                              m.std,
                              key.balance = FALSE) {

    smd.dat <- .get_smd.med(w.dat = w.dat,
                            c.vars = c.vars,
                            m.vars = m.vars,
                            c.std = c.std,
                            m.std = m.std)

    if (key.balance)
        smd.dat$contrast <-
        factor(smd.dat$contrast,
               levels = c("p11 - full", "p00 - full", "p11 - p00",
                          "p10 - full", "p11 - p10", "p10 - p00",
                          "p01 - full", "p11 - p01", "p01 - p00"),
               labels = c("p11 - full  (anchor)",
                          "p00 - full  (anchor)",
                          "p11 - p00  (for TE)",
                          "p10 - full  (anchor)",
                          "p11 - p10  (for NIE1/NDE0)",
                          "p10 - p00  (for NIE1/NDE0)",
                          "p01 - full  (anchor)",
                          "p11 - p01  (for NIE0/NDE1)",
                          "p01 - p00  (for NIE0/NDE1)"))

    ggplot(data = smd.dat,
           aes(x = .data$mean.diff,
               y = factor(.data$variable,
                          levels = rev(levels(.data$variable))))) +
        geom_vline(xintercept = 0,
                   color = "gray60") +
        geom_point(aes(color = .data$var.type,
                       shape = .data$contrast.type),
                   fill = "white",
                   size = 1.5,
                   stroke = .5) +
        labs(x = "differences in means",
             y = "") +
        scale_color_manual(name = "", values = c("black", "magenta")) +
        scale_shape_manual(name = "", values = c(21, 19)) +
        theme_bw() +
        xlim(min(c(-.3, smd.dat$mean.diff)),
             max(c( .3, smd.dat$mean.diff))) +
        facet_wrap(~.data$contrast, ncol = 3)

}



#### OK  .get_smd.med ##########################################################

#' (For maintainer) Compute  (standardized) mean differences
#'
#' Internal functions called by \code{.plot_balance.} functions to compute (standardized) mean differences to be plotted.
#' @return A data frame containing the (standardized) mean differences.
#' @keywords internal
#' @name dot-get_smd
#' @keywords internal
NULL

#' @rdname dot-get_smd
#' @order 1
#' @param w.dat Dataset for pseudo samples, already cleaned and dummy coded.
#' @param c.vars Names of covariates, checked and dummied.
#' @param m.vars Names of mediators, checked and dummied.
#' @param c.std Covariates to be standardized, already checked.
#' @param m.std Mediators to be standardized, already checked.

.get_smd.med <- function(w.dat,
                         c.vars,
                         m.vars,
                         c.std,
                         m.std) {

    tmp <- .dummies_2sets(data = w.dat,
                          columns1 = c.vars,
                          columns2 = m.vars)

    w.dat <- tmp$data
    c.vars <- tmp$columns1
    m.vars <- tmp$columns2;  rm(tmp)

    c.smd <- .get_c.smd(w.dat      = w.dat,
                        vars        = c.vars,
                        standardize = c.std)

    m.smd <- .get_m.smd(w.dat      = w.dat,
                        vars        = m.vars,
                        standardize = m.std)

    smd <- rbind(c.smd, m.smd)

    c.vars <- ifelse(c.vars %in% c.std, paste0("*", c.vars), c.vars)
    m.vars <- ifelse(m.vars %in% m.std, paste0("*", m.vars), m.vars)

    smd$variable <- factor(smd$variable, levels = c(c.vars, m.vars))

    smd
}




#### .get_c.smd & .get_m.smd ###################################################

#' Covariate or mediator (standardized) mean differences
#'
#' Internal functions called by \code{.plot_balance()} or \code{.plot_balance.Ypred()} to compute covariate/mediator (standardized) mean differences required for balance plotting.
#' @param w.dat Data for pseudo samples, e.g., output of \code{.compute_weights.med()} or \code{.compute_weights.te()} that has had categorical variables dummy coded.
#' @param vars Names of numeric variables on which to compute mean differences.
#' @param w.dat For \code{.get_smd.Ypred()}: data for (pseudo) subsamples, e.g., output of \code{.compute_weights.Ypred()} that has had categorical variables dummy coded.
#' @param m.vars For \code{.get_smd.Ypred()}: names of mediators.
#' @param standardize Names of variables for which the mean differences are to be standardized.
#' @return A data frame holding the (standardized) mean differences.
#' @return With \code{.get_c.smd()}, these are covariate mean differences between pseudo samples and full sample and across pseudo samples.
#' @return With \code{.get_m.smd()}, these are mediator mean differences between relevant pseudo samples (i.e., between \code{p10} and \code{p00} and/or between \code{p01} and \code{p11}).
#' @return With \code{.get_smd.Ypred()}, these are covariate and mediator mean differences between pseudo cross-world SUBsamples and the original subsamples they mimic (i.e., between \code{s10} and \code{s00} and/or between \code{s01} and \code{s11}).
#' @noRd


.get_c.smd <- function(w.dat,
                       vars,
                       standardize) {

    w.dat <- w.dat[, c(".samp", ".s.wt", ".f.wt", vars)]

    yes.p10 <- (sum(w.dat$.samp=="p10") > 0)
    yes.p01 <- (sum(w.dat$.samp=="p01") > 0)



    # separate pseudo samples and full sample
    p11 <- w.dat[w.dat$.samp=="p11", ]
    p00 <- w.dat[w.dat$.samp=="p00", ]

    if (yes.p10) p10 <- w.dat[w.dat$.samp=="p10", ]
    if (yes.p01) p01 <- w.dat[w.dat$.samp=="p01", ]

    rm(w.dat)

    full <- rbind(p11, p00)
    full$.f.wt <- full$.s.wt



    # denominator for standardization
    std.denom <- sapply(vars, function(z) {

        if (!z %in% standardize) return(1)

        .get_sd.pooled(variable = z, dat1 = p11, dat0 = p00)
    })



    # compute (standardized) mean differences
    smd <- lapply(list(unw = ".s.wt", wtd = ".f.wt"), function(w) {

        means.fu  <- sapply(vars, function(z) .wtd_mean(full[, z], full[, w]))
        means.p11 <- sapply(vars, function(z) .wtd_mean(p11[, z],  p11[, w]))
        means.p00 <- sapply(vars, function(z) .wtd_mean(p00[, z],  p00[, w]))

        diff <- cbind(p11.full = (means.p11 - means.fu) / std.denom,
                      p00.full = (means.p00 - means.fu) / std.denom,
                      p11.p00  = (means.p11 - means.p00) / std.denom)


        if (yes.p10) {

            means.p10 <- sapply(vars,
                                function(z) .wtd_mean(p10[, z], p10[, w]))

            diff <- cbind(diff,
                          p10.full = (means.p10 - means.fu) / std.denom,
                          p11.p10  = (means.p11 - means.p10) / std.denom,
                          p10.p00  = (means.p10 - means.p00) / std.denom)
            rm(means.p10)
        }

        if (yes.p01) {

            means.p01 <- sapply(vars,
                                function(z) .wtd_mean(p01[, z], p01[, w]))

            diff <- cbind(diff,
                          p01.full = (means.p01 - means.fu) / std.denom,
                          p11.p01  = (means.p11 - means.p01) / std.denom,
                          p01.p00  = (means.p01 - means.p00) / std.denom)
            rm(means.p01)
        }

        diff.names <- colnames(diff)

        diff <- data.frame(diff, row.names = NULL)
        diff$variable <- vars

        diff <- .reshape_gather(data     = diff,
                               columns  = diff.names,
                               key      = "contrast",
                               value    = "mean.diff",
                               wide.row = FALSE)

        if (w==".f.wt") diff$contrast.type <- "weighted"
        if (w==".s.wt") diff$contrast.type <- "pre-weighting"

        diff
    })

    smd$unw <- smd$unw[!smd$unw$contrast %in% c("p11.p10", "p01.p00"), ]

    smd <- do.call("rbind", smd)

    rownames(smd) <- NULL


    smd$contrast <-
        factor(smd$contrast,
               levels = c("p11.full", "p00.full", "p11.p00",
                          "p10.full", "p11.p10", "p10.p00",
                          "p01.full", "p11.p01", "p01.p00"),
               labels = c("p11 - full", "p00 - full", "p11 - p00",
                          "p10 - full", "p11 - p10",  "p10 - p00",
                          "p01 - full", "p11 - p01",  "p01 - p00"))

    smd$var.type <- "covariate"

    smd <- smd[,c("var.type", "variable", "contrast.type", "contrast", "mean.diff")]

    smd$variable <- ifelse(smd$variable %in% standardize,
                           paste0("*", smd$variable),
                           smd$variable)

    smd

}


#' @noRd
.get_m.smd <- function(w.dat,
                       vars,
                       standardize) {

    w.dat <- w.dat[, c(".samp", ".s.wt", ".f.wt", vars)]

    yes.p10 <- (sum(w.dat$.samp=="p10") > 0)
    yes.p01 <- (sum(w.dat$.samp=="p01") > 0)


    # separate pseudo samples
    p00 <- w.dat[w.dat$.samp=="p00", ]
    p11 <- w.dat[w.dat$.samp=="p11", ]

    if (yes.p10) p10 <- w.dat[w.dat$.samp=="p10", ]
    if (yes.p01) p01 <- w.dat[w.dat$.samp=="p01", ]

    rm(w.dat)


    # denominator for standardization
    std.denom <- sapply(vars, function(z) {

        if (!z %in% standardize) return(1)

        .get_sd.pooled(variable = z, dat1 = p11, dat0 = p00)
    })



    # compute mean differences
    smd <- lapply(list(unw = ".s.wt", wtd = ".f.wt"), function(w) {

        diff <- NULL

        if (yes.p10) {

            means.p10 <- sapply(vars,
                                function(z) .wtd_mean(p10[, z], p10[, w]))
            means.p00 <- sapply(vars,
                                function(z) .wtd_mean(p00[, z],  p00[, w]))

            diff <- cbind(diff,
                          p10.p00 = (means.p10 - means.p00) / std.denom)
        }

        if (yes.p01) {

            means.p11 <- sapply(vars,
                                function(z) .wtd_mean(p11[, z],  p11[, w]))
            means.p01 <- sapply(vars,
                                function(z) .wtd_mean(p01[, z], p01[, w]))

            diff <- cbind(diff,
                          p11.p01 = (means.p11 - means.p01) / std.denom)
        }

        diff.names <- colnames(diff)

        diff <- data.frame(diff, row.names = NULL)
        diff$variable <- vars

        diff <- .reshape_gather(data     = diff,
                               columns  = diff.names,
                               key      = "contrast",
                               value    = "mean.diff",
                               wide.row = FALSE)

        if (w==".f.wt") diff$contrast.type <- "weighted"
        if (w==".s.wt") diff$contrast.type <- "pre-weighting"

        diff
    })

    smd <- do.call("rbind", smd)

    rownames(smd) <- NULL


    smd$contrast <-
        factor(smd$contrast,
               levels = c("p10.p00", "p11.p01"),
               labels = c("p10 - p00", "p11 - p01"))

    smd$var.type <- "mediator"

    smd <- smd[,c("var.type", "variable", "contrast.type", "contrast", "mean.diff")]

    smd$variable <- ifelse(smd$variable %in% standardize,
                           paste0("*", smd$variable),
                           smd$variable)

    smd

}




#### OK  param env #############################################################

#' @param env An environment
#' @name env-block
#' @keywords internal
NULL
trangnguyen74/tnMediation documentation built on May 3, 2023, 6:58 a.m.