R/summarize.R

Defines functions summarize_contributions get_variable_predictions summarize_error summarize_predictions_in_parallel summarize_predictions summarize_models

Documented in get_variable_predictions summarize_contributions summarize_error summarize_models summarize_predictions summarize_predictions_in_parallel

#' Summarize models
#'
#' This function generates a table with summary stats of each model
#' @param models A list of model objects generated by make_xgb_models
#' @keywords scores
#' @import glue
#' @export
#' @examples
#' summarize_models(my_models)
summarize_models <- function(models){
  
  n_term_table <- map(models,"important_features") %>%
    map(length) %>%
    as_tibble() %>%
    pivot_longer(everything(),names_to="brd",values_to="n_terms")
  
  
  good_models = n_term_table %>% filter(n_terms > 0) %>% pull(brd)
  if(length(good_models) > 0){
    error_table <- map(models[n_term_table %>% filter(n_terms > 0) %>% pull(brd)],"predictions_error") %>% 
      map(function(x) x^2) %>%
      map(mean, na.rm = T) %>% 
      as_tibble() %>%
      pivot_longer(everything(),names_to="brd",values_to="error")
  } 
  
  r_table <- map(models,"scores") %>%
    map(mean) %>% 
    as_tibble() %>%
    pivot_longer(everything(),names_to="brd",values_to="r")
  
  r_squared_table <- map(models, "scores") %>%
    map(function(x) x^2) %>%
    map(mean) %>% 
    as_tibble() %>%
    pivot_longer(everything(),names_to="brd",values_to="r.squared")
  
  
  R2_table <- map(models,"scores_R2") %>%
    map(mean) %>% 
    as_tibble() %>%
    pivot_longer(everything(),names_to="brd",values_to="R2")
  
  
  rmse_table <- map(models, "scores_rmse") %>%
    map(mean) %>% 
    as_tibble() %>%
    pivot_longer(everything(),names_to="brd",values_to="rmse")
  
  
  # Discrete scores
  d_sensitivity_table <- map(models,"scores_d_sensitivity") %>%
    map(mean) %>% 
    as_tibble() %>%
    pivot_longer(everything(),names_to="brd",values_to="d_sensitivity")
  
  d_specificity_table <- map(models,"scores_d_specificity") %>%
    map(mean) %>% 
    as_tibble() %>%
    pivot_longer(everything(),names_to="brd",values_to="d_specificity")
  
  d_fpr_table <- map(models,"scores_d_fpr") %>%
    map(mean) %>% 
    as_tibble() %>%
    pivot_longer(everything(),names_to="brd",values_to="d_FPR")
  
  d_ppv_table <- map(models,"scores_d_ppv") %>%
    map(mean) %>% 
    as_tibble() %>%
    pivot_longer(everything(),names_to="brd",values_to="d_PPV")
  
  d_npv_table <- map(models,"scores_d_npv") %>%
    map(mean) %>% 
    as_tibble() %>%
    pivot_longer(everything(),names_to="brd",values_to="d_NPV")
  
  d_accuracy <- map(models,"scores_d_accuracy") %>%
    map(mean) %>% 
    as_tibble() %>%
    pivot_longer(everything(),names_to="brd",values_to="d_accuracy")
  

  # Merge all mini tables into one large scores table
  scores_table <- r_table %>%
    left_join(r_squared_table, by = "brd") %>%
    left_join(R2_table, by = "brd") %>%
    left_join(rmse_table, by = "brd") %>%
    left_join(d_sensitivity_table, by = "brd") %>%
    left_join(d_specificity_table, by = "brd") %>%
    left_join(d_fpr_table, by = "brd") %>%
    left_join(d_ppv_table, by = "brd") %>%
    left_join(d_npv_table, by = "brd") %>%
    left_join(d_accuracy, by = "brd") %>%
    left_join(n_term_table, by = "brd") %>%
    arrange(desc(r))  
  
  if(length(good_models)>0) scores_table <- scores_table %>% left_join(error_table, by = "brd")  
  
  return(scores_table)
  
}


#' Summarize predictions
#'
#' This function generates a table with predictions for each perturbation in a list of models.
#' @param models A list of model objects generated by make_xgb_models and with predictions added using add_predictions.
#' @keywords predictions
#' @import purrr
#' @export
#' @examples
#' summarize_predictions(my_models)
summarize_predictions <- function(models){
  
  predictions_table <- models %>% 
    map("new_data") %>%
    map("predictions") %>% 
    map(as_tibble, rownames = "sample") %>%
    reduce(left_join, by = "sample") %>%
    column_to_rownames("sample")
  
  colnames(predictions_table) <- names(models)
  
  return(predictions_table)
  
}



#' Summarize predictions in parallel
#'
#' This function generates a table with predictions for each perturbation in a list of models using futures.
#' @param models A list of model objects generated by make_xgb_models and with predictions added using add_predictions.
#' @keywords predictions
#' @import purrr furrr
#' @export
#' @examples
#' summarize_predictions_in_parallel(my_models)
summarize_predictions_in_parallel <- function(models){
  
  predictions_table <- models %>% 
    furrr::future_map("new_data") %>%
    furrr::future_map("predictions") %>% 
    furrr::future_map(as_tibble, rownames = "sample") %>%
    reduce(left_join, by = "sample") %>%
    column_to_rownames("sample")
  
  colnames(predictions_table) <- names(models)
  
  return(predictions_table)
  
}


#' Summarize prediction error
#'
#' This function generates a table with prediction error for each perturbation in a list of models.
#' @param models A list of model objects generated by make_xgb_models and with predictions added using add_predictions.
#' @keywords predictions error
#' @import purrr
#' @export
#' @examples
#' summarize_predictions(my_models)
summarize_error <- function(models){
  
  error_table <- models %>% 
    map("new_data") %>% 
    map("predictions_error") %>% 
    map(as_tibble, rownames = "sample") %>%
    reduce(left_join, by = "sample") %>%
    column_to_rownames("sample")
  
  colnames(error_table) <- names(models)
  
  return(error_table)
  
}

#' Find most variable predictions
#'
#' This function generates a vector of perturbation names with the most variable predictions, given a predictions table.
#' @param predictions_table A predictions table generated with summarize_predictions
#' @param n Number of variable predictions returned (default = 50).
#' @keywords scores
#' @import glue
#' @export
#' @examples
#' summarize_models(my_models)
get_variable_predictions <- function(predictions_table, n = 50){
  
  var_preds <- predictions_table %>% 
    rownames_to_column("sample") %>% 
    pivot_longer(starts_with("ko_"), names_to = "brd", values_to = "prediction") %>%
    group_by(brd) %>% 
    summarize(var = var(prediction)) %>%
    arrange(desc(var)) %>% 
    top_n(n, wt=var) %>% 
    pull(brd)
  
  return(var_preds)
  
}




#' Summarize SHAP values
#'
#' This function generates one consolidated matrix with SHAP values for each perturbation.
#' @param models A list of model objects generated by make_xgb_models and with predictions added using add_predictions.
#' @keywords SHAP contribution
#' @import purrr
#' @export
#' @examples
#' summarize_contributions(my_models)
summarize_contributions <- function(models, use_new_data = FALSE, min_abs_contribution = 0.05){
  
  if (use_new_data){
    shap_table <- map(models, "new_data") %>% map("feature_contribution")     
  } else {
    shap_table <- map(models, "feature_contribution")
  }

  for (this_ko in names(shap_table)){
    colnames(shap_table[[this_ko]]) <- c("term",this_ko)
  }
  
  shap_table <- shap_table %>%
    reduce(full_join, by = "term") %>% 
    column_to_rownames("term") %>% 
    as.matrix() %>%
    t()
  
  # Remove features with little contribution in all perturbations
  shap_table <- shap_table[ , colSums(abs(shap_table) > min_abs_contribution) > 0]
  
  return(shap_table)
  
}
Mushriq/mixmap documentation built on Jan. 28, 2024, 7:22 p.m.