R/plot_me.R

Defines functions plot_me

Documented in plot_me

#' @title Visualizing Marginal Effects
#' @description
#' \code{plot_me} produces a heatmap of the marginal effects of variable \code{x} plotted against the combinations of \code{x} and \code{over} and
#' adds a scatterplot representing the joint distribution these two variables in the given sample. The size of the markers represents the number of observations.
#' @param x a character string representing the name of the main variable of interest. Marginal effects will be computed for this variable.
#' @param over a character string representing the name of the conditionning variable. DAME will be computed for the bins long the range of this variable.
#' @param model fitted model object. The package works best with GLM objects and will extract the formula, dataset, family, coefficients, and
#' the QR components of the design matrix if arguments \code{formula}, \code{data}, \code{link}, \code{coefficients}, and/or
#' \code{vcov} are not explicitly specified.
#' @param data the dataset to be used to compute marginal effects (if not specified, it is extracted from the fitted model object).
#' @param formula the formula used in estimation (if not specified, it is extracted from the fitted model object).
#' @param link the name of the link function used in estimation (if not specified, it is extracted from the fitted model object).
#' @param coefficients the named vector of coefficients produced during the estimation (if not specified, it is extracted from the fitted model object).
#' @param vcov the variance-covariance matrix to be used for computing standard errors (if not specified, it is extracted from the fitted model object).
#' @param discrete logical. If TRUE, the function will compute the effect of a discrete change in \code{x}. If FALSE, the function will compute the partial derivative of \code{x}.
#' @param discrete_step The size of a discrete change in \code{x} used in computations (used only if \code{discrete=TRUE}).
#' @param at an optional named list of values of independent variables. These variables will be set to these value before computations.
#' The remaining numeric variables (except \code{x} and \code{over}) will be set to their means. The remaining factor variables will be set
#' to their modes.
#' @param mc logical. If TRUE, the standard errors and confidence intervals will be computed using simulations.
#' If FALSE (default), the delta method will be used.
#' @param iter the number of interations used in Monte-Carlo simulations. Default = 1,000.
#' @param weights an optional vector of sampling weights.
#' @param heatmap_dim a numeric vector containing the number of rows and columns used for drawing the heatmap. Default = 100 each.
#' @param p the singificance level for the marginal effects. Default = 0.05.
#' @author \code{plot_me} visualizes ME procedure described in Zhirnov, Moral, and Sedashov (2021) using the tools from \code{ggplot2} package.
#' @references
#' Zhirnov, Andrei, Mert Moral, and Evgeny Sedashov (2021). ``Taking Distributions Seriously: On the Interpretation
#' of the Estimates of Interactive Nonlinear Models.'' Working paper.
#' @details \code{plot_me} provides a convenient way to interpret two-way interactions using heatmaps.
#' It returns a ggplot object, which allows users to customize the plot using functions and layers from the \code{ggplot2} package.
#' @examples
#' ##Poisson regression with 2 variables and interaction between them
#' \dontrun{
#' data <- data.frame(y = rpois(10000, 10), x2 = rpois(10000, 5), x1 = rpois(10000, 3))
#' y <- glm(y ~ x1 + x2 + x1*x2, data = data, family = "poisson")
#' ## A contour-plot with 4 areas
#' library(ggplot2)
#' plot_me(model = y, data = data, x = "x1", over = "x2") +
#'     scale_fill_steps(low="yellow", high="red", n.breaks=4)
#' ## A heatmap with smooth transition of colors
#' plot_me(model = y, data = data, x = "x1", over = "x2") +
#'     scale_fill_gradient(low="yellow", high="red")
#' ## A heatmap with histograms at the edges
#' library(ggExtra)
#' g <- plot_me(model = y, data = data, x = "x1", over = "x2")
#' gt <- g + theme(legend.position="left") + scale_fill_gradient(low="yellow", high="red")
#' ggExtra::ggMarginal(gt, type="histogram", data=data, x=z, y=x)
#' ## if more control over the histograms needed:
#' nbins <- sapply(data[c("x1","x2")], grDevices::nclass.FD)
#' ggExtra::ggMarginal(gt, type="histogram", data=data, x=z, y=x,
#'    xparams=list(bins=nbins['x2']), yparams=list(bins=nbins['x1']))
#' }
#' \dontrun{
#' ## logit
#' m <- glm(any_dispute ~ flows.ln*polity2 + gdp_pc, data=strikes, family="binomial")
#' summary(m)
#' plot_me(model = m, x = "flows.ln", over = "polity2") +
#'     scale_fill_gradient(low="yellow", high="red") +
#'     labs(x="Polity", y="ln(FDI flows)")
#'}
#' @export

plot_me <- function(x, over, model = NULL, data = NULL,
                    link = NULL, formula = NULL, coefficients = NULL, vcov = NULL,
                    discrete = FALSE, discrete_step = 1,
                    at = NULL, mc = FALSE, iter = 1000,
                    heatmap_dim = c(100,100),
                    p = 0.05, weights = NULL) {

  if (!requireNamespace("ggplot2", quietly = TRUE)) {
    stop("Package \"ggplot2\" needed for this function to work. Please install it.",call. = FALSE)
  }

  args <- as.list(match.call())
  obj <- lapply(args[intersect(names(formals(me)), names(args))], eval)

  if (is.null(formula)) formula <- stats::formula(model)
  formula[[2L]] <- NULL
  check.required("formula","formula")
  allvars <- all.vars(formula)

  check.required("x","character")
  check.required("over","character")

  if (is.null(data)) data <- eval(model)[["data"]]
  check.required("data","data.frame")

  outside.formula <- setdiff(c(x,over,names(at)),allvars)
  if (length(outside.formula)>0) stop(paste("Failed to find the following variables in the formula:",outside.formula,collapse="\n"), call. = FALSE)
  outside.data <- setdiff(c(x,over),names(data))
  if (length(outside.data)>0) stop(paste("Failed to find the following variables in the dataset:",outside.data,collapse="\n"), call. = FALSE)

  if (length(weights) != nrow(data)) weights <- rep(1,nrow(data))

  tomeans <- setdiff(allvars, c(x,over,names(at)))
  names(tomeans) <- tomeans

  obj[["at"]] <- as.list(at)
  if (length(at)>0) {
    for (v in names(obj[["at"]])) {
      if (is.character(obj[["at"]][[v]]) & !is.factor(obj[["at"]][[v]])) {
        xle <-  model[["xlevels"]][[v]]
        if (is.null(xle)) xle <- sort(unique(data[[v]]))
        if (is.null(xle)) {
          stop("Please convert the character variables in the 'at' list into factors", call. = FALSE)
        }
        if (any(!obj[["at"]][[v]] %in% xle)) {
          stop(paste0("Could not find all listed values of ",v," in the model"), call. = FALSE)
        }
        obj[["at"]][[v]] <- factor(obj[["at"]][[v]], levels=xle)
      }
    }
  }
  if (length(tomeans)>0) obj[["at"]] <- c(obj[["at"]], lapply(tomeans, find.central, data=data, weights=weights))

  obj[["pct"]] <- 100*c(p/2, (1-p/2))
  names(obj[["pct"]]) <- c("lb","ub")
  obj[["discrete"]] <- discrete[1L]
  if (discrete) {
    obj[["discrete_step"]] <- discrete_step[1L]
  }

# data for heatmaps
  grid.li <- list(
    x = seq(from = min(data[[x]], na.rm=TRUE), to = max(data[[x]], na.rm=TRUE), length.out = heatmap_dim[1]),
    over = seq(from = min(data[[over]], na.rm=TRUE), to = max(data[[over]], na.rm=TRUE), length.out = heatmap_dim[2])
    )
  obj[["data"]] <- grid <- expand.grid(grid.li)
  colnames(obj[["data"]]) <- c(x,over)
  plotdata.hm <- do.call("me",obj)
  plotdata.hm <- merge(grid, plotdata.hm, by.x=c("x","over"), by.y=c(x,over))
  plotdata.hm[["sig"]] <- factor(rowSums(plotdata.hm[c("lb","ub")]>0) %% 2, levels=c(0,1), labels=paste0(c("p<","p>"),p))

# data for scatterplots
  temp <- aggregate(list("nobs" = weights), by = list(x=data[[x]], over=data[[over]]), FUN = sum, na.action=NULL, na.rm=TRUE)
  for (j in c("x","over")) {
    bm <- data.frame(y.n=grid.li[[j]][seq_len(length(grid.li[[j]])-1)],
                     y.x=grid.li[[j]][seq_len(length(grid.li[[j]])-1)+1])
    bm[1,"y.n"] <- -Inf
    bm[nrow(bm),"y.x"] <- Inf
    temp <- merge(temp, bm, by=NULL)
    temp <- temp[temp[[j]]>temp$y.n & temp[[j]]<=temp$y.x,]
    d.n <- abs(temp$y.n-temp[[j]])
    d.x <- abs(temp$y.x-temp[[j]])
    temp[[j]] <- temp$y.x
    temp[[j]][which(d.x > d.n)] <- temp$y.n[which(d.x > d.n)]
    temp <- temp[,c("nobs","x","over"), drop=FALSE]
  }
  temp <- aggregate(nobs ~ x + over, data=temp, FUN = sum, na.action=NULL, na.rm=TRUE)
  plotdata <- merge(plotdata.hm, temp, by=c("x","over"), all=TRUE)

# plot
  ggplot2::ggplot(data = plotdata, ggplot2::aes_string(x = "over", y = "x")) +
    ggplot2::geom_raster(ggplot2::aes_string(fill = "est"), interpolate=FALSE) +
    ggplot2::geom_point(ggplot2::aes_string(size = "nobs", shape = "sig"), color = "black", data=plotdata[which(!is.na(plotdata$nobs)),]) +
    ggplot2::scale_shape_manual(values=c(16L,1L), drop=FALSE) +
    ggplot2::guides(fill = ggplot2::guide_colourbar(order = 1L), shape = ggplot2::guide_legend(order = 2L), size = "none") +
    ggplot2::labs(fill="Effect Size", shape=ggplot2::element_blank(), x=over, y=x) +
    ggplot2::theme_bw() +
    ggplot2::theme(panel.grid.minor = ggplot2::element_blank(),
           panel.grid.major = ggplot2::element_blank(), panel.border = ggplot2::element_rect(colour = "black"),
           aspect.ratio = 1, legend.position = "right")
}
andreizhirnov/DAME documentation built on Nov. 21, 2022, 5:55 p.m.