View source: R/patient-level_modeling.R
tof_train_model | R Documentation |
This function uses a training set/test set paradigm to tune and fit an elastic net model using a variety of user-specified details. Tuning can be performed using either a simple training vs. test set split, k-fold cross-validation, or bootstrapping, and multiple preprocessing options are available.
tof_train_model(
split_data,
unsplit_data,
predictor_cols,
response_col = NULL,
time_col = NULL,
event_col = NULL,
model_type = c("linear", "two-class", "multiclass", "survival"),
hyperparameter_grid = tof_create_grid(),
standardize_predictors = TRUE,
remove_zv_predictors = FALSE,
impute_missing_predictors = FALSE,
optimization_metric = "tidytof_default",
best_model_type = c("best", "best with sparsity"),
num_cores = 1
)
split_data |
An 'rsplit' or 'rset' object from the |
unsplit_data |
A tibble containing sample-level data to use for modeling without resampling. While using a resampling method is advised, this argument provides an interface to fit a model without using cross-validation or bootstrap resampling. Ignored if split_data is provided. |
predictor_cols |
Unquoted column names indicating which columns in the data contained in 'split_data' should be used as predictors in the elastic net model. Supports tidyselect helpers. |
response_col |
Unquoted column name indicating which column in the data contained in 'split_data' should be used as the outcome in a "two-class", "multiclass", or "linear" elastic net model. Must be a factor for "two-class" and "multiclass" models and must be a numeric for "linear" models. Ignored if 'model_type' is "survival". |
time_col |
Unquoted column name indicating which column in the data contained in 'split_data' represents the time-to-event outcome in a "survival" elastic net model. Must be numeric. Ignored if 'model_type' is "two-class", "multiclass", or "linear". |
event_col |
Unquoted column name indicating which column in the data contained in 'split_data' represents the time-to-event outcome in a "survival" elastic net model. Must be a binary column - all values should be either 0 or 1 (with 1 indicating the adverse event) or FALSE and TRUE (with TRUE indicating the adverse event). Ignored if 'model_type' is "two-class", "multiclass", or "linear". |
model_type |
A string indicating which kind of elastic net model to build. If a continuous response is being predicted, use "linear" for linear regression; if a categorical response with only 2 classes is being predicted, use "two-class" for logistic regression; if a categorical response with more than 2 levels is being predicted, use "multiclass" for multinomial regression; and if a time-to-event outcome is being predicted, use "survival" for Cox regression. |
hyperparameter_grid |
A hyperparameter grid indicating which values of
the elastic net penalty (lambda) and the elastic net mixture (alpha) hyperparamters
should be used during model tuning. Generate this grid using |
standardize_predictors |
A logical value indicating if numeric predictor columns should be standardized (centered and scaled) before model fitting, as is standard practice during elastic net regularization. Defaults to TRUE. |
remove_zv_predictors |
A logical value indicating if predictor columns
with near-zero variance should be removed before model fitting using
|
impute_missing_predictors |
A logical value indicating if predictor columns
should have missing values imputed using k-nearest neighbors before model fitting (see
|
optimization_metric |
A string indicating which optimization metric should be used for hyperparameter selection during model tuning. Valid values depend on the model_type.
|
best_model_type |
Currently unused. |
num_cores |
Integer indicating how many cores should be used for parallel processing when fitting multiple models. Defaults to 1. Overhead to separate models across multiple cores can be high, so significant speedup is unlikely to be observed unless many large models are being fit. |
A 'tof_model', an S3 class that includes the elastic net model with the best performance (assessed via cross-validation, bootstrapping, or simple splitting depending on 'split_data') across all tested hyperparameter value combinations. 'tof_models' store the following information:
The final elastic net ("glmnet") model, which is chosen by selecting the elastic net hyperparameters with the best 'optimization_metric' performance on the validation sets of each resample used to train the model (on average)
The recipe
used for data preprocessing
The optimal mixture hyperparameter (alpha) for the glmnet model
The optimal penalty hyperparameter (lambda) for the glmnet model
A string indicating which type of glmnet model was fit
A character vector representing the names of the columns in the training data modeled as outcome variables
A tibble containing the (not preprocessed) data used to train the model
A tibble containing the validation set performance metrics (and model predictions) during for each resample fold during model tuning.
For survival models only, a tibble containing information about the relative-risk thresholds that can be used to split the training data into 2 risk groups (low- and high-risk) based on the final model's predictions. For each relative-risk threshold, the log-rank test p-value and an indicator of which threshold gives the most significant separation is provided.
For survival models only, a numeric value representing the relative-risk threshold that yields the most significant log-rank test when separating the training data into low- and high-risk groups.
Other modeling functions:
tof_assess_model()
,
tof_create_grid()
,
tof_predict()
,
tof_split_data()
feature_tibble <-
dplyr::tibble(
sample = as.character(1:100),
cd45 = runif(n = 100),
pstat5 = runif(n = 100),
cd34 = runif(n = 100),
outcome = (3 * cd45) + (4 * pstat5) + rnorm(100),
class =
as.factor(
dplyr::if_else(outcome > median(outcome), "class1", "class2")
),
multiclass =
as.factor(
c(rep("class1", 30), rep("class2", 30), rep("class3", 40))
),
event = c(rep(0, times = 30), rep(1, times = 70)),
time_to_event = rnorm(n = 100, mean = 10, sd = 2)
)
split_data <- tof_split_data(feature_tibble, split_method = "simple")
# train a regression model
tof_train_model(
split_data = split_data,
predictor_cols = c(cd45, pstat5, cd34),
response_col = outcome,
model_type = "linear"
)
# train a logistic regression classifier
tof_train_model(
split_data = split_data,
predictor_cols = c(cd45, pstat5, cd34),
response_col = class,
model_type = "two-class"
)
# train a cox regression survival model
tof_train_model(
split_data = split_data,
predictor_cols = c(cd45, pstat5, cd34),
time_col = time_to_event,
event_col = event,
model_type = "survival"
)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.