train_model: Train model using 'caret::train()'.

Description Usage Arguments Value Author(s) Examples

View source: R/train_model.R

Description

Train model using caret::train().

Usage

1
2
3
4
5
6
7
8
9
train_model(
  model_formula,
  train_data,
  method,
  cv,
  perf_metric_name,
  tune_grid,
  ntree
)

Arguments

model_formula

Model formula, typically created with stats::as.formula().

train_data

Training data. Expected to be a subset of the full dataset.

method

ML method. Options: c("glmnet", "rf", "rpart2", "svmRadial", "xgbTree").

  • glmnet: linear, logistic, or multiclass regression

  • rf: random forest

  • rpart2: decision tree

  • svmRadial: support vector machine

  • xgbTree: xgboost

cv

Cross-validation caret scheme from define_cv().

perf_metric_name

The column name from the output of the function provided to perf_metric_function that is to be used as the performance metric. Defaults: binary classification = "ROC", multi-class classification = "logLoss", regression = "RMSE".

tune_grid

Tuning grid from get_tuning_grid().

ntree

For random forest, how many trees to use (default: 1000). Note that caret doesn't allow this parameter to be tuned.

Value

Trained model from caret::train().

Author(s)

Zena Lapp, zenalapp@umich.edu

Examples

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
## Not run: 
training_data <- otu_mini_bin_results_glmnet$trained_model$trainingData %>%
  dplyr::rename(dx = .outcome)
method <- "rf"
hyperparameters <- get_hyperparams_list(otu_mini_bin, method)
cross_val <- define_cv(training_data,
  "dx",
  hyperparameters,
  perf_metric_function = caret::multiClassSummary,
  class_probs = TRUE,
  cv_times = 2
)
tune_grid <- get_tuning_grid(hyperparameters, method)

rf_model <- train_model(
  stats::as.formula(paste("dx", "~ .")),
  training_data,
  method,
  cross_val,
  "AUC",
  tune_grid,
  1000
)
rf_model$results %>% dplyr::select(mtry, AUC, prAUC)

## End(Not run)

SchlossLab/mikropml documentation built on Nov. 25, 2021, 1:13 p.m.