R/partDependence.R

Defines functions plot.partDependence print.partDependence partDependence

Documented in partDependence plot.partDependence print.partDependence

#' Partial dependence
#' 
#' This function calculates the partial dependence of a model on a single variable.
#' For that predictions are made for all observations in the dataset while varying 
#' the value of the variable of interest. The overall partial effect is the average
#' of all predictions. \insertCite{Friedman2001GreedyMachine}{SDModels}
#' @importFrom Rdpack reprompt
#' @references
#'   \insertAllCited{}
#' @author Markus Ulmer
#' @param object A model object that has a predict method that takes newdata as argument 
#' and returns predictions.
#' @param j The variable for which the partial dependence should be calculated.
#' Either the column index of the variable in the dataset or the name of the variable.
#' @param X The dataset on which the partial dependence should be calculated.
#' Should contain the same variables as the dataset used to train the model.
#' If NULL, tries to extract the dataset from the model object.
#' @param subSample Number of samples to draw from the original data for the empirical 
#' partial dependence. If NULL, all the observations are used.
#' @param mc.cores Number of cores to use for parallel computation. 
#' Parallel computing is only supported for unix.
#' @return An object of class \code{partDependence} containing
#' \item{preds_mean}{The average prediction for each value of the variable of interest.}
#' \item{x_seq}{The sequence of values for the variable of interest.}
#' \item{preds}{The predictions for each value of the variable of interest for each observation.}
#' \item{j}{The name of the variable of interest.}
#' \item{xj}{The values of the variable of interest in the dataset.}
#' @examples
#' set.seed(1)
#' x <- rnorm(100)
#' y <- sign(x) * 3 + rnorm(100)
#' model <- SDTree(x = x, y = y, Q_type = 'no_deconfounding')
#' pd <- partDependence(model, 1, X = x, subSample = 10)
#' plot(pd)
#' @seealso \code{\link{SDForest}}, \code{\link{SDTree}}
#' @export
partDependence <- function(object, j, X = NULL, subSample = NULL, mc.cores = 1){
  j_name <- j

  if(is.null(X)){
    X <- object$X
    if(is.null(X)) stop('X must be provided if it is not part of the object')
    
  }
  X <- data.frame(X)

  if(is.character(j)){
    j <- which(names(X) == j)
  }
  
  if(!is.null(subSample)) X <- X[sample(1:nrow(X), subSample), ]
  X <- data.frame(X)
  
  if(!is.numeric(j)) stop('j must be a numeric or character')
  if(j > ncol(X)) stop('j must be smaller than p')
  if(j < 1) stop('j must be larger than 0')
  if(any(is.na(X))) stop('X must not contain missing values')
  
  x_seq <- seq(min(X[, j]), max(X[, j]), length.out = 100)
  
  if(mc.cores > 1){
    preds <- parallel::mclapply(x_seq, function(x){
      X_new <- X
      X_new[, j] <- x
      pred <- predict(object, newdata = X_new)
      return(pred)
    }, mc.cores = mc.cores)
  }else{
    preds <- pbapply::pblapply(x_seq, function(x){
      X_new <- X
      X_new[, j] <- x
      pred <- predict(object, newdata = X_new)
      return(pred)
    })
  }
  preds <- do.call(rbind, preds)
  preds_mean <- rowMeans(preds)
  
  res <- list(preds_mean = preds_mean, x_seq = x_seq, preds = preds, j = j_name, xj = X[, j])
  class(res) <- 'partDependence'
  
  res
}

#' Print partDependence
#' 
#' Print contents of the partDependence.
#' @author Markus Ulmer
#' @param x Fitted object of class \code{partDependence}.
#' @param ... Further arguments passed to or from other methods.
#' @return No return value, called for side effects
#' @seealso \code{\link{partDependence}}, \code{\link{plot.partDependence}}
#' @method print partDependence
#' @examples
#' set.seed(1)
#' x <- rnorm(10)
#' y <- sign(x) * 3 + rnorm(10)
#' model <- SDTree(x = x, y = y, Q_type = 'no_deconfounding', cp = 0.5)
#' pd <- partDependence(model, 1, X = x)
#' print(pd)
#' @export
print.partDependence <- function(x, ...){
  cat("Partial dependence of covariate: ", x$j, "\n")
  cat("Plot to analyze!")
}

#' Plot partial dependence
#' 
#' This function plots the partial dependence of a model on a single variable.
#' @author Markus Ulmer
#' @param x An object of class \code{partDependence} returned by \code{\link{partDependence}}.
#' @param n_examples Number of examples to plot in addition to the average prediction.
#' @param ... Further arguments passed to or from other methods.
#' @return A ggplot object.
#' @seealso \code{\link{partDependence}}
#' set.seed(1)
#' x <- rnorm(10)
#' y <- sign(x) * 3 + rnorm(10)
#' model <- SDTree(x = x, y = y, Q_type = 'no_deconfounding', cp = 0.5)
#' pd <- partDependence(model, 1, X = x)
#' plot(pd)
#' @export
plot.partDependence <- function(x, n_examples = 19, ...){
  ggdep <- ggplot2::ggplot() + ggplot2::theme_bw()
  preds <- x$preds
  x_seq <- x$x_seq
  
  
  sample_examples <- sample(1:ncol(preds), min(n_examples, ncol(preds)))

  for(i in sample_examples){
    pred_data <- data.frame(x = x_seq, y = preds[, i])
    ggdep <- ggdep + ggplot2::geom_line(data = pred_data, 
                                        ggplot2::aes(x = .data$x, y = .data$y), col = 'grey')
  }
  
  ggdep <- ggdep + ggplot2::geom_line(data = data.frame(x = x_seq, y = x$preds_mean), 
                                      ggplot2::aes(x = .data$x, y = .data$y), col = '#08cbba', 
                                      linewidth = 1.5)
  ggdep <- ggdep + ggplot2::geom_rug(data = data.frame(x = x$xj, 
                                                       y = min(preds[, sample_examples])), 
                                     ggplot2::aes(x = .data$x, y = .data$y), 
                                     sides = 'b', col = '#949494')
  ggdep <- ggdep + ggplot2::ylab('f(x)') + ggplot2::ggtitle('Partial dependence')
  if(is.character(x$j)){
    ggdep <- ggdep + ggplot2::xlab(x$j)
  }else{
    ggdep <- ggdep + ggplot2::xlab(paste('x', x$j, sep = ''))
  }
  
  ggdep
}

Try the SDModels package in your browser

Any scripts or data that you put into this service are public.

SDModels documentation built on April 11, 2025, 5:50 p.m.