HSPSModelR is an R package that reduces the time it takes to fit and compare many machine learning models to a cognostic dataset. It is designed to give the user a quick snapshot of which classification algorithms perform well, as well as extract variable importance. This package works only for binary classification problems.
install.packages(c("devtools", "caret", "tidyverse")) devtools::install_github("HSPS-DataScience/HSPSModelR")
library(caret) library(tidyverse) library(HSPSModelR)
preprocess_data
The preprocess_data
function's main purpose is to reduce unnessesary predictor columns using correlation and variance. It also has the ability to impute missing data and change the target class into either factor or binary.
data(dhfr) dhfr_reduced <- preprocess_data(dhfr, target = "Y", reduce_cols = TRUE)
ncol(dhfr) ncol(dhfr_reduced)
This function does not scale or perform any other major data tranformation. Such transformations should be specified while training as to be more easily reverted for interpretation purposes. For more options, see caret
's pre-processing functions here{target="_blank"}.
set.seed(1) index <- caret::createDataPartition(dhfr_reduced$Y, p = .8, list = F) train_x <- dhfr_reduced[ index, 1:111] test_x <- dhfr_reduced[-index, 1:111] train_y <- dhfr_reduced[ index, "Y"] test_y <- dhfr_reduced[-index, "Y"]
install_train_packages
In preparation for the run_models
function, this function i
run_models
This function takes considerable time to run as it trains 68 classification methods from the caret package. To reduce this time, make sure to run install_train_packages()
which installs all the libraries required to iterate through each of the 68 methods.
Many preprocessing and training parameters can be specified such as preprocess, folds, iterations, etc. Since caret includes a lot of information in each train
object, this function implements a trimming method to retain just enough information to be able to predict.
install_train_packages()
models_list <- run_models(train_x = train_x, train_y = train_y, seed = 1, num_folds = 2, trim_models = TRUE, light = TRUE, prepro_methods = c("center","scale","YeoJohnson"))
get_performance
This function takes a list of models and outputs a dataframe where each row represents a specific perfmance metric for a trained model (labeled by the training method) calculated in from caret::confusionMatrix
The default ouput is in long format. If you would like it in wide, specify the format
argument to be "wide"
.
model_performance <- get_performance(models_list, test_x, test_y) model_performance
get_performance(models_list, test_x, test_y, format = "wide")
Further exploration can be performed to view a specific measurement such as the F1 measure. F1 is a measure of a test's accuracy that weights the average of precision and recall.
model_performance %>% filter(measure == "F1")
get_common_predictions(models = models_list, test_x = test_x, threshold = 0.70)
The get_common_predictions function takes a list of models and outputs a vector of rows whose desired outcomes were agreed upon by a prespecified ratio of models. Downloading the HSPSModelR package and using run_models() then get_common_predictions() is the fastest way to get from raw data to quality results.
var_imp_*()
The var_imp_*
functions output tibbles where rows rank the individual contribution of features to final predictions. var_imp_overall
ranks the most influential features from among the entire list of models, whereas var_imp_raw
gives the most important variable in sequence from each model listed in the function.
suppressMessages( var_imp_overall(models_list) ) suppressMessages( var_imp_raw(models_list) )
gg_var_imp()
gg_var_imp()
is available to quickly visualize these variables, allowing you to specify the top number of ranked variables with the top_num
argument.
vars <- var_imp_overall(models_list) gg_var_imp(vars, top_num = 20)
The functions above are especially useful if one already has trained models. If several models are stored in the same directory, the make_list function can pull them into a single object in your R environment.
make_list("path/to/file")
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.