R/make_xgb_models.R

Defines functions fit_models_in_parallel fit_models_and_save fit_depmap_models make_xgb_model

Documented in fit_depmap_models fit_models_and_save fit_models_in_parallel make_xgb_model

#' Make predictive models of dependencies
#'
#' This function creates an XGBoost model
#' @param perturbation Column name of the perturbation (e.g. "ko_ctnnb1").
#' @param indx Integer index used, for progress report.
#' @param total Integer of the total number of perturbations passed to this function, for progress report.
#' @param dataset A dataframe with the perturbation in a column and all other predictors. Sample names are row names.
#' @param response_cutoff The value above which the sample is considered sensitive.
#' @param weight_cap The maximum weight of each minority case when resampling. Set to 0 if no resampling needed.
#' @param nfolds The number of folds in k-fold cross validation.
#' @param nrepeats The number of repeats in k-fold cross validation.
#' @param nrounds The maximum number of trees in the XGBoost model.
#' @param min_score The minimum number of r value for a model to be considered for the next stage (making predictions and calculating SHAP values).
#' @param skip_eval Default = FALSE. If TRUE, k-fold CV will not be conducted and instead all models will be pushed to the next stage.
#' @param use_gpu Default = TRUE. Set to FALSE if using CPU.
#' @keywords model
#' @import xgboost purrr fastshap
#' @export
#' @examples
#' make_xgb_model("ko_ctnnb1",1,1,my_data)
make_xgb_model <- function(perturbation, indx, total, dataset, 
                            response_cutoff = 0.75, decreasing = F,
                            weight_cap = 0.05,
                            nfolds = 3, 
                            nrepeats = 3, 
                            nrounds = 100, 
                            max_depth = 3,
                            f_subsample = 1,
                            min_score = 0.5, 
                            skip_eval = FALSE,
                            shuffle = FALSE,
                            n_threads = 4,
                            xgb_params = NULL,
                            cor_data = NULL, cor_n_features = 1000,
                            use_gpu = TRUE, gpu_id = 0){
  
  cat(glue::glue("[{lubridate::now('US/Eastern')}] Training a model for {perturbation} ({indx} of {total}) .."))
  flush.console()
  
  
  # This keeps one column of dependency scores (renamed 'y_value') plus all predictors
  prepare_model_data <- function(perturbation, data, tag = "ko_", response_cutoff = 0.75, nfolds = 3, nrepeats = 3, cor_data = NULL, cor_num = 1000){
    
    if (is.null(cor_data)){
      
      prepared_data <- data %>%
        mutate(y_value = get(perturbation)) %>% 
        select(-starts_with(tag)) %>%
        na.omit() %>% as_tibble(rownames = "cell_line") 
      
    } else {
      
      # Check if the perturbation is in the correlation matrix or skip
      correlated_features <- NULL
      
      if (perturbation %in% colnames(cor_data)){
        
        correlated_features <- cor_data %>%
          top_n(cor_num, abs(get(perturbation))) %>%
          pull(feature)
        
      }

      
      # if no features then use all features
      if (!is.null(correlated_features) & length(correlated_features) > 0) {
        
        prepared_data <- data %>%
          mutate(y_value = get(perturbation)) %>% 
          select(-starts_with(tag)) %>%
          select(y_value, any_of(correlated_features))  %>%
          na.omit() %>% as_tibble(rownames = "cell_line") 
        
      } else {
        
        prepared_data <- data %>%
          mutate(y_value = get(perturbation)) %>% 
          select(-starts_with(tag)) %>%
          na.omit() %>% as_tibble(rownames = "cell_line") 
        
      }
      

      
      
    }
 
    
    # Create a response column to help stratify cases
    prepared_data <- prepared_data %>% column_to_rownames("cell_line") %>% mutate(response = y_value > response_cutoff)
    
    # Note: We are using the full data as there is no tuning right now.
    data_folds <- vfold_cv(prepared_data, v = nfolds, strata = y_value, repeats = nrepeats, breaks = 20, pool = 0.05)
    
    output <- list()
    output$original_data <- prepared_data
    output$dfolds <- data_folds
    
    return(output)
    
  }
  
  
  # This creates an object that stores model parameters
  # Ideally this is tuning the parameters but we skip this for now
  prepare_model_params <- function(data, xgb_params){
    
    params <- list()
    params$booster <- "gbtree"
    params$objective <- "reg:squarederror"
    
    
    # These parameters seem to do OK on average for all perturbations
    params$eta <- 0.04 # 0.04
    params$gamma <- 0 # 0.01
    params$alpha <- 0.35
    params$lambda <- 0.7
    #params$max_depth = 3
    # params$subsample = 1
    params$sampling_method = "gradient_based"
    params$colsample_bytree = 1
    params$colsample_bylevel = 0.2 # 0.2
    params$colsample_bynode = 0.8 # 0.8
    
    # If user provided other params, overwrite the baseline
    if (is.null(xgb_params$eta)) params$eta <- 0.04 else params$eta <- xgb_params$eta
    if (is.null(xgb_params$gamma)) params$gamma <- 0 else params$gamma <- xgb_params$gamma
    if (is.null(xgb_params$alpha)) params$alpha <- 0.35 else params$alpha <- xgb_params$alpha
    if (is.null(xgb_params$lambda)) params$lambda <- 0.7 else params$lambda <- xgb_params$lambda
    if (is.null(xgb_params$colsample_bytree)) params$colsample_bytree <- 1 else params$colsample_bytree <- xgb_params$colsample_bytree
    if (is.null(xgb_params$colsample_bylevel)) params$colsample_bylevel <- 0.2 else params$colsample_bylevel <- xgb_params$colsample_bylevel
    if (is.null(xgb_params$colsample_bynode)) params$colsample_bynode <- 0.8 else params$colsample_bynode <- xgb_params$colsample_bynode
    
    return(params)
    
  }
  
  # This calculates weights per case based on response
  get_weights <- function(y_value, response_cutoff, weight_cap = 0.05){
    
    # If all samples are above/below the cutoff, return equal weights
    if (weight_cap == 0 || sum(y_value >= response_cutoff) == 0 || sum(y_value < response_cutoff) == 0){
      return(rep(1/length(y_value),times=length(y_value)))
    }
    
    # We count how many cases we have of each response 'status'
    status_counts <- table(if_else(y_value >= response_cutoff, "A", "B"))
    
    # Decide the majority group
    status_major_count <- if_else(status_counts["A"] > status_counts["B"], status_counts["A"], status_counts["B"])
    
    # We calculate the weight of an individual sensitivity status
    status_weight <- 1/status_counts
    
    # We assign the weight to each observation based on observed sensitivity
    weights <- if_else(y_value > response_cutoff, status_weight["A"], status_weight["B"])
    
    # We normalize so that total weights add up to 1
    weights <- weights/sum(weights)
    
    # We cap each observation's individual weight at weight_cap
    weights <- if_else(weights > weight_cap, weight_cap, weights)
    
    # We redistribute the 'lost' weight from the capping step to the remaining samples
    leftover_weight <- (1 - sum(weights))/status_major_count
    weights <- if_else(weights == weight_cap, weight_cap, weights + leftover_weight)
    
    return(weights)
    
  }
  
  # This resamples the given data slice     
  get_weighted_set <- function(data, weights){
    
    data <- data %>%
      slice_sample(prop = 1,
                   replace = TRUE, 
                   weight_by = weights
      )
    
    
    return(data)
    
  }
  
  # This puts the data in DMatrix format for xgboost
  # Optional: To generate a null model we can shuffle the outcome here
  get_DMatrix <- function(data, weights = NULL, shuffle = FALSE){
    
    x_values <- data %>% select(-"y_value",-"response") %>% as.matrix()
    y_values <- data %>% pull(y_value)
    if (shuffle) y_values <- sample(y_values)
    if(!is.null(weights)){
      data <- xgb.DMatrix(data = x_values, label = y_values, weight = 1000*weights) 
    } else {
      data <- xgb.DMatrix(data = x_values, label = y_values) 
    }
    
    
    return(data)
  }
  
  # This calculates SHAP values
  get_xgb_shap <- function(model, data){
    
    pfun <- function(object, newdata) {
      predict(object, newdata = newdata)
    }
    
    shap_obj <- fastshap::explain(model, 
                                  exact = TRUE,
                                  X = data %>% select(-"y_value",-"response") %>% as.matrix(), 
                                  pred_wrapper = pfun, 
                                  adjust = TRUE)
    
    contrib <- tibble(
      term = names(shap_obj),
      value = apply(shap_obj, MARGIN = 2, FUN = function(x) sum(abs(x)))
    ) %>% arrange(desc(value))
    
    nonzero_terms <- contrib %>% filter(value > 0) %>% pull(term)
    
    shap_obj <- shap_obj %>% as.data.frame()
    rownames(shap_obj) <- rownames(data)   
    
    shap_output <- list()
    shap_output$shap_values = shap_obj
    shap_output$shap_table = contrib
    shap_output$good_terms = nonzero_terms
    
    return(shap_output)
    
  }
  
  
  # If the SD is zero, cor() will throw an error
  get_pseudo_cor <- function(x, y){
    
    if (sd(x) == 0 | sd(y) == 0){
      
      x[1] = x[1] + 1e-6
      y[1] = y[1] + 1e-6
      
    } else {
      
      # Do nothing
      
    }
    
    return(cor(x,y))
    
    
  }
  
  
  get_rmse <- function(x, y){
    return( sqrt( mean( (x - y)^2  ) ) )
  }
  
  
  get_R2 <- function(x, y){
    
    return( 1 - sum( ( x - y )^2 )/sum( ( y - mean(y) )^2 ) )
    
  }
  
  
  # Binarized scores
  get_discrete_sensitivity <- function(pred, obs, discrete_cut, decreasing = F){
    
    if (decreasing){
      pred_d = if_else(pred <= discrete_cut, T, F)
      obs_d = if_else(obs <= discrete_cut, T, F)
    } else {
      pred_d = if_else(pred >= discrete_cut, T, F)
      obs_d = if_else(obs >= discrete_cut, T, F)
    }

    
    TP = sum(pred_d & obs_d)
    FN = sum(!pred_d & obs_d)
    
    result = TP / (TP + FN)
    return(result)
    
  }
  get_discrete_specificity <- function(pred, obs, discrete_cut, decreasing = F){
    
    if (decreasing){
      pred_d = if_else(pred <= discrete_cut, T, F)
      obs_d = if_else(obs <= discrete_cut, T, F)
    } else {
      pred_d = if_else(pred >= discrete_cut, T, F)
      obs_d = if_else(obs >= discrete_cut, T, F)
    }
    
    TN = sum(!pred_d & !obs_d)
    FP = sum(pred_d & !obs_d)
    
    result = TN / (TN + FP)
    return(result)
  }
  get_discrete_fpr <- function(pred, obs, discrete_cut, decreasing = F){
    
    if (decreasing){
      pred_d = if_else(pred <= discrete_cut, T, F)
      obs_d = if_else(obs <= discrete_cut, T, F)
    } else {
      pred_d = if_else(pred >= discrete_cut, T, F)
      obs_d = if_else(obs >= discrete_cut, T, F)
    }
    
    FP = sum(pred_d & !obs_d)
    TN = sum(!pred_d & !obs_d)
    
    result = FP / (FP + TN)
    return(result)
    
  }
  get_discrete_ppv <- function(pred, obs, discrete_cut, decreasing = F){
    
    if (decreasing){
      pred_d = if_else(pred <= discrete_cut, T, F)
      obs_d = if_else(obs <= discrete_cut, T, F)
    } else {
      pred_d = if_else(pred >= discrete_cut, T, F)
      obs_d = if_else(obs >= discrete_cut, T, F)
    }
    
    TP = sum(pred_d & obs_d)
    FP = sum(pred_d & !obs_d)
    TN = sum(!pred_d & !obs_d)
    FN = sum(!pred_d & obs_d)
    
    result = TP / (TP + FP)
    return(result)
    
  }
  get_discrete_npv <- function(pred, obs, discrete_cut, decreasing = F){
    
    if (decreasing){
      pred_d = if_else(pred <= discrete_cut, T, F)
      obs_d = if_else(obs <= discrete_cut, T, F)
    } else {
      pred_d = if_else(pred >= discrete_cut, T, F)
      obs_d = if_else(obs >= discrete_cut, T, F)
    }
    
    TP = sum(pred_d & obs_d)
    FP = sum(pred_d & !obs_d)
    TN = sum(!pred_d & !obs_d)
    FN = sum(!pred_d & obs_d)
    
    result = TN / (TN + FP)
    return(result)
    
  }
  get_discrete_accuracy <- function(pred, obs, discrete_cut, decreasing = F){
    
    if (decreasing){
      pred_d = if_else(pred <= discrete_cut, T, F)
      obs_d = if_else(obs <= discrete_cut, T, F)
    } else {
      pred_d = if_else(pred >= discrete_cut, T, F)
      obs_d = if_else(obs >= discrete_cut, T, F)
    }
    
    TP = sum(pred_d & obs_d)
    FP = sum(pred_d & !obs_d)
    TN = sum(!pred_d & !obs_d)
    FN = sum(!pred_d & obs_d)
    
    result = (TP + TN) / (TP + FP + TN + FN)
    return(result)
  }
  
  
  # Step 1: Prepare the data (keep only this perturbation's outcome values and split into folds)
  model_data <- prepare_model_data(perturbation = perturbation, 
                                   data = dataset, 
                                   response_cutoff = response_cutoff, 
                                   nfolds = nfolds, nrepeats = nrepeats,
                                   cor_data = cor_data, cor_num = cor_n_features)
  
  # Step 2: Define parameters
  model_params <- prepare_model_params(data = model_data, xgb_params = xgb_params)
  
  # Step 3: Assess current parameters with repeated k-fold CV
  # Ideally we are tuning hyperparameters here
  if(!skip_eval){
    
    data_splits = model_data$dfolds$splits
    
    # Grab the analysis parts, for each: resample with bias, and create a training DMatrix
    training_sets <- map(data_splits, analysis)
    
    if(weight_cap > 0){    
      training_weights <- training_sets %>% map(pull, y_value) %>% map(get_weights, response_cutoff = response_cutoff, weight_cap = weight_cap)
      # training_matrices <- training_sets %>% map2(training_weights, get_weighted_set) %>% map(get_DMatrix)
      training_matrices <- training_sets %>% map2(training_weights, get_DMatrix, shuffle = shuffle)
    } else {
      training_matrices <- training_sets %>% map(get_DMatrix, shuffle = shuffle)
    }
    
    # Grab the assessment parts, for each: do as above + pull the y_values for later use
    validation_sets <- map(data_splits, assessment)
    
    if(weight_cap > 0){
      validation_weights <- validation_sets %>% map(pull, y_value) %>% map(get_weights, response_cutoff = response_cutoff, weight_cap = weight_cap)
      #  validation_matrices <- validation_sets %>% map2(validation_weights, get_weighted_set) 
      validation_matrices <- validation_sets %>% map2(validation_weights, get_DMatrix) 
    } else {
      validation_matrices <- validation_sets %>% map(get_DMatrix)
    }
    
    validation_y_values <- map(validation_matrices, getinfo, "label") #instead of pull and "y_value")
    #validation_matrices <- map(validation_matrices, get_DMatrix)
    
    # Use the analysis subsets for creating a model, then use the assessment subset to make predictions and calculate correlation
    score_models <- map(training_matrices, xgboost, 
                        params = model_params,
                        max_depth = max_depth,
                        subsample = f_subsample,
                        nthread = n_threads,
                        max_bin = 64,
                        tree_method = if_else(use_gpu,"gpu_hist","auto"),
                        gpu_id = gpu_id,
                        nrounds = nrounds,
                        early_stopping_rounds = 10, 
                        verbose = 0)
    
    score_predictions <-  score_models %>% map2(validation_matrices, predict)
    
    scores <- score_predictions %>% map2(validation_y_values, get_pseudo_cor) %>% unlist()
    
    scores_rmse <- score_predictions %>% map2(validation_y_values, get_rmse) %>% unlist()
    
    scores_R2 <- score_predictions %>% map2(validation_y_values, get_R2) %>% unlist()
    
    
    # Discrete scores
    scores_d_sensitivity <- score_predictions %>% map2(validation_y_values, get_discrete_sensitivity, response_cutoff, decreasing) %>% unlist()
    scores_d_specificity <- score_predictions %>% map2(validation_y_values, get_discrete_specificity, response_cutoff, decreasing) %>% unlist()
    scores_d_fpr <- score_predictions %>% map2(validation_y_values, get_discrete_fpr, response_cutoff, decreasing) %>% unlist()
    scores_d_ppv <- score_predictions %>% map2(validation_y_values, get_discrete_ppv, response_cutoff, decreasing) %>% unlist()
    scores_d_npv <- score_predictions %>% map2(validation_y_values, get_discrete_npv, response_cutoff, decreasing) %>% unlist()
    scores_d_accuracy <- score_predictions %>% map2(validation_y_values, get_discrete_accuracy, response_cutoff, decreasing) %>% unlist()
    
    
    # Clean up
    rm(score_models)
    rm(data_splits)
    rm(training_matrices)
    rm(validation_matrices)
    
    
  } else {
    
    scores <- rep(1,9)
    scores_R2 <- rep(1,9)
    scores_rmse <- rep(0,9)
    scores_d_sensitivity <- rep(1,9)
    scores_d_specificity <- rep(1,9)
    scores_d_fpr <- rep(1,9)
    scores_d_ppv <- rep(1,9)
    scores_d_npv <- rep(1,9)
    scores_d_accuracy <- rep(1,9)
    
  }        
  
  
  cat(glue::glue(" R^2 = {round(mean(scores^2),3)} +/- {round(1.96*sd(scores^2),3)} , RMSE = {round(mean(scores_rmse),5)} , (n={length(scores)})"))
  flush.console()
  
  
  # Prepare output
  output <- list()  
  output$perturbation_name <- perturbation
  output$scores <- scores
  output$scores_R2 <- scores_R2
  output$scores_rmse <- scores_rmse
  output$scores_d_sensitivity <- scores_d_sensitivity
  output$scores_d_specificity <- scores_d_specificity
  output$scores_d_fpr <- scores_d_fpr
  output$scores_d_ppv <- scores_d_ppv
  output$scores_d_npv <- scores_d_npv
  output$scores_d_accuracy <- scores_d_accuracy
  
  # If the score is good enough, we proceed with extra steps                                                                              
  if (!is.na(mean(scores)) & mean(scores^2) >= min_score){
    
    # Fit one last model using all data    
    last_params <- model_params # Ideally we have found the best params and we set them here
    last_nrounds <- nrounds # Ideally this has been tuned too
    
    last_weights <- model_data$original_data %>% pull(y_value) %>% get_weights(response_cutoff = response_cutoff, weight_cap = weight_cap)
    
    if (weight_cap > 0){
      #   last_matrix <-  get_weighted_set(model_data$original_data, last_weights) %>% get_DMatrix()
      last_matrix <-  get_DMatrix(model_data$original_data, last_weights, shuffle = shuffle)
    } else {   
      last_matrix <- get_DMatrix(model_data$original_data, shuffle = shuffle)
    }
    
    
    # We create a DMatrix using the original (non-resampled) data
    last_validation <- get_DMatrix(model_data$original_data)
    
    # We fit a last model
    last_model <- xgboost(data = last_matrix, 
                          params = last_params, 
                          max_depth = max_depth,
                          subsample = f_subsample,
                          nthread = n_threads,
                          max_bin = 64,
                          tree_method = if_else(use_gpu,"gpu_hist","auto"),
                          gpu_id = gpu_id,
                          nrounds = last_nrounds, 
                          early_stopping_rounds = 10, verbose = 0)
    
    # We collect the predictions on the same data
    last_predictions <- predict(last_model, newdata = last_validation)
    names(last_predictions) <- rownames(model_data$original_data)
    
    null_prediction <- predict(last_model, newdata = last_matrix) %>% mean()
    
    # Create an error estimate
    errors <- (last_predictions - model_data$original_data$y_value)
    names(errors) <- rownames(model_data$original_data)
    
    # Create a matrix of errors vs features
    error_data <- xgb.DMatrix(data = model_data$original_data %>% select(-"y_value",-"response") %>% as.matrix(),
                              label = errors^2)
    
    # Fit a model on error (using default params)
    error_model <- xgboost(data = error_data, params = last_params,
                           max_depth = max_depth,
                           subsample = f_subsample,
                           nrounds = last_nrounds, early_stopping_rounds = 10, 
                           max_bin = 64,
                           nthread = n_threads,
                           tree_method = if_else(use_gpu,"gpu_hist","auto"),
                           gpu_id = gpu_id,
                           verbose = 0) 
    
    cat(glue::glue(" E = +/- {round(1.96*sqrt(mean(errors^2,na.rm=T)),3)}"), sep = "\n")
    flush.console()
    
    # Get feature contributions                                                                          
    shap <- get_xgb_shap(last_model, model_data$original_data)
    
    
    # Finish preparing outputs
    output$model <- last_model
    output$error_model <- error_model 
    output$null_prediction <- null_prediction
    output$predictions <- last_predictions
    output$predictions_error <- errors
    output$feature_contribution <- shap$shap_table
    output$important_features <- shap$good_terms
    output$shap_values <- shap$shap_values
    output$sample_names <- rownames(model_data$original_data)
    output$feature_names <- setdiff(colnames(model_data$original_data),c("y_value","response"))
    
    # Clean up
    rm(last_model)
    rm(error_model)
    rm(last_matrix)
    rm(shap)
    gc()
    
    # output$data <- model_data$original_data
    
  } else {
    
    # This model isn't good enough, so we save some time and skip this step.
    cat(glue::glue(" Skipped"), sep = "\n")
    flush.console()
    
  }                                                                            
  
  
  
  return(output)
}


#' Make a list of predictive models of dependencies
#'
#' This function creates an XGBoost model for each perturbation given, and returns a list of model objects.
#' @param perturbation Column name of the perturbation (e.g. "ko_ctnnb1").
#' @param indx Integer index used, for progress report.
#' @param total Integer of the total number of perturbations passed to this function, for progress report.
#' @param dataset A dataframe with the perturbation in a column and all other predictors. Sample names are row names.
#' @param response_cutoff The value above which the sample is considered sensitive.
#' @param weight_cap The maximum weight of each minority case when resampling. Set to 0 if no resampling needed.
#' @param nfolds The number of folds in k-fold cross validation.
#' @param nrepeats The number of repeats in k-fold cross validation.
#' @param nrounds The maximum number of trees in the XGBoost model.
#' @param min_score The minimum number of r^2 value for a model to be considered for the next stage (making predictions and calculating SHAP values).
#' @param skip_eval Default = FALSE. If TRUE, k-fold CV will not be conducted and instead all models will be pushed to the next stage.
#' @param use_gpu Default = TRUE. Set to FALSE if using CPU.
#' @keywords model
#' @import xgboost purrr fastshap
#' @export
#' @examples
#' fit_depmap_models(my_data, c("ko_ctnnb1","ko_myod1"))
fit_depmap_models <- function(depmap_data, models_to_make,
                              response_cutoff = 0.5, decreasing = FALSE,
                              weight_cap = 0,
                              nfolds = 3, nrepeats = 1, nrounds = 200, min_score = 0.5,
                              max_depth = 3,
                              f_subsample = 1,
                              skip_eval = FALSE, shuffle = FALSE,
                              n_threads = 4,
                              xgb_params = NULL,
                              cor_data = NULL, cor_n_features = 1000,
                              use_gpu = TRUE, gpu_id = 0){
  
  my_models <- map2(
    models_to_make, seq_along(models_to_make), make_xgb_model,  
    total = length(models_to_make),
    dataset = depmap_data,
    response_cutoff = response_cutoff, decreasing = decreasing,
    weight_cap = weight_cap,
    nfolds = nfolds, nrepeats = nrepeats, nrounds = nrounds, min_score = min_score,
    max_depth = max_depth,
    f_subsample = f_subsample,
    skip_eval = skip_eval, shuffle = shuffle, 
    xgb_params = xgb_params,
    n_threads = n_threads, 
    cor_data = cor_data, cor_n_features = cor_n_features,
    use_gpu = use_gpu, gpu_id = gpu_id)
  
  names(my_models) <- models_to_make
  
  return(my_models)
  
}




#' Make a list of models and save as file
#'
#' This function creates an XGBoost model for each perturbation given, saves the list of models, and returns a message.
#' @param perturbs A vector of perturbations.
#' @param chunk_indx Integer index used, for progress report.
#' @param model_dataset A dataframe with the perturbation in a column and all other predictors. Sample names are row names.
#' @param response_cutoff The value above which the sample is considered sensitive.
#' @param weight_cap The maximum weight of each minority case when resampling. Set to 0 if no resampling needed.
#' @param nfolds The number of folds in k-fold cross validation.
#' @param nrepeats The number of repeats in k-fold cross validation.
#' @param nrounds The maximum number of trees in the XGBoost model.
#' @param min_score The minimum number of r^2 value for a model to be considered for the next stage (making predictions and calculating SHAP values).
#' @param skip_eval Default = FALSE. If TRUE, k-fold CV will not be conducted and instead all models will be pushed to the next stage.
#' @param use_gpu Default = TRUE. Set to FALSE if using CPU.
#' @param seed Random seed
#' @param path Folder path (e.g. "/home/test/models") to save models in.
#' @keywords model
#' @import xgboost purrr fastshap tidyverse glue lubridate tidymodels rsample
#' @export
#' @examples
#' fit_models_and_save(my_data, c("ko_ctnnb1","ko_myod1"))
fit_models_and_save <- function(perturbs, chunk_indx, 
                                model_dataset, response_cutoff = 0.5, decreasing = FALSE,
                                weight_cap = 0,
                                nfolds = 3, nrepeats = 1, nrounds = 200, min_score = 0.5,
                                max_depth = 3,
                                f_subsample = 1,
                                skip_eval = FALSE, shuffle = FALSE,
                                xgb_params = NULL,
                                n_threads = 4,
                                cor_data = NULL, cor_n_features = 1000,
                                use_gpu = TRUE, gpu_id = 0, seed = 123, path = NULL){
  
  library(tidyverse)
  library(glue)
  library(purrr)
  library(lubridate)
  library(tidymodels)
  library(rsample)
  library(xgboost)
  library(fastshap)
  library(mixmap)
  
  if (is.null(path)) path = "."
  
  if (!file.exists(glue::glue("{path}/models_chunk_{chunk_indx}.rds"))){
    
    set.seed(seed)
    
    my_models <- fit_depmap_models(depmap_data = model_dataset, 
                                   models_to_make = perturbs, 
                                   response_cutoff = response_cutoff, decreasing = decreasing,
                                   weight_cap = weight_cap,
                                   nfolds = nfolds, nrepeats = nrepeats, nrounds = nrounds,
                                   min_score = min_score,
                                   max_depth = max_depth,
                                   f_subsample = f_subsample,
                                   skip_eval = skip_eval, shuffle = shuffle,
                                   xgb_params = xgb_params,
                                   cor_data = cor_data,
                                   cor_n_features = cor_n_features,
                                   n_threads = n_threads,
                                   use_gpu = use_gpu, gpu_id = gpu_id)
    
    
    
    saveRDS(my_models,glue::glue("{path}/models_chunk_{chunk_indx}.rds"))
    
    # Clean up
    rm(my_models)
    gc()
    
    return(glue::glue("Done chunk {chunk_indx}"))
    
  } else {
    
    
    return(glue::glue("Chunk already done {chunk_indx}"))
    
    
  }
  
  
}




#' Make a list of models and save as file (parallel)
#'
#' This function creates an XGBoost model for each perturbation given, saves the list of models, and returns a message.
#' @param perturbs A vector of perturbations.
#' @param chunk_indx Integer index used, for progress report.
#' @param model_dataset A dataframe with the perturbation in a column and all other predictors. Sample names are row names.
#' @param response_cutoff The value above which the sample is considered sensitive.
#' @param weight_cap The maximum weight of each minority case when resampling. Set to 0 if no resampling needed.
#' @param nfolds The number of folds in k-fold cross validation.
#' @param nrepeats The number of repeats in k-fold cross validation.
#' @param nrounds The maximum number of trees in the XGBoost model.
#' @param min_score The minimum number of r^2 value for a model to be considered for the next stage (making predictions and calculating SHAP values).
#' @param skip_eval Default = FALSE. If TRUE, k-fold CV will not be conducted and instead all models will be pushed to the next stage.
#' @param use_gpu Default = TRUE. Set to FALSE if using CPU.
#' @param seed Random seed
#' @param path Folder path (e.g. "/home/test/models") to save models in.
#' @keywords model
#' @import xgboost purrr furrr future fastshap tidyverse glue lubridate tidymodels rsample
#' @export
#' @examples
#' fit_models_in_parallel(my_data, c("ko_ctnnb1","ko_myod1"))
fit_models_in_parallel <- function(perturbs, chunk_size = 20, 
                                   model_dataset, response_cutoff = 0.5, decreasing = FALSE,
                                   weight_cap = 0,
                                   nfolds = 3, nrepeats = 1, nrounds = 200, min_score = 0.5,
                                   max_depth = 3,
                                   f_subsample = 1,
                                   skip_eval = FALSE, shuffle = FALSE, 
                                   xgb_params = NULL,
                                   cor_data = NULL, cor_n_features = 1000,
                                   n_threads = 4,
                                   use_gpu = TRUE, gpu_id = c(0), seed = 123, path = NULL){

  perturb_splits <- split(perturbs, ceiling(seq_along(perturbs)/chunk_size))
  
  # Generate a list of inputs
  inputs <- list()
  inputs$perturbs <- perturb_splits
  inputs$chunk_indx <- seq_along(perturb_splits)
  inputs$gpu_id <- rep(gpu_id,length.out=length(perturb_splits))
  
  furrr::future_pmap(inputs,fit_models_and_save,
                     model_dataset = model_dataset, response_cutoff = response_cutoff, decreasing = decreasing,
                     weight_cap = weight_cap,
                     nfolds = nfolds, nrepeats = nrepeats, nrounds = nrounds, 
                     min_score = min_score,
                     max_depth = max_depth,
                     f_subsample = f_subsample,
                     skip_eval = skip_eval, shuffle = shuffle,
                     xgb_params = xgb_params,
                     cor_data = cor_data, cor_n_features = cor_n_features,
                     n_threads = n_threads,
                     use_gpu = use_gpu, seed = seed, path = path,  
                     .options = furrr_options(seed = TRUE))
  
  return("Done")
  
}
Mushriq/mixmap documentation built on Jan. 28, 2024, 7:22 p.m.