R/plot.R

Defines functions plot.margins

Documented in plot.margins

#' @title Plot Marginal Effects Estimates
#' @description An implementation of Stata's \samp{marginsplot} as an S3 generic function
#' @param x An object of class \dQuote{margins}, as returned by \code{\link{margins}}.
#' @param pos A numeric vector specifying the x-positions of the estimates (or y-positions, if \code{horizontal = TRUE}).
#' @param which A character vector specifying which marginal effect estimate to plot. Default is all.
#' @param labels A character vector specifying the axis labels to use for the marginal effect estimates. Default is the variable names from \code{x}.
#' @param horizontal A logical indicating whether to plot the estimates along the x-axis with vertical confidence intervals (the default), or along the y-axis with horizontal confidence intervals.
#' @param xlab A character string specifying the x-axis (or y-axis, if \code{horizontal = TRUE}) label.
#' @param ylab A character string specifying the y-axis (or x-axis, if \code{horizontal = TRUE}) label.
#' @param level A numeric value between 0 and 1 indicating the confidence level to use when drawing error bars.
#' @param pch The point symbol to use for plotting marginal effect point estimates. See \code{\link[graphics]{points}} for details.
#' @param points.col The point color to use for plotting marginal effect point estimates. See \code{\link[graphics]{points}} for details.
#' @param points.bg The point color to use for plotting marginal effect point estimates. See \code{\link[graphics]{points}} for details.
#' @param las An integer value specifying the orientation of the axis labels. See \code{\link[graphics]{par}} for details.
#' @param cex A numerical value giving the amount by which plotting text and symbols should be magnified relative to the default. See \code{\link[graphics]{par}} for details.
#' @param lwd A numerical value giving the width of error bars in points.
#' @param zeroline A logical indicating whether to draw a line indicating zero. Default is \code{TRUE}.
#' @param zero.col A character string indicating a color to use for the zero line if \code{zeroline = TRUE}.
#' @param \dots Additional arguments passed to \code{\link[graphics]{plot.default}}, such as \code{title}, etc.
#' @details This function is invoked for its side effect: a basic dot plot with error bars displaying marginal effects as generated by \code{\link{margins}}, in the style of Stata's \samp{marginsplot} command.
#' @return The original \dQuote{margins} object \code{x}, invisibly.
#' @examples
#' \dontrun{
#'   require("datasets")
#'   x <- lm(mpg ~ cyl * hp + wt, data = mtcars)
#'   mar <- margins(x)
#'   plot(mar)
#' }
#' 
#' @seealso \code{\link{margins}}, \code{\link{persp.lm}}
#' @keywords graphics
#' @importFrom graphics abline axis plot points segments
#' @export
plot.margins <- 
function(x, 
         pos = seq_along(marginal_effects(x, with_at = FALSE)),
         which = colnames(marginal_effects(x, with_at = FALSE)), 
         labels = gsub("^dydx_", "", which),
         horizontal = FALSE,
         xlab = "",
         ylab = "Average Marginal Effect",
         level = 0.95,
         pch = 21, 
         points.col = "black",
         points.bg = "black",
         las = 1,
         cex = 1,
         lwd = 2, 
         zeroline = TRUE,
         zero.col = "gray",
         ...) {
    
    pars <- list(...)

    summ <- summary(x, level = level, by_factor = TRUE)
    MEs <- summ[, "AME", drop = TRUE]
    lower <- summ[, ncol(summ) - 1L]
    upper <- summ[, ncol(summ)]
    r <- max(upper) - min(lower)
    
    at_levels <- unique(summ[, names(attributes(x)[["at"]]), drop = FALSE])
    n_at_levels <- nrow(at_levels)
    if (n_at_levels > 1) {
        pos2 <- rep(pos, each = n_at_levels)
        pos2 <- pos2 + seq(from = -0.2, to = 0.2, length.out = n_at_levels)
    } else {
        pos2 <- pos
    }
    
    if (isTRUE(horizontal)) {
        xlim <- if ("xlim" %in% names(pars)) xlim else c(min(lower)-0.04*r, max(upper)+0.04*r)
        ylim <- if ("ylim" %in% names(pars)) xlim else c(min(pos2)-(0.04*min(pos2)), max(pos2) + (0.04*max(pos2)))
    } else {
        xlim <- if ("xlim" %in% names(pars)) xlim else c(min(pos2)-(0.04*min(pos2)), max(pos2) + (0.04*max(pos2)))
        ylim <- if ("ylim" %in% names(pars)) xlim else c(min(lower)-0.04*r, max(upper)+0.04*r)
    }
    
    if (isTRUE(horizontal)) {
        plot(NA, xlim = xlim,
                 ylim = ylim, 
                 yaxt = 'n', xlab = ylab, ylab = xlab, las = las, ...)
        if (isTRUE(zeroline)) {
            abline(v = 0, col = zero.col)
        }
        points(MEs, pos2, col = points.col, bg = points.bg, pch = pch)
        axis(2, at = pos, labels = as.character(labels), las = las)
        mapply(function(pos, upper, lower, lwd) {
            segments(upper, pos, 
                     lower, pos, 
                     col = points.col, lwd = lwd)
        }, pos2, upper, lower, seq(max(lwd), 0.25, length.out = length(MEs)))
    } else {
        plot(NA, xlim = xlim, 
                 ylim = ylim, 
                 xaxt = 'n', xlab = xlab, ylab = ylab, las = las, ...)
        if (isTRUE(zeroline)) {
            abline(h = 0, col = zero.col)
        }
        points(pos2, MEs, col = points.col, bg = points.bg, pch = pch)
        axis(1, at = pos, labels = as.character(labels), las = las)
        mapply(function(pos, upper, lower, lwd) {
            segments(pos, upper, 
                     pos, lower, 
                     col = points.col, lwd = lwd)
        }, pos2, upper, lower, seq(max(lwd), 0.25, length.out = length(MEs)))
    }
    invisible(x)
}
leeper/margins documentation built on Jan. 26, 2021, 9:12 p.m.