View source: R/feature_importance.R
get_feature_importance  R Documentation 
Calculates feature importance using a trained model and test data. Requires
the future.apply
package.
get_feature_importance(
trained_model,
test_data,
outcome_colname,
perf_metric_function,
perf_metric_name,
class_probs,
method,
seed = NA,
corr_thresh = 1,
groups = NULL,
nperms = 100,
corr_method = "spearman"
)
trained_model 
Trained model from 
test_data 
Held out test data: dataframe of outcome and features. 
outcome_colname 
Column name as a string of the outcome variable
(default 
perf_metric_function 
Function to calculate the performance metric to
be used for crossvalidation and test performance. Some functions are
provided by caret (see 
perf_metric_name 
The column name from the output of the function
provided to perf_metric_function that is to be used as the performance metric.
Defaults: binary classification = 
class_probs 
Whether to use class probabilities (TRUE for categorical outcomes, FALSE for numeric outcomes). 
method 
ML method.
Options:

seed 
Random seed (default: 
corr_thresh 
For feature importance, group correlations
above or equal to 
groups 
Vector of feature names to group together during permutation.
Each element should be a string with feature names separated by a pipe
character ( 
nperms 
number of permutations to perform (default: 
corr_method 
correlation method. options or the same as those supported
by 
For permutation tests, the pvalue is the number of permutation statistics that are greater than the test statistic, divided by the number of permutations. In our case, the permutation statistic is the model performance (e.g. AUROC) after randomizing the order of observations for one feature, and the test statistic is the actual performance on the test data. By default we perform 100 permutations per feature; increasing this will increase the precision of estimating the null distribution, but also increases runtime. The pvalue represents the probability of obtaining the actual performance in the event that the null hypothesis is true, where the null hypothesis is that the feature is not important for model performance.
We strongly recommend providing multiple cores to speed up computation time. See our vignette on parallel processing for more details.
Data frame with performance metrics for when each feature (or group
of correlated features; feat
) is permuted (perf_metric
), differences
between the actual test performance metric on and the permuted performance
metric (perf_metric_diff
; test minus permuted performance), and the
pvalue (pvalue
: the probability of obtaining the actual performance
value under the null hypothesis). Features with a larger perf_metric_diff
are more important. The performance metric name (perf_metric_name
) and
seed (seed
) are also returned.
Begüm Topçuoğlu, topcuoglu.begum@gmail.com
Zena Lapp, zenalapp@umich.edu
Kelly Sovacool, sovacool@umich.edu
## Not run:
# If you called `run_ml()` with `feature_importance = FALSE` (the default),
# you can use `get_feature_importance()` later as long as you have the
# trained model and test data.
results < run_ml(otu_small, "glmnet", kfold = 2, cv_times = 2)
names(results$trained_model$trainingData)[1] < "dx"
feat_imp < get_feature_importance(results$trained_model,
results$trained_model$trainingData,
results$test_data,
"dx",
multiClassSummary,
"AUC",
class_probs = TRUE,
method = "glmnet"
)
# We strongly recommend providing multiple cores to speed up computation time.
# Do this before calling `get_feature_importance()`.
doFuture::registerDoFuture()
future::plan(future::multicore, workers = 2)
# Optionally, you can group features together with a custom grouping
feat_imp < get_feature_importance(results$trained_model,
results$trained_model$trainingData,
results$test_data,
"dx",
multiClassSummary,
"AUC",
class_probs = TRUE,
method = "glmnet",
groups = c(
"Otu00007", "Otu00008", "Otu00009", "Otu00011", "Otu00012",
"Otu00015", "Otu00016", "Otu00018", "Otu00019", "Otu00020", "Otu00022",
"Otu00023", "Otu00025", "Otu00028", "Otu00029", "Otu00030", "Otu00035",
"Otu00036", "Otu00037", "Otu00038", "Otu00039", "Otu00040", "Otu00047",
"Otu00050", "Otu00052", "Otu00054", "Otu00055", "Otu00056", "Otu00060",
"Otu00003Otu00002Otu00005Otu00024Otu00032Otu00041Otu00053",
"Otu00014Otu00021Otu00017Otu00031Otu00057",
"Otu00013Otu00006", "Otu00026Otu00001Otu00034Otu00048",
"Otu00033Otu00010",
"Otu00042Otu00004", "Otu00043Otu00027Otu00049", "Otu00051Otu00045",
"Otu00058Otu00044", "Otu00059Otu00046"
)
)
# the function can show a progress bar if you have the `progressr` package installed.
## optionally, specify the progress bar format:
progressr::handlers(progressr::handler_progress(
format = ":message :bar :percent  elapsed: :elapsed  eta: :eta",
clear = FALSE,
show_after = 0
))
## tell progressr to always report progress
progressr::handlers(global = TRUE)
## run the function and watch the live progress udpates
feat_imp < get_feature_importance(results$trained_model,
results$trained_model$trainingData,
results$test_data,
"dx",
multiClassSummary,
"AUC",
class_probs = TRUE,
method = "glmnet"
)
# You can specify any correlation method supported by `stats::cor`:
feat_imp < get_feature_importance(results$trained_model,
results$trained_model$trainingData,
results$test_data,
"dx",
multiClassSummary,
"AUC",
class_probs = TRUE,
method = "glmnet",
corr_method = "pearson"
)
## End(Not run)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.