R/gbm-interactions.r

Defines functions compute_preds_for_all_var_combinations table_of_unique_values check_and_set_variables_indices check_and_set_num_trees interact.GBMFit interact

Documented in interact interact.GBMFit

#' Estimate the strength of interaction effects
#' 
#' Computes Friedman's H-statistic to assess the strength of variable
#' interactions.
#' 
#' \code{interact.GBMFit} computes Friedman's H-statistic to assess the relative
#' strength of interaction effects in non-linear models. H is on the scale of
#' [0-1] with higher values indicating larger interaction effects. To connect
#' to a more familiar measure, if \eqn{x_1} and \eqn{x_2} are uncorrelated
#' covariates with mean 0 and variance 1 and the model is of the form
#' \deqn{y=\beta_0+\beta_1x_1+\beta_2x_2+\beta_3x_3} then
#' \deqn{H=\frac{\beta_3}{\sqrt{\beta_1^2+\beta_2^2+\beta_3^2}}}
#' 
#' Note that if the main effects are weak, the estimated H will be unstable.
#' For example, if (in the case of a two-way interaction) neither main effect
#' is in the selected model (relative influence is zero), the result will be
#' 0/0. Also, with weak main effects, rounding errors can result in values of H
#' > 1 which are not possible.
#' 
#' @usage interact(gbm_fit_obj, data, var_indices=1, num_trees=gbm_fit_obj$params$num_trees)
#' 
#' @param gbm_fit_obj a \code{GBMFit} object fitted using a call to \code{\link{gbmt}}.
#' 
#' @param data the dataset used to construct \code{gbm_fit_obj}. If the original dataset
#' is large, a random subsample may be used to accelerate the computation.
#' 
#' @param var_indices a vector of indices or the names of the variables for compute
#' the interaction effect. If using indices, the variables are indexed in the
#' same order that they appear in the initial \code{gbmt} formula.
#' 
#' @param num_trees the number of trees used to generate the plot. Only the first
#' \code{num_trees} trees will be used.
#' 
#' @return Returns the value of \code{H}.
#' 
#' @author Greg Ridgeway \email{gregridgeway@@gmail.com}
#' @seealso \code{\link{gbmt}}
#' @references J.H. Friedman and B.E. Popescu (2005). \dQuote{Predictive
#' Learning via Rule Ensembles.} Section 8.1
#' @keywords methods
#' @export 
#'

interact <- function(gbm_fit_obj, data, var_indices=1, num_trees=gbm_fit_obj$params$num_trees) {
  UseMethod("interact", gbm_fit_obj)
}

#' @name interact
#' @export
interact.GBMFit <- function(gbm_fit_obj, data, var_indices=1, num_trees = gbm_fit_obj$params$num_trees){
  # Initial input checks and set up
  if(!is.data.frame(data) && !is.matrix(data)) {
    stop("data argument should be a data.frame or matrix")  
  }
  if(!is.atomic(var_indices) ||
     any(is.infinite(var_indices)) || any(is.na(var_indices)) || any(is.nan(var_indices))|| 
     !(all(var_indices == as.integer(var_indices)) || all(var_indices == as.character(var_indices)))
     || is.na(all(var_indices == as.integer(var_indices))) || 
          is.na(all(var_indices == as.character(var_indices)))) {
    stop("Variables indices must be a vector of integers or characters")
  }
  if (gbm_fit_obj$params$interaction_depth < length(var_indices)){
    stop("interaction_depth < length(variables_indices): too low in model call")
  }
  var_indices <- check_and_set_variables_indices(gbm_fit_obj, var_indices)
  num_trees <- check_and_set_num_trees(gbm_fit_obj, num_trees)

  # Convert factors to appropriate numerics
  for(var in var_indices) {
    if(is.factor(data[, gbm_fit_obj$variables$var_names[var]]))
      data[, gbm_fit_obj$variables$var_names[var]] <-
        as.numeric(data[, gbm_fit_obj$variables$var_names[var]])-1
  }
  
  # Generate a list with all combinations of variables
  all_combinations_vars <- apply(expand.grid(rep(list(c(FALSE,TRUE)), length(var_indices)))[-1,], 1,
             function(x) as.numeric(which(x)))
  
  # Compute predictions and "parity" for all variable combinations  
  preds_for_comb_vars <- compute_preds_for_all_var_combinations(data, gbm_fit_obj, all_combinations_vars, var_indices, num_trees)
  
  # Compute H-statistic
  # Set to prediction with all variables
  H_stat_squared <- preds_for_comb_vars[[length(all_combinations_vars)]]$preds
  
  # Loop over other combinations and see what variables have been excluded
  # Add to predictions for all variables with correct sign
  for(vars in seq_len((length(all_combinations_vars)-1))){
    i1 <- apply(preds_for_comb_vars[[length(all_combinations_vars)]]$data[, all_combinations_vars[[vars]], drop=FALSE],
                1, paste, collapse="\r")
    i2 <- apply(preds_for_comb_vars[[vars]]$data, 1, paste,collapse="\r")
    i <- match(i1, i2)
    
    H_stat_squared <- H_stat_squared + with(preds_for_comb_vars[[vars]], sign*preds[i,])
  }
  
  # The H-statistic squared is given by sum over variables predictions with all included 
  # minus the partial dependence on a variable + partial dependence excluding a variable
  # This sum is normalized by the sum of the prediction with no variables excluded
  weights <- matrix(preds_for_comb_vars[[length(all_combinations_vars)]]$num_levels_factors, ncol=1)
  sum_preds_no_exclusion <- matrix(preds_for_comb_vars[[length(all_combinations_vars)]]$preds^2, ncol=1, byrow=FALSE)
  
  numerator <- apply(H_stat_squared^2, 2, weighted.mean, w = weights, na.rm = TRUE)
  denominator <- apply(sum_preds_no_exclusion, 2, weighted.mean, w = weights, na.rm = TRUE)
  H_stat_squared <- numerator / denominator
  
  # If H > 1, rounding and tiny main effects have messed things up
  H_stat_squared[H_stat_squared > 1] <- NaN
  
  return(sqrt(H_stat_squared))
}


#### Helper Functions - Not to be used outside of this interact function ####
check_and_set_num_trees <- function(gbm_fit_obj, num_trees) {
  if(length(num_trees) > 1) {
    warning("length num_trees > 1: using first element")
    num_trees <- num_trees[1]
  }
  check_if_natural_number(num_trees)
  if (num_trees > gbm_fit_obj$params$num_trees) {
    warning(paste("num_trees exceeds the number of trees in the model, ",
                  gbm_fit_obj$params$num_trees,". Using ", gbm_fit_obj$params$num_trees, " trees.", sep = ""))
    num_trees <- gbm_fit_obj$params$num_trees
  }
  
  return(num_trees)
}

check_and_set_variables_indices <- function(gbm_fit_obj, variables_indices) {
  # Match up variable_indices to var_names - convert characters
  if (all(is.character(variables_indices))){
    i <- match(variables_indices, gbm_fit_obj$variables$var_names)
    if (any(is.na(i))) {
      stop("Variables given are not used in gbm model fit: ", variables_indices[is.na(i)])
    }
    else {
      variables_indices <- i
    }
  }
  
  if ((min(variables_indices) < 1) || (max(variables_indices) > length(gbm_fit_obj$variables$var_names))) {
    warning("variables_indices must be between 1 and ", length(gbm_fit_obj$variables$var_names))
  }
  
  return(variables_indices)
}

table_of_unique_values <- function(data, variables_indices) {
  unique_vars <- unique(data[, variables_indices,drop=FALSE])
  # the additional unique() in levels should not be necessary,
  #   but on fedora build getting duplicate factor levels (2024-01-14)
  unique_vars$num_levels_factors <- table(factor(apply(data[, variables_indices,drop=FALSE],1,paste,collapse="\r"),
                      levels=unique(apply(unique_vars, 1,paste,collapse="\r"))))
  return(unique_vars)
}

compute_preds_for_all_var_combinations <- function(data, gbm_fit_obj, all_combinations_vars, variables_indices, num_trees) {
  preds_for_comb_vars <- vector("list", length(all_combinations_vars))
  for(vars in seq_along(all_combinations_vars)) {
    # Get data for combination
    preds_for_comb_vars[[vars]]$data <- 
      data.frame(table_of_unique_values(data,
                                        gbm_fit_obj$variables$var_names[variables_indices[all_combinations_vars[[vars]]]]))
    
    preds_for_comb_vars[[vars]]$num_levels_factors <- as.numeric(preds_for_comb_vars[[vars]]$data$num_levels_factors)
    preds_for_comb_vars[[vars]]$data$num_levels_factors <- NULL
    
    # Make predictions using the current combination of variables
    preds_for_comb_vars[[vars]]$preds <- .Call("gbm_plot",
                                               X = data.matrix(preds_for_comb_vars[[vars]]$data),
                                               i.var = as.integer(variables_indices[all_combinations_vars[[vars]]] - 1),
                                               n.trees = as.integer(num_trees),
                                               initF = as.double(gbm_fit_obj$initF),
                                               trees = trees(gbm_fit_obj),
                                               c.splits = gbm_fit_obj$c.splits,
                                               var.type = as.integer(gbm_fit_obj$variables$var_type),
                                               PACKAGE = "gbm3")
    
    # Convert predictions to flat matrix
    preds_for_comb_vars[[vars]]$preds <- matrix(preds_for_comb_vars[[vars]]$preds, ncol=1, byrow=FALSE)
    
    # Centre the predictions
    preds_for_comb_vars[[vars]]$preds <- apply(preds_for_comb_vars[[vars]]$preds, 2, function(x, w){
      x - weighted.mean(x, w, na.rm=TRUE)
    }, w=preds_for_comb_vars[[vars]]$num_levels_factors)
    
    # precompute the sign of these terms to appear in H - statistic
    # if same "parity" return 1, else -1
    preds_for_comb_vars[[vars]]$sign <- ifelse((length(all_combinations_vars[[vars]]) %% 2) == 
                                                        (length(variables_indices) %% 2), 1, -1)
  }
  
  return(preds_for_comb_vars)
}
gbm-developers/gbm3 documentation built on April 28, 2024, 10:04 p.m.