#' Train models with forester
#'
#' The `train()` function is the core function of this package.
#' The only obligatory arguments are `data` and `target`.
#' Setting and changing other arguments will affect model
#' validation strategy, tested model families, and so on.
#'
#' @param data A `data.frame` or `matrix` - data which will be
#' used to build models. By default model will be trained
#' on all columns in the `data`.
#' @param y A target variable, being a character name of variable in the `data`
#' that contains the target variable for classification and regression tasks.
#' By default set to NULL. If you use y, don't use `time`, and `status`, which are
#' reserved for survival analysis.
#' @param time A target variable, being a character name of variable in the `data`
#' that describes the `time` column for survival analysis task. By default set to NULL.
#' You have to use both `time`, and `status` together. If you use it, you cannot use `y`
#' as it is reserved for classification and regression tasks.
#' @param status A target variable, being a character name of variable in the `data`
#' that describes the `status` for survival analysis task. By default set to NULL.
#' You have to use both `time`, and `status` together. If you use it, you cannot use `y`
#' as it is reserved for classification and regression tasks.
#' @param type A character, one of `binary_clf`/`regression`/`survival`/`auto`/`multiclass` that
#' sets the type of the task. If `auto` (the default option) then
#' forester will figure out `type` based on the number of unique values
#' in the `y` variable, or the presence of `time`/`status` columns.
#' @param engine A vector of tree-based models that shall be tested.
#' Possible values are: `ranger`, `xgboost`, `decision_tree`, `lightgbm`, `catboost`.
#' All models from this vector will be trained and the best one will be returned.
#' It doesn't matter for survival analysis.
#' @param verbose A logical value, if set to TRUE, provides all information about
#' training process, if FALSE gives none.
#' @param check_correlation A logical value, if set to TRUE, provides information about
#' the correlations between numeric, and categorical pairs of variables as a part
#' of data check. Available only when verbose is set to TRUE. Default value is TRUE.
#' @param train_test_split A 3-value, numeric vector, describing the proportions of train,
#' test, validation subsets to original data set. Default values are: c(0.6, 0.2, 0.2).
#' @param split_seed An integer value describing the seed for the split into
#' train, test, and validation datasets. By default no seed is set and the split
#' is performed randomly. Default value is NULL.
#' @param bayes_iter An integer value describing number of optimization rounds
#' used by the Bayesian optimization. If set to 0 it turns off this method.
#' @param bayes_info A list with two values, determining the verbosity of the Bayesian
#' Optmization process. The first value is `verbose` with 3 levels: 0 - no output;
#' 1 - describes what is hapenning, and if we can reach local optimum; 2 - addtionally
#' provides infromation about recent, and the best scores. The second value is
#' `plotProgress`, which is a logical value indicating if the progress of the Bayesian
#' Optimization should be plotted. WARNING it will create plot after each step, thus
#' it might be computationally expensive. Both arguments come from the
#' `ParBayesianOptimization` package. It only matters if you set global verbose to TRUE.
#' Default values are: list(verbose = 0, plotProgress = FALSE).
#' @param random_evals An integer value describing number of trained models
#' with different parameters by random search. If set to 0 it turns off this method.
#' @param parallel A logical value indicating if the parallel method for random search
#' and Bayesian Optimizations should be used. Unfortunately it works properly
#' for ranger and xgboost models only. By default it is set to TRUE.
#' @param metrics A vector of metrics names. By default param set for `auto`, most important metrics are returned.
#' For `all` all metrics are returned. For `NULL` no metrics returned but still sorted by `sort_by`.
#' @param sort_by A string with a name of metric to sort by.
#' For `auto` models going to be sorted by `mse` for regression and `f1` for classification.
#' @param metric_function The self-created function.
#' It should look like name(predictions, observed) and return the numeric value.
#' In case of using `metrics` param with a value other than `auto` or `all`, is needed to use a value `metric_function`
#' in order to see given metric in report. If `sort_by` is equal to `auto` models are sorted by `metric_function`.
#' @param metric_function_name The name of the column with values of `metric_function` parameter.
#' By default `metric_function_name` is `metric_function`.
#' @param metric_function_decreasing A logical value indicating how metric_function
#' should be sorted. `TRUE` by default.
#' @param best_model_number Number of best models to be chosen as element of the return.
#' All trained models will be returned as different element of the return.
#' @param custom_preprocessing An object returned by the `custom_preprocessing()`
#' function. By default it is set to NULL, which indicates that basic preprocessing
#' inside the train will be executed. This process however only makes the necessary actions
#' for the `train()` to work properly.
#'
#' @return A list of all necessary objects for other functions. It contains:
#' \itemize{
#' \item \code{`data`} The original data.
#' \item \code{`y`} The original target column name.
#' \item \code{`time`} The original column name describing time for survival analysis task.
#' \item \code{`status`} The original column name describing status for survival analysis task.
#' \item \code{`type`} The type of the ML task. If the user did not specify a type in the
#' input parameters, the algorithm recognizes, uses and returns the same type.
#' It could be `binary_clf`, `regression`, `survival`, or `multiclass`.
#'
#' \item \code{`deleted_columns`} Column names from the original data frame that have been
#' removed in the data preprocessing process, e.g. due to too high correlation
#' with other columns.
#' \item \code{`preprocessed_data`} The data frame after the preprocessing process - that
#' means: removing columns with one value for all rows, binarizing the target
#' column, managing missing values and in advanced preprocessing: deleting
#' correlated values, deleting columns that are ID-like columns and performing
#' Boruta algorithm for selecting most important features.
#' \item \code{`bin_labels`} Labels of binarized target value - 1 or 2 for binary
#' classification and NULL for regression.
#' \item \code{`deleted_rows`} The indexes of rows deleted during the preprocessing,
#' if none were removed the value is NULL.
#' \item \code{`models_list`} The list of all trained models.
#' \item \code{`check_report`} Data check report held as a list of strings. It is used
#' by the `report()` function.
#' \item \code{`outliers`} The vector of possible outliers detected by the `check_data()`.
#'
#' \item \code{`best_models_on_valid`} The object containing the best performing models
#' on the validation dataset.
#' #' \item \code{`engine`} The list of names of all types of trained models. Possible
#' values: 'ranger', 'xgboost', 'decision_tree', 'lightgbm', 'catboost'.
#' \item \code{`raw_train`} The another form of the training dataset (useful for creating
#' VS plot and predicting on training dataset for catboost and lightgbm models).
#'
#' \item \code{`train_data`} The training dataset - the part of the source dataset after
#' preprocessing, balancing and splitting into the training, test and validation
#' datasets.
#' \item \code{`test_data`} The test dataset - the part of the source dataset after
#' preprocessing, balancing and splitting into the training, test and
#' validation datasets.
#' \item \code{`valid_data`} The validation dataset - the part of the source dataset after
#' preprocessing, balancing and splitting into the training, test and validation
#' datasets.
#'
#' \item \code{`train_inds`} The vector of integers describing the observation indexes from
#' the original data frame that went to the training set.
#' \item \code{`test_inds`} The vector of integers describing the observation indexes from
#' the original data frame that went to the testing set.
#' \item \code{`valid_inds`} The vector of integers describing the observation indexes from
#' the original data frame that went to the validation set.
#'
#' \item \code{`predictions_train`} Predictions for all trained models on a train dataset.
#' \item \code{`predictions_test`} Predictions for all trained models on a test dataset.
#' \item \code{`predictions_valid`} Predictions for all trained models on a validation dataset.
#'
#' \item \code{`predictions_train_labels`} Predictions for all trained models on a
#' train dataset with human readable labels (for classification tasks only).
#' \item \code{`predictions_test_labels`} Predictions for all trained models on a
#' test dataset with human readable labels (for classification tasks only).
#' \item \code{`predictions_valid_labels`} Predictions for all trained models on a
#' validation dataset with human readable labels (for classification tasks only).
#'
#' \item \code{`predictions_best_train`} Predictions for best trained models on a train dataset.
#' \item \code{`predictions_best_test`} Predictions for best trained models on a test dataset.
#' \item \code{`predictions_best_valid`} Predictions for best trained models on a validation dataset.
#'
#' \item \code{`predictions_best_train_labels`} Predictions for best trained models on a
#' train dataset with human readable labels (for classification tasks only).
#' \item \code{`predictions_best_test_labels`} Predictions for best trained models on a
#' test dataset with human readable labels (for classification tasks only).
#' \item \code{`predictions_best_valid_labels`} Predictions for best trained models on a
#' validation dataset with human readable labels (for classification tasks only).
#'
#' \item \code{`score_train`} The list of metrics for all trained models calculated on a train
#' dataset.
#' \item \code{`score_test`} The list of metrics for all trained models calculated on a test
#' dataset.
#' \item \code{`score_valid`} The list of metrics for all trained models calculated on a validation
#' dataset.
#'
#' \item \code{`test_observed`} Values of y column from the test dataset.
#' \item \code{`train_observed`} Values of y column from the training dataset.
#' \item \code{`valid_observed`} Values of y column from the validation dataset.
#'
#' \item \code{`test_observed_labels`} Values of y column from the test dataset as text labels
#' (for classification tasks only).
#' \item \code{`train_observed_labels`} Values of y column from the training dataset as text
#' labels (for classification task only).
#' \item \code{`valid_observed_labels`} Values of y column from the validation dataset as text
#' labels (for classification task only).
#' }
#' @export
#'
#' @examples
#' \dontrun{
#' # Regression task example.
#' library(forester)
#' data('lisbon')
#' train_output <- train(lisbon, 'Price')
#' train_output$score_valid
#'
#' # Survival analysis example
#' data('peakVO2')
#' train_output <- train(peakVO2, time = 'ttodead', status = 'died')
#' train_output$score_valid
#' }
train <- function(data,
y = NULL,
time = NULL,
status = NULL,
type = 'auto',
engine = c('ranger', 'xgboost', 'decision_tree', 'lightgbm'),
verbose = TRUE,
check_correlation = TRUE,
train_test_split = c(0.6, 0.2, 0.2),
split_seed = NULL,
bayes_iter = 10,
bayes_info = list(verbose = 0, plotProgress = FALSE),
random_evals = 10,
parallel = TRUE,
metrics = 'auto',
sort_by = 'auto',
metric_function = NULL,
metric_function_name = NULL,
metric_function_decreasing = TRUE,
best_model_number = 5,
custom_preprocessing = NULL) {
t0 <- as.numeric(Sys.time())
if (is.null(y)) {
if (is.null(time) | is.null(status)) {
verbose_cat(crayon::red('\u2716'), 'Lack of target variables. Please specify',
'either y (for classification or regression tasks), or time and',
'status (for survival analysis). \n\n', verbose = verbose)
stop('Lack of target variables. Please specify either y (for classification
or regression tasks), or time and status (for survival analysis)')
}
} else {
if (!is.null(time) | !is.null(status)) {
verbose_cat(crayon::red('\u2716'), 'Provided too many targets. Please specify',
'either y (for classification or regression tasks), or time and',
'status (for survival analysis). \n\n', verbose = verbose)
stop('Provided too many targets. Please specify either y (for classification
or regression tasks), or time and status (for survival analysis).')
}
}
tryCatch({
if ('catboost' %in% engine) {
find.package('catboost')
}
},
error = function(cond) {
verbose_cat(crayon::red('\u2716'), 'Package not found: catboost, to use it please ',
'follow guides for installation from GitHub repository README.',
'Otherwise remove it from the engine. \n\n', verbose = verbose)
stop('Package not found: catboost, to use it please follow guides for installation
from GitHub repository README. Otherwise remove it from the engine.')
})
if ('tbl' %in% class(data) || 'list' %in% class(data) || 'matrix' %in% class(data)) {
data <- as.data.frame(data)
verbose_cat(crayon::red('\u2716'), 'Provided dataset is a tibble, list or matrix and not a',
'data.frame. Casting the dataset to data.frame format. \n\n',
verbose = verbose)
}
if (type == 'auto') {
type <- guess_type(data, y)
if (type == 'regression') {
data[[y]] <- as.numeric(data[[y]])
}
verbose_cat(crayon::green('\u2714'), 'Type guessed as:', type, '\n', verbose = verbose)
} else if (!type %in% c('regression', 'binary_clf', 'survival', 'multiclass')) {
verbose_cat(crayon::red('\u2716'), 'Invalid value. Correct task types are: `binary_clf`, `regression`, `survival`, `multiclass`, and `auto` for automatic task identification \n', verbose = verbose)
stop('Invalid value. Correct task types are: `binary_clf`, `regression`, `survival`, `multiclass`, and `auto` for automatic task identification')
} else {
verbose_cat(crayon::green('\u2714'), 'Type provided as: ', type, '\n', verbose = verbose)
}
if (type == 'survial') {
if (!status %in% colnames(data) || !time %in% colnames(data)) {
verbose_cat(crayon::red('\u2716'), 'Provided target column name for time or status parameters',
status, time, 'is not present in the datataset. \n', verbose = verbose)
stop('Provided target column name for time or status parameter is not present in the datataset.')
}
} else if (!y %in% colnames(data)) {
verbose_cat(crayon::red('\u2716'), 'Provided target column name for y parameter', y,
'is not present in the datataset. \n', verbose = verbose)
stop('Provided target column name for y parameter is not present in the datataset.')
}
if (parallel) {
cores <- parallel::detectCores()
cl <- parallel::makeCluster(cores - 1)
doParallel::registerDoParallel(cl)
verbose_cat(crayon::green('\u2714'), 'Parallel processing is turned on. Registered', cores - 1, 'cores. \n', verbose = verbose)
}
if (is.null(custom_preprocessing)) {
check_report <- check_data(data, y, time, status, type, verbose, check_correlation = check_correlation)
preprocessed_data <- preprocessing(data, y, time, status, type)
preprocessed_data$rm_rows <- NULL
verbose_cat(crayon::green('\u2714'), 'Data preprocessed with basic preprocessing. \n', verbose = verbose)
} else {
check_report <- check_data(custom_preprocessing$data, y, time, status, type, verbose)
preprocessed_data <- custom_preprocessing
verbose_cat(crayon::green('\u2714'), 'Imported preprocessed data from custom_preprocessing(). \n', verbose = verbose)
}
# Data splitting and recording observed variables in each dataset with distinction
# between survival analysis and other tasks.
if (!is.null(y)) {
target <- y
} else {
target <- status
}
split_data <- train_test_balance(preprocessed_data$data, target, balance = TRUE,
fractions = train_test_split, seed = split_seed)
train_observed <- split_data$train[[target]]
test_observed <- split_data$test[[target]]
valid_observed <- split_data$valid[[target]]
verbose_cat(crayon::green('\u2714'), 'Data split and balanced. \n', verbose = verbose)
train_data <- prepare_data(split_data$train, type, y, time, status, engine)
test_data <- prepare_data(split_data$test, type, y, time, status, engine,
predict = TRUE, split_data$train)
valid_data <- prepare_data(split_data$valid, type, y, time, status, engine,
predict = TRUE, split_data$train)
# For creating VS plot and predicting on train (catboost, lgbm).
raw_train <- prepare_data(split_data$train, type, y, time, status,engine,
predict = TRUE, split_data$train)
verbose_cat(crayon::green('\u2714'), 'Correct formats prepared. \n', verbose = verbose)
b_t0 <- as.numeric(Sys.time())
model_basic <- train_models(train_data, y, time, status, engine, type)
b_t1 <- as.numeric(Sys.time())
verbose_cat('\n', crayon::green('\u2714'), ' Models with default parameters successfully trained. \n', verbose = verbose, sep = '')
verbose_cat(' ', crayon::green('\u2714'), 'Default: It took', round(b_t1 - b_t0, 2), 'seconds. \n', verbose = verbose)
if (random_evals > 0) {
rs_t0 <- as.numeric(Sys.time())
verbose_cat('\n', crayon::green('\u2714'), ' Starting Random Search training process. \n', verbose = verbose, sep = '')
}
model_random <- random_search(train_data,
y = y,
time = time,
status = status,
engine = engine,
type = type,
max_evals = random_evals,
parallel = parallel,
verbose = verbose)
if (random_evals > 0) {
rs_t1 <- as.numeric(Sys.time())
verbose_cat('\n', crayon::green('\u2714'), ' Models optimized with Random Search successfully trained. \n', verbose = verbose, sep = '')
verbose_cat(' ', crayon::green('\u2714'), 'Random Search: It took', round(rs_t1 - rs_t0, 2), 'seconds. \n', verbose = verbose)
}
if (bayes_iter > 0) {
bo_t0 <- as.numeric(Sys.time())
verbose_cat('\n', crayon::green('\u2714'), ' Starting Bayesian Optimization training process. \n', verbose = verbose, sep = '')
}
model_bayes <- train_models_bayesopt(train_data,
y = y,
time = time,
status = status,
test_data = test_data,
engine = engine,
type = type,
parallel = parallel,
iters.n = bayes_iter,
bayes_info = bayes_info,
verbose = verbose)
if (bayes_iter > 0) {
bo_t1 <- as.numeric(Sys.time())
verbose_cat('\n', crayon::green('\u2714'), ' Models optimized with Bayesian Optimization successfully trained. \n', verbose = verbose, sep = '')
verbose_cat(' ', crayon::green('\u2714'), 'Bayesian Optimization: It took', round(bo_t1 - bo_t0, 2), 'seconds. \n', verbose = verbose)
}
models_all <- c(model_basic, model_random$models, model_bayes)
engine_all <- c(engine, model_random$engine, engine)
if (type != 'survival') {
tuning <- c(rep('basic', length(engine)),
rep('random_search', length(model_random$engine)),
rep('bayes_opt', length(engine)))
} else {
tuning <- c('basic',
rep('random_search', length(model_random$engine)),
'bayes_opt')
}
predict_train <- predict_models_all(models_all, raw_train, y, type = type)
predict_test <- predict_models_all(models_all, test_data, y, type = type)
predict_valid <- predict_models_all(models_all, valid_data, y, type = type)
verbose_cat('\n', crayon::green('\u2714'), ' Created the predictions for all models. \n', verbose = verbose, sep = '')
score_train <- score_models(models_all,
predict_train,
train_data$ranger_data[[y]],
train_data,
type,
time,
status,
metrics = metrics,
sort_by = sort_by,
metric_function = metric_function,
metric_function_name = metric_function_name,
metric_function_decreasing = metric_function_decreasing,
engine = engine_all,
tuning = tuning)
score_test <- score_models(models_all,
predict_test,
test_data$ranger_data[[y]],
test_data,
type,
time,
status,
metrics = metrics,
sort_by = sort_by,
metric_function = metric_function,
metric_function_name = metric_function_name,
metric_function_decreasing = metric_function_decreasing,
engine = engine_all,
tuning = tuning)
score_valid <- score_models(models_all,
predict_valid,
valid_data$ranger_data[[y]],
valid_data,
type,
time,
status,
metrics = metrics,
sort_by = sort_by,
metric_function = metric_function,
metric_function_name = metric_function_name,
metric_function_decreasing = metric_function_decreasing,
engine = engine_all,
tuning = tuning)
verbose_cat(crayon::green('\u2714'), 'Created the score boards for all models. \n', verbose = verbose)
choose_best_models <- function(models, engine, score, number) {
number <- min(number, length(models))
return(list(
models = models[score[1:number, 'name']],
engine = score[1:number, 'engine']))
}
best_models_on_valid <- choose_best_models(models_all, engine_all, score_valid, best_model_number)
predictions_best_train <- predict_models_all(best_models_on_valid$models, raw_train, y, type = type)
predictions_best_test <- predict_models_all(best_models_on_valid$models, test_data, y, type = type)
predictions_best_valid <- predict_models_all(best_models_on_valid$models, valid_data, y, type = type)
verbose_cat(crayon::green('\u2714'), 'Created the predictions for the best models. \n', verbose = verbose)
# Providing the original labels to the target.
if (type == 'binary_clf') {
test_observed <- as.numeric(test_observed) - 1
train_observed <- as.numeric(train_observed) - 1
valid_observed <- as.numeric(valid_observed) - 1
test_observed_labels <- test_observed
train_observed_labels <- train_observed
valid_observed_labels <- valid_observed
predict_train_labels <- predict_train
predict_test_labels <- predict_test
predict_valid_labels <- predict_valid
predictions_best_train_labels <- predictions_best_train
predictions_best_test_labels <- predictions_best_test
predictions_best_valid_labels <- predictions_best_valid
labels <- preprocessed_data$bin_labels
# Human-readable observed values with text labels.
# For the observed values.
for (i in 1:length(train_observed)) {
if (train_observed[i] < 0.5) {
train_observed_labels[i] <- labels[1]
} else {
train_observed_labels[i] <- labels[2]
}
}
for (i in 1:length(test_observed)) {
if (test_observed[i] < 0.5) {
test_observed_labels[i] <- labels[1]
} else {
test_observed_labels[i] <- labels[2]
}
}
for (i in 1:length(valid_observed)) {
if (valid_observed[i] < 0.5) {
valid_observed_labels[i] <- labels[1]
} else {
valid_observed_labels[i] <- labels[2]
}
}
# For the all models predictions.
for (j in 1:length(predict_train)){
for (i in 1:length(predict_train[[j]])) {
if (predict_train[[j]][i] < 0.5) {
predict_train_labels[[j]][i] <- labels[1]
} else {
predict_train_labels[[j]][i] <- labels[2]
}
}
for (i in 1:length(predict_test[[j]])) {
if (predict_test[[j]][i] < 0.5) {
predict_test_labels[[j]][i] <- labels[1]
} else {
predict_test_labels[[j]][i] <- labels[2]
}
}
for (i in 1:length(predict_valid[[j]])) {
if (predict_valid[[j]][i] < 0.5) {
predict_valid_labels[[j]][i] <- labels[1]
} else {
predict_valid_labels[[j]][i] <- labels[2]
}
}
}
# For the best models predictions.
for (j in 1:length(predictions_best_train)){
for (i in 1:length(predictions_best_train[[j]])) {
if (predictions_best_train[[j]][i] < 0.5) {
predictions_best_train_labels[[j]][i] <- labels[1]
} else {
predictions_best_train_labels[[j]][i] <- labels[2]
}
}
for (i in 1:length(predictions_best_test[[j]])) {
if (predictions_best_test[[j]][i] < 0.5) {
predictions_best_test_labels[[j]][i] <- labels[1]
} else {
predictions_best_test_labels[[j]][i] <- labels[2]
}
}
for (i in 1:length(predictions_best_valid[[j]])) {
if (predictions_best_valid[[j]][i] < 0.5) {
predictions_best_valid_labels[[j]][i] <- labels[1]
} else {
predictions_best_valid_labels[[j]][i] <- labels[2]
}
}
}
}
if (type == 'multiclass') {
test_observed <- as.numeric(test_observed)
train_observed <- as.numeric(train_observed)
valid_observed <- as.numeric(valid_observed)
test_observed_labels <- test_observed
train_observed_labels <- train_observed
valid_observed_labels <- valid_observed
predict_train_labels <- predict_train
predict_test_labels <- predict_test
predict_valid_labels <- predict_valid
predictions_best_train_labels <- predictions_best_train
predictions_best_test_labels <- predictions_best_test
predictions_best_valid_labels <- predictions_best_valid
labels <- preprocessed_data$bin_labels
# Human-readable observed values with text labels.
# For the observed values.
for (i in 1:length(train_observed)) {
train_observed_labels[i] <- labels[train_observed[i]]
}
for (i in 1:length(test_observed)) {
test_observed_labels[i] <- labels[test_observed[i]]
}
for (i in 1:length(valid_observed)) {
valid_observed_labels[i] <- labels[valid_observed[i]]
}
# For the all models predictions.
for (j in 1:length(predict_train)){
for (i in 1:length(predict_train[[j]])) {
predict_train_labels[[j]][i] <- labels[predict_train[[j]][i]]
}
for (i in 1:length(predict_test[[j]])) {
predict_test_labels[[j]][i] <- labels[predict_test[[j]][i]]
}
for (i in 1:length(predict_valid[[j]])) {
predict_valid_labels[[j]][i] <- labels[predict_valid[[j]][i]]
}
}
# For the best models predictions.
for (j in 1:length(predictions_best_train)){
for (i in 1:length(predictions_best_train[[j]])) {
predictions_best_train_labels[[j]][i] <- labels[predictions_best_train[[j]][i]]
}
for (i in 1:length(predictions_best_test[[j]])) {
predictions_best_test_labels[[j]][i] <- labels[predictions_best_test[[j]][i]]
}
for (i in 1:length(predictions_best_valid[[j]])) {
predictions_best_valid_labels[[j]][i] <- labels[predictions_best_valid[[j]][i]]
}
}
}
verbose_cat(crayon::green('\u2714'), 'Created human-readable labels for observables and predictions. \n', verbose = verbose)
t1 <- as.numeric(Sys.time())
verbose_cat(crayon::green('\u2714'), 'The train() run took', round(t1 - t0, 2), 'seconds. \n', verbose = verbose)
if (type %in% c('binary_clf', 'multiclass')) {
clf_models <- list(
data = data,
y = y,
time = time,
status = status,
type = type,
deleted_columns = preprocessed_data$rm_colnames,
preprocessed_data = preprocessed_data$data,
bin_labels = preprocessed_data$bin_labels,
deleted_rows = preprocessed_data$rm_rows,
models_list = models_all,
check_report = check_report$str,
outliers = check_report$outliers,
best_models_on_valid = best_models_on_valid,
engine = engine,
raw_train = raw_train,
train_data = train_data,
test_data = test_data,
valid_data = valid_data,
train_inds = split_data$train_inds,
test_inds = split_data$test_inds,
valid_inds = split_data$valid_inds,
predictions_train = predict_train,
predictions_test = predict_test,
predictions_valid = predict_valid,
predictions_train_labels = predict_train_labels,
predictions_test_labels = predict_test_labels,
predictions_valid_labels = predict_valid_labels,
predictions_best_train = predictions_best_train,
predictions_best_test = predictions_best_test,
predictions_best_valid = predictions_best_valid,
predictions_best_train_labels = predictions_best_train_labels,
predictions_best_test_labels = predictions_best_test_labels,
predictions_best_valid_labels = predictions_best_valid_labels,
score_test = score_test,
score_train = score_train,
score_valid = score_valid,
test_observed = test_observed,
train_observed = train_observed,
valid_observed = valid_observed,
test_observed_labels = test_observed_labels,
train_observed_labels = train_observed_labels,
valid_observed_labels = valid_observed_labels
)
class(clf_models) <- c(type, 'list')
return(clf_models)
} else {
other_models <- list(
type = type,
deleted_columns = preprocessed_data$rm_colnames,
preprocessed_data = preprocessed_data$data,
bin_labels = preprocessed_data$bin_labels,
deleted_rows = preprocessed_data$rm_rows,
models_list = models_all,
data = data,
y = y,
time = time,
status = status,
raw_train = raw_train,
check_report = check_report$str,
outliers = check_report$outliers,
best_models_on_valid = best_models_on_valid,
engine = engine,
train_data = train_data,
test_data = test_data,
valid_data = valid_data,
train_inds = split_data$train_inds,
test_inds = split_data$test_inds,
valid_inds = split_data$valid_inds,
predictions_train = predict_train,
predictions_test = predict_test,
predictions_valid = predict_valid,
predictions_best_train = predictions_best_train,
predictions_best_test = predictions_best_test,
predictions_best_valid = predictions_best_valid,
score_test = score_test,
score_train = score_train,
score_valid = score_valid,
test_observed = test_observed,
train_observed = train_observed,
valid_observed = valid_observed
)
if (type == 'regression') {
class(other_models) <- c('regression', 'list')
} else if (type == 'survival') {
class(other_models) <- c('survival', 'list')
}
return(other_models)
}
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.