R/add_predictions.R

Defines functions add_predictions make_new_data_predictions

Documented in add_predictions make_new_data_predictions

#' Use a model to make predictions on new data
#'
#' This function takes new data, and produces predictions, error estimates, and SHAP values
#' @param model The model object.
#' @param name The name of the perturbation.
#' @param indx Integer index used, for progress report.
#' @param total Integer of the total number of perturbations passed to this function, for progress report.
#' @param new_data A dataframe of new cases with predictors as columns. Sample names are row names.
#' @keywords model predictions
#' @import xgboost purrr fastshap
#' @export
#' @examples
#' make_new_data_predictions(my_model,"ko_ctnnb1",1,1,my_new_data)
make_new_data_predictions <- function(model, name, indx, total, new_data){
  
  # This calculates the Shapley values for the new data
  get_xgb_shap_pred <- function(model, data){
    
    pfun <- function(object, newdata) {
      predict(object, newdata = newdata)
    }
    
    shap_obj <- fastshap::explain(model, exact = TRUE, 
                                  X = data,
                                  pred_wrapper = pfun, adjust = TRUE)
    
    
    contrib <- tibble(
      term = names(shap_obj),
      value = apply(shap_obj, MARGIN = 2, FUN = function(x) sum(abs(x)))
    ) %>% arrange(desc(value))
    
    pos_terms <- contrib %>% filter(value > 0) %>% pull(term)
    
    shap_obj <- shap_obj %>% as.data.frame()
    
    rownames(shap_obj) <- rownames(data)   
    
    return(list(shap_values = shap_obj, shap_table = contrib, good_terms = pos_terms))
    
  }
  
  
  cat(glue::glue("[{lubridate::now('US/Eastern')}] Making predictions for {name} ({indx} of {total}) .."),sep="\n")
  flush.console()
  
  # Keep only the features needed by the model
  new_data <- new_data[, model$model$feature_names]
  
  # Convert to DMatrix
  new_data_dm <- xgb.DMatrix(new_data %>% as.matrix())
  
  # Make predictions and error estimates for each sample
  predictions <- predict(model$model, new_data_dm)
  error <- predict(model$error_model, new_data_dm)
  
  # Explain the predictions
  shap <- get_xgb_shap_pred(model$model, new_data %>% as.matrix())
  
  # Attach new data outputs to the original model
  model$new_data$data <- new_data
  model$new_data$predictions <- predictions
  model$new_data$predictions_error <- error
  names(model$new_data$predictions) <- rownames(new_data)
  names(model$new_data$predictions_error) <- rownames(new_data)
  model$new_data$feature_contribution <- shap$shap_table
  model$new_data$important_features <- shap$good_terms
  model$new_data$shap_values <- shap$shap_values
  
  return(model)
}


#' Use a batch of models to make predictions on new data
#'
#' This function takes a list of models and makes predictions on new data.
#' @param models A list with model objects generated by make_xgb_models.
#' @param new_data A dataframe of new cases with predictors as columns. Sample names are row names.
#' @param models_to_use Optional vector with subset of names of models to use.
#' @keywords model predictions
#' @import xgboost purrr fastshap
#' @export
#' @examples
#' make_new_data_predictions(my_model,"ko_ctnnb1",1,1,my_new_data)
add_predictions <- function(models, new_data, models_to_use = NULL){
  
  # Subset to only needed models if provided
  if(!is.null(models_to_use) && length(models_to_use) > 0) models <- models[models_to_use]
  
  # Generate an input list
  inputs <- list()
  inputs$model <- models
  inputs$name <- names(models)
  inputs$indx <- seq_along(models)
  
  models_with_predictions <- pmap(inputs, make_new_data_predictions, 
                                  total = length(inputs$indx),
                                  new_data = new_data)
  
  
  return(models_with_predictions)
  
}
Mushriq/mixmap documentation built on Jan. 28, 2024, 7:22 p.m.