R/dependence.R

Defines functions print.part_depend plot.part_depend dependence.beset_rf dependence.randomForest dependence.nested_elnet dependence.beset dependence.nested dependence

Documented in dependence dependence.beset dependence.beset_rf dependence.nested plot.part_depend

#' Partial-Dependence Effects
#'
#' Generate data for partial dependence plots.
#'
#' @param object A "beset" model object for which partial dependence plots
#' are desired.
#'
#' @param x (Optional) \code{character} vector giving name(s) of predictors
#' for which partial dependence effects should be estimated. If omitted,
#' partial dependence effects will be estimated for all predictors in the
#' model.
#'
#' @param cond (Optional) named \code{list} of the values to use for one or more
#' predictors when estimating the partial effect of the predictor(s) named in
#' \code{x}. Any predictor omitted from this list will be assigned the mean
#' observed value for continuous variables, or the modal level for factors.
#'
#' @param x_lab (Optional) \code{character} vector giving replacement labels
#' for \code{x}. Labels should be listed in the same order as the predictor
#' names in \code{x}, or, if \code{x} is omitted, the same column order as the
#' predictors in the model data frame. If omitted, plots will be labeled with
#' the names of the predictors as they appear in the model object.
#'
#' @param y_lab (Optional) \code{character} string giving replacement label
#' for the response variable in the model. If omitted, plots will be labeled
#' with the name of the response as it appears in the model object.
#'
#' @param ... Additional arguments affecting the plots produced.
#'
#' @param order If plotting a grid, order in which partial dependence plots
#' should appear. Options are \code{"delta"} (default), which arranges plots
#' based on the magnitude of change in \code{y} over the range of each \code{x},
#' and \code{"import"}, which arranges plots based on the variable importance
#' scores returned by the object's \code{\link{importance}} method.
#'
#' @param p_max Maximum number of partial dependence effects to plot.
#' Default is 16, resulting in a 4 X 4 grid.
#'
#' @inheritParams summary.beset
#' @inheritParams beset_glm

#' @export
dependence <- function(object, ...){
  UseMethod("dependence")
}

#' @rdname dependence
#' @export
dependence.nested <- function(object, x = NULL, cond = NULL,
                              alpha = NULL, lambda = NULL, n_pred = NULL,
                              metric = "auto", oneSE = TRUE,
                              x_lab = NULL, y_lab = NULL, ...,
                              parallel_type = NULL, n_cores = NULL, cl = NULL){
  if(inherits(object, "elnet"))
    do.call(dependence.nested_elnet, as.list(match.call())[-1])
  else
    tryCatch(
      error = function(c){
        c$message <- paste("Partial dependence method not yet implemented,",
                           "for `beset_lm` and `beset_glm`.")
        c$call <- NULL
        stop(c)
      }
    )
}

#' @rdname dependence
#' @export
dependence.beset <- function(
  object, x = NULL, cond = NULL, alpha = NULL, lambda = NULL, n_pred = NULL,
  metric = "auto", oneSE = TRUE, x_lab = NULL, y_lab = NULL, ...,
  parallel_type = NULL, n_cores = NULL, cl = NULL
){
  if(inherits(object, "rf")){
    return(dependence.beset_rf(object, x, cond, x_lab, y_lab, ...,
      parallel_type = parallel_type, n_cores = n_cores, cl = cl
    ))
  }
  model_data <- list(...)
  if(is.null(object$data)){
    object$data <- model_data$data
    object$terms <- model_data$terms
    object$contrasts <- model_data$contrasts
    object$xlevels <- model_data$xlevels
  }
  data <- object$data
  if(is.null(data)){
      stop(
        paste("Partial dependence cannot be calculated from skinny objects.",
              "Please rerun `beset_elnet` and set `skinny` to `FALSE`.")
      )
  }
  data <- mutate_if(data, is.logical, factor)
  if(is.null(x)) x <- labels(object$terms)
  if(is.null(x_lab)) x_lab <- x
  y <- setdiff(all.vars(object$terms), labels(object$terms))
  if(is.null(y_lab)) y_lab <- y
  if(length(x) > 1){
    return(map2(x, x_lab,
                ~ dependence.beset(object, x = .x, cond = cond,
                                   alpha = alpha, lambda = lambda,
                                   metric = metric, oneSE = oneSE,
                                   n_pred = n_pred, x_lab = .y, y_lab = y_lab)))
  }
  xes <- setdiff(labels(object$terms), x)
  covars <- data[xes]
  covars <- map_df(covars, function(x){
    if(is.factor(x))
      factor(levels(forcats::fct_infreq(x))[1], levels = levels(x)) else
        mean(x, na.rm = TRUE)
  })
  if(!is.null(cond)){
    for(varname in names(cond)) covars[, varname] <- cond[[varname]]
  }
  x_obs <- data[[x]]
  x_plot <- if(is.factor(x_obs)){
    factor(levels(x_obs), levels = levels(x_obs))
  } else if(is.character(x_obs) ||
            (is.integer(x_obs) && n_distinct(x_obs) < 100)){
    sort(unique(x_obs))
  } else {
    seq(min(x_obs, na.rm = T), max(x_obs, na.rm = T), length.out = 100)
  }
  plot_data <- vector("list", length(x_plot))
  plot_data <- map_df(plot_data, ~ return(covars))
  plot_data <- bind_cols(x = x_plot, plot_data)
  names(plot_data)[1] <- x
  y_hat <- predict(
    object, newdata = plot_data, newoffset = object$parameters$offset,
    alpha = alpha, lambda = lambda, metric = metric, oneSE = oneSE
  )
  y_obs <- object$parameters$y
  if(is.factor(y_obs)) y_obs <- as.integer(y_obs) - 1
  impact <- (max(y_hat) - min(y_hat)) / (max(y_obs) - min(y_obs))
  plot_data <- tibble(
    x = x_plot,
    y = y_hat
  )
  partial_effects <- tibble(
    variable = x_lab,
    delta = impact
  )
  structure(
    list(plot_data = plot_data, partial_effects = partial_effects),
    class = "part_depend"
  )
}

dependence.nested_elnet <- function(
  object, x = NULL, cond = NULL, alpha = NULL, lambda = NULL,
  metric = "auto", oneSE = TRUE, x_lab = NULL, y_lab = NULL, ...,
  parallel_type = NULL, n_cores = NULL, cl = NULL
){
  if(is.null(object$data)){
    stop(
      paste("Partial dependence cannot be calculated from skinny objects.",
            "Please rerun `beset_elnet` and set `skinny` to `FALSE`.")
    )
  }
  if(is.null(x)) x <- labels(object$terms)
  y <- setdiff(all.vars(object$terms), labels(object$terms))
  if(is.null(x_lab)) x_lab <- x
  if(is.null(y_lab)) y_lab <- y
  variable_importance <- importance(
    object, alpha = alpha, lambda = lambda, metric = metric, oneSE = oneSE
  )
  keep <- map(x, ~ grep(.x, variable_importance$variable)) %>% as_vector
  variable_importance <- variable_importance[keep,]
  matching_labels <- map(x, ~ grep(.x, variable_importance$variable))
  walk2(x_lab, matching_labels,
       function(x, i) variable_importance$variable[i] <<- x)
  variable_importance <- variable_importance %>% group_by(variable) %>%
    summarize_all(max)
  if(is.null(n_cores) || n_cores > 1){
    parallel_control <- setup_parallel(
      parallel_type = parallel_type, n_cores = n_cores, cl = cl)
    have_mc <- parallel_control$have_mc
    n_cores <- parallel_control$n_cores
    cl <- parallel_control$cl
  }
  pd <- if(n_cores > 1L){
    if(is.null(cl)){
      parallel::mclapply(object$beset, dependence, x = x, cond = cond,
                         alpha = alpha, lambda = lambda, metric = metric,
                         oneSE = oneSE, x_lab = x_lab, y_lab = y_lab,
                         data = object$data, terms = object$terms,
                         contrasts = object$contrasts,
                         xlevels = object$xlevels, mc.cores = n_cores)
        } else {
          parallel::parLapply(cl, object$beset, dependence, x = x,
                              cond = cond, alpha = alpha, lambda = lambda,
                              metric = metric, oneSE = oneSE,
                              x_lab = x_lab, y_lab = y_lab,
                              data = object$data, terms = object$terms,
                              contrasts = object$contrasts,
                              xlevels = object$xlevels)
        }
      } else {
      lapply(object$beset, dependence, x = x, cond = cond,
             alpha = alpha, lambda = lambda, metric = metric,
             oneSE = oneSE,  x_lab = x_lab, y_lab = y_lab,
             data = object$data, terms = object$terms,
             contrasts = object$contrasts,
             xlevels = object$xlevels)
      }
  if(!is.null(cl)) parallel::stopCluster(cl)
  if(length(x) > 1){
    pd <- transpose(pd)
    names(pd) <- x_lab
    impact <- map(pd, ~ map_dbl(.x, ~.x$partial_effects$delta))
    plot_data <- map(pd, ~ imap_dfr(.x, function(x, i){
      df <- x$plot_data; df$model <- i; df}))
    partial_dependence <- imap(
      plot_data, ~ ggplot(data = .x, aes(x = x, y = y)) +
      geom_line(aes(group = model), alpha = .1) +
      (if(is.factor(.x$x)) {
        stat_summary(
          aes(group=1), fun = mean, geom = "line", colour="tomato2", size = 1.5
        )
      } else {
        geom_smooth(
          method = "lm", formula = y ~ splines::ns(x,2), size = 1.5,
          color = "tomato2"
        )
      }) + xlab(.y) + ylab(y_lab) + theme_bw()
    )
  } else {
    impact <- map_dbl(pd, ~.x$partial_effects$delta)
    plot_data <- imap_dfr(pd, function(x, i){
      df <- x$plot_data; df$model <- i; df})
    partial_dependence <- ggplot(data = plot_data, aes(x = x, y = y)) +
      geom_line(aes(group = model), alpha = .1) +
        (if(is.factor(plot_data$x)) {
          stat_summary(
            aes(group=1), fun = mean, geom="line", colour="tomato2", size = 1.5)
        } else {
          geom_smooth(
            method = "lm", formula = y ~ splines::ns(x,2), size = 1.5,
            color = "tomato2"
          )
        }) + xlab(x_lab) + ylab(y_lab) + theme_bw()
  }
  variable_importance <- variable_importance %>%
    inner_join(
      tibble(
        variable = names(impact),
        delta = map_dbl(impact, median),
        delta_low = map_dbl(impact, ~ quantile(.x, .025)),
        delta_high = map_dbl(impact, ~ quantile(.x, .975))
      ), by = "variable")
  structure(
    list(partial_dependence = partial_dependence,
         variable_importance = variable_importance),
    class = "part_depend"
  )
}

#' @export
dependence.randomForest <- function(
  object, data = NULL, x = NULL, y = NULL, cond = NULL, x_lab = NULL,
  y_lab = NULL, make_plot = TRUE, ...
){
  if(is.null(x)) x <- labels(object$terms)
  if(is.null(y)) y <- setdiff(all.vars(object$terms), labels(object$terms))
  if(is.null(x_lab)) x_lab <- x
  if(is.null(y_lab)) y_lab <- y
  if(is.null(data)){
    rf_par <- as.list(object$call)
    data <- eval(rf_par$data)
  }
  if(length(x) > 1){
    out <- map2(x, x_lab, ~ dependence.randomForest(
      object, data = data, x = .x, y = y, cond = cond, x_lab = .y,
      y_lab = y_lab, make_plot = make_plot))
    names(out) <- x_lab
    out <- transpose(out)
    out$variable_importance <- reduce(out$variable_importance, rbind)
    class(out) <- "part_depend"
    return(out)
  }
  xes <- setdiff(names(data), x)
  x_obs <- as_vector(data[x])
  covars <- data[xes]
  covars <- map_df(covars, function(x){
    if(is.factor(x))
      factor(levels(forcats::fct_infreq(x))[1], levels = levels(x)) else
        mean(x, na.rm = TRUE)
  })
  if(!is.null(cond)){
    for(varname in names(cond)) covars[, varname] <- cond[[varname]]
  }
  x_plot <- if(is.factor(x_obs)){
    factor(levels(x_obs), levels = levels(x_obs))
  } else if(is.character(x_obs) ||
            (is.integer(x_obs) && n_distinct(x_obs) < 100)){
    sort(unique(x_obs))
  } else {
    seq(min(x_obs, na.rm = T), max(x_obs, na.rm = T), length.out = 100)
  }
  plot_data <- vector("list", length(x_plot))
  plot_data <- map_df(plot_data, ~ return(covars))
  plot_data <- bind_cols(x = x_plot, plot_data)
  names(plot_data)[1] <- x
  type <- if(object$type == "classification") "prob" else "response"
  y_hat <- predict(object, newdata = plot_data, type = type)
  if(inherits(y_hat, "matrix")) y_hat <- y_hat[, "1"]
  y_obs <- object$y
  if(is.factor(y_obs)) y_obs <- as.integer(y_obs) - 1
  impact <- (max(y_hat) - min(y_hat)) / (max(y_obs) - min(y_obs))
  plot_data <- tibble(
    x = x_plot,
    y = y_hat
  )
  partial_dependence <- if(make_plot){
    ggplot(data = plot_data, aes(x = x, y = y)) +
    geom_line(aes(group = 1)) + theme_classic() + xlab(x_lab) + ylab(y_lab)
  } else plot_data
  variable_importance <- tibble(
    variable = x_lab,
    delta = impact
  )
  structure(
    list(partial_dependence = partial_dependence,
         variable_importance = variable_importance),
    class = "part_depend"
  )
}

#' @export
#' @rdname dependence
dependence.beset_rf <- function(
  object, x = NULL, cond = NULL, x_lab = NULL, y_lab = NULL, ...,
  parallel_type = NULL, n_cores = NULL, cl = NULL
){
  if(is.null(n_cores) || n_cores > 1){
    parallel_control <- setup_parallel(
      parallel_type = parallel_type, n_cores = n_cores, cl = cl)
    have_mc <- parallel_control$have_mc
    n_cores <- parallel_control$n_cores
    cl <- parallel_control$cl
  }
  if(is.null(x)) x <- names(object$forests[[1]]$call$x)
  y <- setdiff(names(object$data), names(object$forests[[1]]$call$x))
  if(is.null(x_lab)) x_lab <- x; if(is.null(y_lab)) y_lab <- y
  variable_importance <- importance(object)
  keep <- map(x, ~ grep(.x, variable_importance$variable)) %>% as_vector
  variable_importance <- variable_importance[keep,]
  matching_labels <- map(x, ~ grep(.x, variable_importance$variable))
  walk2(x_lab, matching_labels,
        function(x, i) variable_importance$variable[i] <<- x)
  data <- object$data
  attr(data, "terms") <- NULL
  pd <- if(n_cores > 1L){
    if(have_mc){
      parallel::mclapply(object$forests, dependence.randomForest,
                         data = data, x = x, y = y, cond = cond,
                         x_lab = x_lab, y_lab = y_lab, make_plot = FALSE,
                         mc.cores = n_cores)
    } else {
      parallel::parLapply(cl, object$forests, dependence.randomForest,
                          data = data, x = x, y = y, cond = cond,
                          x_lab = x_lab, y_lab = y_lab, make_plot = FALSE)
    }
  } else {
    lapply(object$forests, dependence.randomForest, data = data,
           x = x, y = y, cond = cond, x_lab = x_lab, y_lab = y_lab,
           make_plot = FALSE)
  }
  if(!is.null(cl)) parallel::stopCluster(cl)
  pd <- transpose(pd)
  impact <- pd$variable_importance %>% map("delta") %>% transpose %>%
    simplify_all
  names(impact) <- x_lab
  plot_data <- pd$partial_dependence
  if(length(x) > 1){
    plot_data <- transpose(plot_data)
    names(plot_data) <- x_lab
    plot_data <- map(plot_data, ~ imap_dfr(.x, function(x, i){
      df <- x; df$model <- i; df}))
  } else {
    plot_data <- list(imap_dfr(plot_data, function(x, i){
      df <- x; df$model <- i; df}))
    names(plot_data) <- x_lab
  }
  partial_dependence <- imap(
    plot_data, ~ ggplot(data = .x, aes(x = x, y = y)) +
      geom_line(aes(group = model), alpha = .1) + (
        if(is.factor(.x$x)) {
          stat_summary(
            aes(group=1), fun = mean, geom = "line", colour="tomato2",
            size = 1.5
          )
        } else {
          geom_smooth(method = "loess", size = 1.5, color = "tomato2")
        }
      ) + xlab(.y) + ylab(y_lab) + theme_bw()
    )
  variable_importance <- variable_importance %>%
    mutate(delta = map_dbl(impact, median),
           delta_low = map_dbl(impact, ~ quantile(.x, .025)),
           delta_high = map_dbl(impact, ~ quantile(.x, .975)))
  structure(
    list(partial_dependence = partial_dependence,
         variable_importance = variable_importance),
    class = "part_depend"
  )
}

#' @export
#' @rdname dependence
plot.part_depend <- function(x, order = "delta", p_max = 16, ...){
  plot_data <- x$partial_dependence
  n_vars <- length(plot_data)
  if(n_vars == 1) return(plot(plot_data[[1]]))
  varimp <- x$variable_importance
  varimp <- if(order == "import"){
    arrange(varimp, desc(importance))
  } else {
    arrange(varimp, desc(delta))
  }
  min_y <- plot_data %>% map(~.x$data$y) %>% map_dbl(min) %>% min
  max_y <- plot_data %>% map(~.x$data$y) %>% map_dbl(max) %>% max
  p_max <- min(length(plot_data), p_max)
  plots <- map(plot_data[varimp$variable[1:p_max]],
               ~.x + ylab("") +
                 coord_cartesian(ylim = c(min_y, max_y)) +
                 theme(panel.grid.major = element_blank(),
                       panel.grid.minor = element_blank(),
                       axis.line = element_line(size = .5),
                       axis.text.x = element_text(size = 10),
                       axis.text.y = element_text(size = 10),
                       axis.title.x = element_text(size = 12),
                       axis.title.y = element_text(size = 12))
  )
  gout <- suppressWarnings(
    gridExtra::arrangeGrob(grobs = plots,
                           left = x$partial_dependence[[1]]$labels$y)
  )
  plot(gout)
}

#' @export
print.part_depend <- function(x, ...){
  plot(x, ...)
}
jashu/beset documentation built on April 20, 2023, 5:28 a.m.