#' Use a model to make predictions on new data
#'
#' This function takes new data, and produces predictions, error estimates, and SHAP values
#' @param model The model object.
#' @param name The name of the perturbation.
#' @param indx Integer index used, for progress report.
#' @param total Integer of the total number of perturbations passed to this function, for progress report.
#' @param new_data A dataframe of new cases with predictors as columns. Sample names are row names.
#' @keywords model predictions
#' @import xgboost purrr fastshap
#' @export
#' @examples
#' make_new_data_predictions(my_model,"ko_ctnnb1",1,1,my_new_data)
make_new_data_predictions <- function(model, name, indx, total, new_data){
# This calculates the Shapley values for the new data
get_xgb_shap_pred <- function(model, data){
pfun <- function(object, newdata) {
predict(object, newdata = newdata)
}
shap_obj <- fastshap::explain(model, exact = TRUE,
X = data,
pred_wrapper = pfun, adjust = TRUE)
contrib <- tibble(
term = names(shap_obj),
value = apply(shap_obj, MARGIN = 2, FUN = function(x) sum(abs(x)))
) %>% arrange(desc(value))
pos_terms <- contrib %>% filter(value > 0) %>% pull(term)
shap_obj <- shap_obj %>% as.data.frame()
rownames(shap_obj) <- rownames(data)
return(list(shap_values = shap_obj, shap_table = contrib, good_terms = pos_terms))
}
cat(glue::glue("[{lubridate::now('US/Eastern')}] Making predictions for {name} ({indx} of {total}) .."),sep="\n")
flush.console()
# Keep only the features needed by the model
new_data <- new_data[, model$model$feature_names]
# Convert to DMatrix
new_data_dm <- xgb.DMatrix(new_data %>% as.matrix())
# Make predictions and error estimates for each sample
predictions <- predict(model$model, new_data_dm)
error <- predict(model$error_model, new_data_dm)
# Explain the predictions
shap <- get_xgb_shap_pred(model$model, new_data %>% as.matrix())
# Attach new data outputs to the original model
model$new_data$data <- new_data
model$new_data$predictions <- predictions
model$new_data$predictions_error <- error
names(model$new_data$predictions) <- rownames(new_data)
names(model$new_data$predictions_error) <- rownames(new_data)
model$new_data$feature_contribution <- shap$shap_table
model$new_data$important_features <- shap$good_terms
model$new_data$shap_values <- shap$shap_values
return(model)
}
#' Use a batch of models to make predictions on new data
#'
#' This function takes a list of models and makes predictions on new data.
#' @param models A list with model objects generated by make_xgb_models.
#' @param new_data A dataframe of new cases with predictors as columns. Sample names are row names.
#' @param models_to_use Optional vector with subset of names of models to use.
#' @keywords model predictions
#' @import xgboost purrr fastshap
#' @export
#' @examples
#' make_new_data_predictions(my_model,"ko_ctnnb1",1,1,my_new_data)
add_predictions <- function(models, new_data, models_to_use = NULL){
# Subset to only needed models if provided
if(!is.null(models_to_use) && length(models_to_use) > 0) models <- models[models_to_use]
# Generate an input list
inputs <- list()
inputs$model <- models
inputs$name <- names(models)
inputs$indx <- seq_along(models)
models_with_predictions <- pmap(inputs, make_new_data_predictions,
total = length(inputs$indx),
new_data = new_data)
return(models_with_predictions)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.