R/margeffPlot.R

#' Plot marginal effects / partial dependence curves
#'
#' @param fit a model it
#' @param X a data frame to be used. should not included the response variable.
#' @param y the response variable. used only for plotting numeric covariates, and is optional.
#' @param predfun a prediction function that takes the arguments fit and newdata,
#' and returns a vector of predicted values.
#' @param wch which variables to plot. can be a single number indexing the variable, or
#' a vector of two numbers to plot the predicted values as a function of both variables.
#' @param res the resolution of the plot, in other words, how many points along the function should define the curve/line.
#' @param points if TRUE, and y is not NULL, the data points are plotted with the regression line.
#' @return a plot
#' @export
#'
#' @examples
#' fit <- KIR(prog ~ .-sex, diabetes)
#' yhat <- predict(fit, newdata, type = "response")
#' margeffPlot(fit = fit, X = diabetes[,-c(2, 11)], predfun = yhat, wch = 4)
#'
margeffPlot <- function (fit, X, y = NULL, predfun, wch = 1, res = 100, points = FALSE){
  N = dim(X)[1]
  d = dim(X)[2]
  if (length(wch) == 1) {
    if (class(X[, wch]) == "numeric" | class(X[, wch]) == "integer") {
      xo <- X[, wch]
      fJ = numeric(res)
      fJ = numeric(res)
      xmin = min(X[, wch])
      xmax = max(X[, wch])
      x = seq(xmin, xmax, length.out = res)
      for (res in 1:res) {
        X.predict = X
        X.predict[, wch] = x[res]
        y.hat = predfun(fit = fit, newdata = X.predict)
        fJ[res] = mean(y.hat)
      }
      a <- cut(X[, wch], breaks = c(xmin - (x[2] - x[1]), x), include.lowest = TRUE)
      b <- as.numeric(table(a))
      fJ = fJ - sum(fJ * b)/sum(b)
      df <- cbind.data.frame(x = x, fJ = fJ)
      #if (!is.null(y)){fJ <- df$fJ <-  fJ + (mean(y) - mean(fJ))}
      yfunc <- splinefun(x = df$x, y = df$fJ, method = "natural")
      funcplot <- ggplot(df, aes(x = x, y = fJ))
      if (!is.null(y) && points){
        funcplot <- funcplot +
          geom_point(aes(x = xo, y = y),
                     data = cbind.data.frame(xo = xo, y = y),
                     color = "#1e0536", alpha = 0.95)
      }
      funcplot <- funcplot +
        stat_function(fun = yfunc, color = "#d92602", size = 1.5, alpha = 0.90, geom = "line") +
        ylab(label = paste0("f", "(", names(X)[wch], ")")) +
        xlab(label = names(X)[wch])
      if (!is.null(y)){funcplot <- funcplot + ylim(min(y), max(y))}
      plot(funcplot)
    }
    else if (class(X[, wch]) == "factor") {
      X[, wch] <- droplevels(X[, wch])
      x.count <- as.numeric(table(X[, wch]))
      x.prob <- x.count/sum(x.count)
      res <- nlevels(X[, wch])
      x <- levels(X[, wch])
      fJ = numeric(res)
      for (res in 1:res) {
        X.predict = X
        X.predict[, wch] = x[res]
        y.hat = predfun(fit = fit, newdata = X.predict)
        fJ[res] = mean(y.hat)
      }
      fJ = fJ - sum(fJ * x.prob)
      barplot(fJ, names = x, xlab = names(X)[wch],
              ylab = paste0("f", "(", names(X)[wch], ")"), las = 3)
    }
    else print("error:  class(X[,wch]) must be either factor or numeric or integer")
  }
  else if (length(wch) == 2) {
    if (class(X[, wch[2]]) != "numeric" & class(X[, wch[2]]) !=
        "integer") {
      print("error: X[,wch[2]] must be numeric or integer. Only X[,wch[1]] can be a factor")
    }
    if (class(X[, wch[1]]) == "factor") {
      X[, wch[1]] <- droplevels(X[, wch[1]])
      K1 <- nlevels(X[, wch[1]])
      fJ = matrix(0, K1, res)
      x1.char <- levels(X[, wch[1]])
      x1.num <- 1:K1
      xmin2 = min(X[, wch[2]])
      xmax2 = max(X[, wch[2]])
      x2 = seq(xmin2, xmax2, length.out = res)
      for (k1 in 1:K1) {
        for (k2 in 1:res) {
          X.predict = X
          X.predict[, wch[1]] = x1.char[k1]
          X.predict[, wch[2]] = x2[k2]
          y.hat = predfun(fit = fit, newdata = X.predict)
          fJ[k1, k2] = mean(y.hat)
        }
      }
      b1 = as.numeric(table(X[, wch[1]]))
      a2 = cut(X[, wch[2]], breaks = c(xmin2 - (x2[2] -
                                                  x2[1]), x2), include.lowest = TRUE)
      b2 = as.numeric(table(a2))
      b = as.matrix(table(X[, wch[1]], a2))
      fJ1 = apply(t(fJ) * b2, 2, sum)/sum(b2)
      fJ2 = apply(fJ * b1, 2, sum)/sum(b1)
      fJ = fJ - outer(fJ1, rep(1, res)) - outer(rep(1,
                                                    K1), fJ2)
      fJ0 = sum(fJ * b)/sum(b)
      fJ = fJ - fJ0
      x <- list(x1.char, x2)
      res <- c(K1, res)
      image(x1.num, x2, fJ, xlab = names(X)[wch[1]], ylab = paste0("f",
                                                                   "(", names(X)[wch[2]], ")"), ylim = range(x2),
            yaxs = "i")
      contour(x1.num, x2, fJ, add = TRUE, drawlabels = TRUE)
      axis(side = 1, labels = x1.char, at = 1:K1, las = 3,
           padj = 1.2)
    }
    else if (class(X[, wch[1]]) == "numeric" | class(X[,wch[1]]) == "integer") {
      fJ = matrix(0, res, res)
      xmin1 = min(X[, wch[1]])
      xmax1 = max(X[, wch[1]])
      xmin2 = min(X[, wch[2]])
      xmax2 = max(X[, wch[2]])
      x1 = seq(xmin1, xmax1, length.out = res)
      x2 = seq(xmin2, xmax2, length.out = res)
      for (k1 in 1:res) {
        for (k2 in 1:res) {
          X.predict = X
          X.predict[, wch[1]] = x1[k1]
          X.predict[, wch[2]] = x2[k2]
          y.hat = predfun(fit = fit, newdata = X.predict)
          fJ[k1, k2] = mean(y.hat)
        }
      }
      a1 = cut(X[, wch[1]], breaks = c(xmin1-(x1[2]-x1[1]), x1), include.lowest = TRUE)
      a2 = cut(X[, wch[2]], breaks = c(xmin2-(x2[2]-x2[1]), x2), include.lowest = TRUE)
      b1 = as.numeric(table(a1))
      b2 = as.numeric(table(a2))
      b = as.matrix(table(a1, a2))
      fJ1 = apply(t(fJ) * b2, 2, sum)/sum(b2)
      fJ2 = apply(fJ * b1, 2, sum)/sum(b1)
      fJ = fJ - outer(fJ1, rep(1, res))-outer(rep(1,res), fJ2)
      fJ0 = sum(fJ * b)/sum(b)
      fJ = fJ - fJ0
      x <- list(x1, x2)
      res <- c(res, res)
      image(x1, x2, fJ, xlab = names(X)[wch[1]],
            ylab = paste0("f", "(", names(X)[wch[2]], ")"), xlim = range(x1),
            ylim = range(x2), xaxs = "i", yaxs = "i")
      contour(x1, x2, fJ, add = TRUE, drawlabels = TRUE)
    }
    else print("error:  class(X[,wch[1]]) must be either factor or numeric/integer")
  }
  else print("error:  wch must be a vector of length one or two")
}
abnormally-distributed/cvreg documentation built on May 3, 2020, 3:45 p.m.