R/train.R

Defines functions train

Documented in train

#' Train models with forester
#'
#' The `train()` function is the core function of this package.
#' The only obligatory arguments are `data` and `target`.
#' Setting and changing other arguments will affect model
#' validation strategy, tested model families, and so on.
#'
#' @param data A `data.frame` or `matrix` - data which will be
#' used to build models. By default model will be trained
#' on all columns in the `data`.
#' @param y A target variable, being a character name of variable in the `data`
#' that contains the target variable for classification and regression tasks.
#' By default set to NULL. If you use y, don't use `time`, and `status`, which are
#' reserved for survival analysis.
#' @param time A target variable, being a character name of variable in the `data`
#' that describes the `time` column for survival analysis task. By default set to NULL.
#' You have to use both `time`, and `status` together. If you use it, you cannot use `y`
#' as it is reserved for classification and regression tasks.
#' @param status A target variable, being a character name of variable in the `data`
#' that describes the `status` for survival analysis task. By default set to NULL.
#' You have to use both `time`, and `status` together. If you use it, you cannot use `y`
#' as it is reserved for classification and regression tasks.
#' @param type A character, one of `binary_clf`/`regression`/`survival`/`auto`/`multiclass` that
#' sets the type of the task. If `auto` (the default option) then
#' forester will figure out `type` based on the number of unique values
#' in the `y` variable, or the presence of `time`/`status` columns.
#' @param engine A vector of tree-based models that shall be tested.
#' Possible values are: `ranger`, `xgboost`, `decision_tree`, `lightgbm`, `catboost`.
#' All models from this vector will be trained and the best one will be returned.
#' It doesn't matter for survival analysis.
#' @param verbose A logical value, if set to TRUE, provides all information about
#' training process, if FALSE gives none.
#' @param check_correlation A logical value, if set to TRUE, provides information about
#' the correlations between numeric, and categorical pairs of variables as a part
#' of data check. Available only when verbose is set to TRUE. Default value is TRUE.
#' @param train_test_split A 3-value, numeric vector, describing the proportions of train,
#' test, validation subsets to original data set. Default values are: c(0.6, 0.2, 0.2).
#' @param split_seed An integer value describing the seed for the split into
#' train, test, and validation datasets. By default no seed is set and the split
#' is performed randomly. Default value is NULL.
#' @param bayes_iter An integer value describing number of optimization rounds
#' used by the Bayesian optimization. If set to 0 it turns off this method.
#' @param bayes_info A list with two values, determining the verbosity of the Bayesian
#' Optmization process. The first value is `verbose` with 3 levels: 0 - no output;
#' 1 - describes what is hapenning, and if we can reach local optimum; 2 - addtionally
#' provides infromation about recent, and the best scores. The second value is
#' `plotProgress`, which is a logical value indicating if the progress of the Bayesian
#' Optimization should be plotted. WARNING it will create plot after each step, thus
#' it might be computationally expensive. Both arguments come from the
#' `ParBayesianOptimization` package. It only matters if you set global verbose to TRUE.
#' Default values are: list(verbose = 0, plotProgress = FALSE).
#' @param random_evals An integer value describing number of trained models
#' with different parameters by random search. If set to 0 it turns off this method.
#' @param parallel A logical value indicating if the parallel method for random search
#' and Bayesian Optimizations should be used. Unfortunately it works properly
#' for ranger and xgboost models only. By default it is set to TRUE.
#' @param metrics A vector of metrics names. By default param set for `auto`, most important metrics are returned.
#' For `all` all metrics are returned. For `NULL` no metrics returned but still sorted by `sort_by`.
#' @param sort_by A string with a name of metric to sort by.
#' For `auto` models going to be sorted by `mse` for regression and `f1` for classification.
#' @param metric_function The self-created function.
#' It should look like name(predictions, observed) and return the numeric value.
#' In case of using `metrics` param with a value other than `auto` or `all`, is needed to use a value `metric_function`
#' in order to see given metric in report. If `sort_by` is equal to `auto` models are sorted by `metric_function`.
#' @param metric_function_name The name of the column with values of `metric_function` parameter.
#' By default `metric_function_name` is `metric_function`.
#' @param metric_function_decreasing A logical value indicating how metric_function
#' should be sorted. `TRUE` by default.
#' @param best_model_number Number of best models to be chosen as element of the return.
#' All trained models will be returned as different element of the return.
#' @param custom_preprocessing An object returned by the `custom_preprocessing()`
#' function. By default it is set to NULL, which indicates that basic preprocessing
#' inside the train will be executed. This process however only makes the necessary actions
#' for the `train()` to work properly.
#'
#' @return A list of all necessary objects for other functions. It contains:
#' \itemize{
#' \item \code{`data`} The original data.
#' \item \code{`y`} The original target column name.
#' \item \code{`time`} The original column name describing time for survival analysis task.
#' \item \code{`status`} The original column name describing status for survival analysis task.
#' \item \code{`type`} The type of the ML task. If the user did not specify a type in the
#' input parameters, the algorithm recognizes, uses and returns the same type.
#' It could be `binary_clf`, `regression`, `survival`, or  `multiclass`.
#'
#' \item \code{`deleted_columns`} Column names from the original data frame that have been
#' removed in the data preprocessing process, e.g. due to too high correlation
#' with other columns.
#' \item \code{`preprocessed_data`} The data frame after the preprocessing process - that
#' means: removing columns with one value for all rows, binarizing the target
#' column, managing missing values and in advanced preprocessing: deleting
#' correlated values, deleting columns that are ID-like columns and performing
#' Boruta algorithm for selecting most important features.
#' \item \code{`bin_labels`} Labels of binarized target value - 1 or 2 for binary
#' classification and NULL for regression.
#' \item \code{`deleted_rows`} The indexes of rows deleted during the preprocessing,
#' if none were removed the value is NULL.
#' \item \code{`models_list`} The list of all trained models.
#' \item \code{`check_report`} Data check report held as a list of strings. It is used
#' by the `report()` function.
#' \item \code{`outliers`} The vector of possible outliers detected by the `check_data()`.
#'
#' \item \code{`best_models_on_valid`} The object containing the best performing models
#' on the validation dataset.
#' #' \item \code{`engine`} The list of names of all types of trained models. Possible
#' values: 'ranger', 'xgboost', 'decision_tree', 'lightgbm', 'catboost'.
#' \item \code{`raw_train`} The another form of the training dataset (useful for creating
#' VS plot and predicting on training dataset for catboost and lightgbm models).
#'
#' \item \code{`train_data`} The training dataset - the part of the source dataset after
#' preprocessing, balancing and splitting into the training, test and validation
#' datasets.
#' \item \code{`test_data`} The test dataset - the part of the source dataset after
#' preprocessing, balancing and splitting into the training, test and
#' validation datasets.
#' \item \code{`valid_data`} The validation dataset - the part of the source dataset after
#' preprocessing, balancing and splitting into the training, test and validation
#' datasets.
#'
#' \item \code{`train_inds`} The vector of integers describing the observation indexes from
#' the original data frame that went to the training set.
#' \item \code{`test_inds`} The vector of integers describing the observation indexes from
#' the original data frame that went to the testing set.
#' \item \code{`valid_inds`} The vector of integers describing the observation indexes from
#' the original data frame that went to the validation set.
#'
#' \item \code{`predictions_train`} Predictions for all trained models on a train dataset.
#' \item \code{`predictions_test`} Predictions for all trained models on a test dataset.
#' \item \code{`predictions_valid`} Predictions for all trained models on a validation dataset.
#'
#' \item \code{`predictions_train_labels`} Predictions for all trained models on a
#' train dataset with human readable labels (for classification tasks only).
#' \item \code{`predictions_test_labels`} Predictions for all trained models on a
#' test dataset with human readable labels (for classification tasks only).
#' \item \code{`predictions_valid_labels`} Predictions for all trained models on a
#' validation dataset with human readable labels (for classification tasks only).
#'
#' \item \code{`predictions_best_train`} Predictions for best trained models on a train dataset.
#' \item \code{`predictions_best_test`} Predictions for best trained models on a test dataset.
#' \item \code{`predictions_best_valid`} Predictions for best trained models on a validation dataset.
#'
#' \item \code{`predictions_best_train_labels`} Predictions for best trained models on a
#' train dataset with human readable labels (for classification tasks only).
#' \item \code{`predictions_best_test_labels`} Predictions for best trained models on a
#' test dataset with human readable labels (for classification tasks only).
#' \item \code{`predictions_best_valid_labels`} Predictions for best trained models on a
#' validation dataset with human readable labels (for classification tasks only).
#'
#' \item \code{`score_train`} The list of metrics for all trained models calculated on a train
#' dataset.
#' \item \code{`score_test`} The list of metrics for all trained models calculated on a test
#' dataset.
#' \item \code{`score_valid`} The list of metrics for all trained models calculated on a validation
#' dataset.
#'
#' \item \code{`test_observed`} Values of y column from the test dataset.
#' \item \code{`train_observed`} Values of y column from the training dataset.
#' \item \code{`valid_observed`} Values of y column from the validation dataset.
#'
#' \item \code{`test_observed_labels`} Values of y column from the test dataset as text labels
#' (for classification tasks only).
#' \item \code{`train_observed_labels`} Values of y column from the training dataset as text
#' labels (for classification task only).
#' \item \code{`valid_observed_labels`} Values of y column from the validation dataset as text
#' labels (for classification task only).
#' }
#' @export
#'
#' @examples
#' \dontrun{
#' # Regression task example.
#' library(forester)
#' data('lisbon')
#' train_output <- train(lisbon, 'Price')
#' train_output$score_valid
#'
#' # Survival analysis example
#' data('peakVO2')
#' train_output <- train(peakVO2, time = 'ttodead', status = 'died')
#' train_output$score_valid
#' }
train <- function(data,
                  y                 = NULL,
                  time              = NULL,
                  status            = NULL,
                  type              = 'auto',
                  engine            = c('ranger', 'xgboost', 'decision_tree', 'lightgbm'),
                  verbose           = TRUE,
                  check_correlation = TRUE,
                  train_test_split  = c(0.6, 0.2, 0.2),
                  split_seed        = NULL,
                  bayes_iter        = 10,
                  bayes_info        = list(verbose = 0, plotProgress = FALSE),
                  random_evals      = 10,
                  parallel          = TRUE,
                  metrics           = 'auto',
                  sort_by           = 'auto',
                  metric_function   = NULL,
                  metric_function_name       = NULL,
                  metric_function_decreasing = TRUE,
                  best_model_number          = 5,
                  custom_preprocessing       = NULL) {
  t0 <- as.numeric(Sys.time())
  if (is.null(y)) {
    if (is.null(time) | is.null(status)) {
      verbose_cat(crayon::red('\u2716'), 'Lack of target variables. Please specify',
                  'either y (for classification or regression tasks), or time and',
                  'status (for survival analysis). \n\n', verbose = verbose)
      stop('Lack of target variables. Please specify either y (for classification
           or regression tasks), or time and status (for survival analysis)')
    }
  } else {
    if (!is.null(time) | !is.null(status)) {
      verbose_cat(crayon::red('\u2716'), 'Provided too many targets. Please specify',
                  'either y (for classification or regression tasks), or time and',
                  'status (for survival analysis). \n\n', verbose = verbose)
      stop('Provided too many targets. Please specify either y (for classification
           or regression tasks), or time and status (for survival analysis).')
    }
  }

  tryCatch({
    if ('catboost' %in% engine) {
      find.package('catboost')
    }
  },
  error = function(cond) {
    verbose_cat(crayon::red('\u2716'), 'Package not found: catboost, to use it please ',
                'follow guides for installation from GitHub repository README.',
                'Otherwise remove it from the engine. \n\n', verbose = verbose)
    stop('Package not found: catboost, to use it please follow guides for installation
         from GitHub repository README. Otherwise remove it from the engine.')
  })

  if ('tbl' %in% class(data) || 'list' %in% class(data) || 'matrix' %in% class(data)) {
    data <- as.data.frame(data)
    verbose_cat(crayon::red('\u2716'), 'Provided dataset is a tibble, list or matrix and not a',
                'data.frame. Casting the dataset to data.frame format. \n\n',
                verbose = verbose)
  }

  if (type == 'auto') {
    type <- guess_type(data, y)
    if (type == 'regression') {
      data[[y]] <- as.numeric(data[[y]])
    }
    verbose_cat(crayon::green('\u2714'), 'Type guessed as:', type, '\n', verbose = verbose)
  } else if (!type %in% c('regression', 'binary_clf', 'survival', 'multiclass')) {
    verbose_cat(crayon::red('\u2716'), 'Invalid value. Correct task types are: `binary_clf`, `regression`, `survival`, `multiclass`, and `auto` for automatic task identification \n', verbose = verbose)
    stop('Invalid value. Correct task types are: `binary_clf`, `regression`, `survival`, `multiclass`, and `auto` for automatic task identification')
  } else {
    verbose_cat(crayon::green('\u2714'), 'Type provided as: ', type, '\n', verbose = verbose)
  }

  if (type == 'survial') {
    if (!status %in% colnames(data) || !time %in% colnames(data)) {
      verbose_cat(crayon::red('\u2716'), 'Provided target column name for time or status parameters',
                  status, time, 'is not present in the datataset. \n', verbose = verbose)
      stop('Provided target column name for time or status parameter is not present in the datataset.')
    }
  } else if (!y %in% colnames(data)) {
    verbose_cat(crayon::red('\u2716'), 'Provided target column name for y parameter', y,
                'is not present in the datataset. \n', verbose = verbose)
    stop('Provided target column name for y parameter is not present in the datataset.')
  }

  if (parallel) {
    cores <- parallel::detectCores()
    cl    <- parallel::makeCluster(cores - 1)
    doParallel::registerDoParallel(cl)
    verbose_cat(crayon::green('\u2714'), 'Parallel processing is turned on. Registered', cores - 1, 'cores. \n', verbose = verbose)
  }

  if (is.null(custom_preprocessing)) {
    check_report              <- check_data(data, y, time, status, type, verbose, check_correlation = check_correlation)
    preprocessed_data         <- preprocessing(data, y, time, status, type)
    preprocessed_data$rm_rows <- NULL
    verbose_cat(crayon::green('\u2714'), 'Data preprocessed with basic preprocessing. \n', verbose = verbose)
  } else {
    check_report      <- check_data(custom_preprocessing$data, y, time, status, type, verbose)
    preprocessed_data <- custom_preprocessing
    verbose_cat(crayon::green('\u2714'), 'Imported preprocessed data from custom_preprocessing(). \n', verbose = verbose)
  }

  # Data splitting and recording observed variables in each dataset with distinction
  # between survival analysis and other tasks.
  if (!is.null(y)) {
    target <- y
  } else {
    target <- status
  }
  split_data <- train_test_balance(preprocessed_data$data, target, balance = TRUE,
                                   fractions = train_test_split, seed = split_seed)

  train_observed <- split_data$train[[target]]
  test_observed  <- split_data$test[[target]]
  valid_observed <- split_data$valid[[target]]

  verbose_cat(crayon::green('\u2714'), 'Data split and balanced. \n', verbose = verbose)

  train_data <- prepare_data(split_data$train, type, y, time, status, engine)
  test_data  <- prepare_data(split_data$test, type,  y, time, status, engine,
                             predict = TRUE, split_data$train)
  valid_data <- prepare_data(split_data$valid, type, y, time, status, engine,
                             predict = TRUE, split_data$train)

  # For creating VS plot and predicting on train (catboost, lgbm).
  raw_train  <- prepare_data(split_data$train, type, y, time, status,engine,
                             predict = TRUE, split_data$train)

  verbose_cat(crayon::green('\u2714'), 'Correct formats prepared. \n', verbose = verbose)

  b_t0 <- as.numeric(Sys.time())
  model_basic    <- train_models(train_data, y, time, status, engine, type)
  b_t1 <- as.numeric(Sys.time())
  verbose_cat('\n', crayon::green('\u2714'), ' Models with default parameters successfully trained. \n', verbose = verbose, sep = '')
  verbose_cat('   ', crayon::green('\u2714'), 'Default: It took', round(b_t1 - b_t0, 2), 'seconds. \n', verbose = verbose)

  if (random_evals > 0) {
    rs_t0 <- as.numeric(Sys.time())
    verbose_cat('\n', crayon::green('\u2714'), ' Starting Random Search training process. \n', verbose = verbose, sep = '')
  }

  model_random   <- random_search(train_data,
                                  y         = y,
                                  time      = time,
                                  status    = status,
                                  engine    = engine,
                                  type      = type,
                                  max_evals = random_evals,
                                  parallel  = parallel,
                                  verbose   = verbose)
  if (random_evals > 0) {
    rs_t1 <- as.numeric(Sys.time())
    verbose_cat('\n', crayon::green('\u2714'), ' Models optimized with Random Search successfully trained. \n', verbose = verbose, sep = '')
    verbose_cat('   ', crayon::green('\u2714'), 'Random Search: It took', round(rs_t1 - rs_t0, 2), 'seconds. \n', verbose = verbose)
  }

  if (bayes_iter > 0) {
    bo_t0 <- as.numeric(Sys.time())
    verbose_cat('\n', crayon::green('\u2714'), ' Starting Bayesian Optimization training process. \n', verbose = verbose, sep = '')
  }

  model_bayes <- train_models_bayesopt(train_data,
                                       y          = y,
                                       time       = time,
                                       status     = status,
                                       test_data  = test_data,
                                       engine     = engine,
                                       type       = type,
                                       parallel   = parallel,
                                       iters.n    = bayes_iter,
                                       bayes_info = bayes_info,
                                       verbose    = verbose)
  if (bayes_iter > 0) {
    bo_t1 <- as.numeric(Sys.time())
    verbose_cat('\n', crayon::green('\u2714'), ' Models optimized with Bayesian Optimization successfully trained. \n', verbose = verbose, sep = '')
    verbose_cat('   ', crayon::green('\u2714'), 'Bayesian Optimization: It took', round(bo_t1 - bo_t0, 2), 'seconds. \n', verbose = verbose)
  }


  models_all <- c(model_basic, model_random$models, model_bayes)
  engine_all <- c(engine, model_random$engine, engine)

  if (type != 'survival') {
    tuning <- c(rep('basic', length(engine)),
                rep('random_search', length(model_random$engine)),
                rep('bayes_opt', length(engine)))
  } else {
    tuning <- c('basic',
                rep('random_search', length(model_random$engine)),
                'bayes_opt')
  }

  predict_train    <- predict_models_all(models_all, raw_train,  y, type = type)
  predict_test     <- predict_models_all(models_all, test_data,  y, type = type)
  predict_valid    <- predict_models_all(models_all, valid_data, y, type = type)

  verbose_cat('\n', crayon::green('\u2714'), ' Created the predictions for all models. \n', verbose = verbose, sep = '')

  score_train <- score_models(models_all,
                              predict_train,
                              train_data$ranger_data[[y]],
                              train_data,
                              type,
                              time,
                              status,
                              metrics = metrics,
                              sort_by = sort_by,
                              metric_function = metric_function,
                              metric_function_name = metric_function_name,
                              metric_function_decreasing = metric_function_decreasing,
                              engine = engine_all,
                              tuning = tuning)

  score_test  <- score_models(models_all,
                              predict_test,
                              test_data$ranger_data[[y]],
                              test_data,
                              type,
                              time,
                              status,
                              metrics = metrics,
                              sort_by = sort_by,
                              metric_function = metric_function,
                              metric_function_name = metric_function_name,
                              metric_function_decreasing = metric_function_decreasing,
                              engine = engine_all,
                              tuning = tuning)

  score_valid <- score_models(models_all,
                              predict_valid,
                              valid_data$ranger_data[[y]],
                              valid_data,
                              type,
                              time,
                              status,
                              metrics = metrics,
                              sort_by = sort_by,
                              metric_function = metric_function,
                              metric_function_name = metric_function_name,
                              metric_function_decreasing = metric_function_decreasing,
                              engine = engine_all,
                              tuning = tuning)

  verbose_cat(crayon::green('\u2714'), 'Created the score boards for all models. \n', verbose = verbose)

  choose_best_models <- function(models, engine, score, number) {
    number <- min(number, length(models))
    return(list(
      models = models[score[1:number, 'name']],
      engine = score[1:number, 'engine']))
  }

  best_models_on_valid   <- choose_best_models(models_all, engine_all, score_valid, best_model_number)
  predictions_best_train <- predict_models_all(best_models_on_valid$models, raw_train,  y, type = type)
  predictions_best_test  <- predict_models_all(best_models_on_valid$models, test_data,  y, type = type)
  predictions_best_valid <- predict_models_all(best_models_on_valid$models, valid_data, y, type = type)

  verbose_cat(crayon::green('\u2714'), 'Created the predictions for the best models. \n', verbose = verbose)

  # Providing the original labels to the target.
  if (type == 'binary_clf') {
    test_observed  <- as.numeric(test_observed)  - 1
    train_observed <- as.numeric(train_observed) - 1
    valid_observed <- as.numeric(valid_observed) - 1

    test_observed_labels           <- test_observed
    train_observed_labels          <- train_observed
    valid_observed_labels          <- valid_observed

    predict_train_labels           <- predict_train
    predict_test_labels            <- predict_test
    predict_valid_labels           <- predict_valid

    predictions_best_train_labels  <- predictions_best_train
    predictions_best_test_labels   <- predictions_best_test
    predictions_best_valid_labels  <- predictions_best_valid

    labels <- preprocessed_data$bin_labels
    # Human-readable observed values with text labels.
    # For the observed values.
    for (i in 1:length(train_observed)) {
      if (train_observed[i] < 0.5) {
        train_observed_labels[i] <- labels[1]
      } else {
        train_observed_labels[i] <- labels[2]
      }
    }
    for (i in 1:length(test_observed)) {
      if (test_observed[i] < 0.5) {
        test_observed_labels[i] <- labels[1]
      } else {
        test_observed_labels[i] <- labels[2]
      }
    }
    for (i in 1:length(valid_observed)) {
      if (valid_observed[i] < 0.5) {
        valid_observed_labels[i] <- labels[1]
      } else {
        valid_observed_labels[i] <- labels[2]
      }
    }
    # For the all models predictions.
    for (j in 1:length(predict_train)){
      for (i in 1:length(predict_train[[j]])) {
        if (predict_train[[j]][i] < 0.5) {
          predict_train_labels[[j]][i] <- labels[1]
        } else {
          predict_train_labels[[j]][i] <- labels[2]
        }
      }
      for (i in 1:length(predict_test[[j]])) {
        if (predict_test[[j]][i] < 0.5) {
          predict_test_labels[[j]][i] <- labels[1]
        } else {
          predict_test_labels[[j]][i] <- labels[2]
        }
      }
      for (i in 1:length(predict_valid[[j]])) {
        if (predict_valid[[j]][i] < 0.5) {
          predict_valid_labels[[j]][i] <- labels[1]
        } else {
          predict_valid_labels[[j]][i] <- labels[2]
        }
      }
    }
    # For the best models predictions.
    for (j in 1:length(predictions_best_train)){
      for (i in 1:length(predictions_best_train[[j]])) {
        if (predictions_best_train[[j]][i] < 0.5) {
          predictions_best_train_labels[[j]][i] <- labels[1]
        } else {
          predictions_best_train_labels[[j]][i] <- labels[2]
        }
      }
      for (i in 1:length(predictions_best_test[[j]])) {
        if (predictions_best_test[[j]][i] < 0.5) {
          predictions_best_test_labels[[j]][i] <- labels[1]
        } else {
          predictions_best_test_labels[[j]][i] <- labels[2]
        }
      }
      for (i in 1:length(predictions_best_valid[[j]])) {
        if (predictions_best_valid[[j]][i] < 0.5) {
          predictions_best_valid_labels[[j]][i] <- labels[1]
        } else {
          predictions_best_valid_labels[[j]][i] <- labels[2]
        }
      }
    }
  }

  if (type == 'multiclass') {
    test_observed  <- as.numeric(test_observed)
    train_observed <- as.numeric(train_observed)
    valid_observed <- as.numeric(valid_observed)

    test_observed_labels           <- test_observed
    train_observed_labels          <- train_observed
    valid_observed_labels          <- valid_observed

    predict_train_labels           <- predict_train
    predict_test_labels            <- predict_test
    predict_valid_labels           <- predict_valid

    predictions_best_train_labels  <- predictions_best_train
    predictions_best_test_labels   <- predictions_best_test
    predictions_best_valid_labels  <- predictions_best_valid

    labels <- preprocessed_data$bin_labels
    # Human-readable observed values with text labels.
    # For the observed values.
    for (i in 1:length(train_observed)) {
      train_observed_labels[i] <- labels[train_observed[i]]
    }
    for (i in 1:length(test_observed)) {
      test_observed_labels[i]  <- labels[test_observed[i]]
    }
    for (i in 1:length(valid_observed)) {
      valid_observed_labels[i] <- labels[valid_observed[i]]
    }
    # For the all models predictions.
    for (j in 1:length(predict_train)){
      for (i in 1:length(predict_train[[j]])) {
        predict_train_labels[[j]][i] <- labels[predict_train[[j]][i]]
      }
      for (i in 1:length(predict_test[[j]])) {
        predict_test_labels[[j]][i]  <- labels[predict_test[[j]][i]]
      }
      for (i in 1:length(predict_valid[[j]])) {
        predict_valid_labels[[j]][i] <- labels[predict_valid[[j]][i]]
      }
    }
    # For the best models predictions.
    for (j in 1:length(predictions_best_train)){
      for (i in 1:length(predictions_best_train[[j]])) {
        predictions_best_train_labels[[j]][i] <- labels[predictions_best_train[[j]][i]]
      }
      for (i in 1:length(predictions_best_test[[j]])) {
        predictions_best_test_labels[[j]][i]  <- labels[predictions_best_test[[j]][i]]
      }
      for (i in 1:length(predictions_best_valid[[j]])) {
        predictions_best_valid_labels[[j]][i] <- labels[predictions_best_valid[[j]][i]]
      }
    }
  }

  verbose_cat(crayon::green('\u2714'), 'Created human-readable labels for observables and predictions. \n', verbose = verbose)
  t1 <- as.numeric(Sys.time())
  verbose_cat(crayon::green('\u2714'), 'The train() run took', round(t1 - t0, 2), 'seconds. \n', verbose = verbose)
  if (type %in% c('binary_clf', 'multiclass')) {
    clf_models <- list(
        data                    = data,
        y                       = y,
        time                    = time,
        status                  = status,
        type                    = type,

        deleted_columns         = preprocessed_data$rm_colnames,
        preprocessed_data       = preprocessed_data$data,
        bin_labels              = preprocessed_data$bin_labels,
        deleted_rows            = preprocessed_data$rm_rows,

        models_list             = models_all,
        check_report            = check_report$str,
        outliers                = check_report$outliers,

        best_models_on_valid    = best_models_on_valid,
        engine                  = engine,
        raw_train               = raw_train,

        train_data              = train_data,
        test_data               = test_data,
        valid_data              = valid_data,

        train_inds              = split_data$train_inds,
        test_inds               = split_data$test_inds,
        valid_inds              = split_data$valid_inds,

        predictions_train              = predict_train,
        predictions_test               = predict_test,
        predictions_valid              = predict_valid,

        predictions_train_labels       = predict_train_labels,
        predictions_test_labels        = predict_test_labels,
        predictions_valid_labels       = predict_valid_labels,

        predictions_best_train         = predictions_best_train,
        predictions_best_test          = predictions_best_test,
        predictions_best_valid         = predictions_best_valid,

        predictions_best_train_labels  = predictions_best_train_labels,
        predictions_best_test_labels   = predictions_best_test_labels,
        predictions_best_valid_labels  = predictions_best_valid_labels,

        score_test              = score_test,
        score_train             = score_train,
        score_valid             = score_valid,

        test_observed           = test_observed,
        train_observed          = train_observed,
        valid_observed          = valid_observed,

        test_observed_labels    = test_observed_labels,
        train_observed_labels   = train_observed_labels,
        valid_observed_labels   = valid_observed_labels
      )
    class(clf_models) <- c(type, 'list')
    return(clf_models)
  } else {
    other_models <- list(
      type                    = type,
      deleted_columns         = preprocessed_data$rm_colnames,
      preprocessed_data       = preprocessed_data$data,
      bin_labels              = preprocessed_data$bin_labels,
      deleted_rows            = preprocessed_data$rm_rows,

      models_list             = models_all,
      data                    = data,
      y                       = y,
      time                    = time,
      status                  = status,

      raw_train               = raw_train,
      check_report            = check_report$str,
      outliers                = check_report$outliers,

      best_models_on_valid    = best_models_on_valid,
      engine                  = engine,

      train_data              = train_data,
      test_data               = test_data,
      valid_data              = valid_data,

      train_inds              = split_data$train_inds,
      test_inds               = split_data$test_inds,
      valid_inds              = split_data$valid_inds,

      predictions_train       = predict_train,
      predictions_test        = predict_test,
      predictions_valid       = predict_valid,

      predictions_best_train  = predictions_best_train,
      predictions_best_test   = predictions_best_test,
      predictions_best_valid  = predictions_best_valid,

      score_test              = score_test,
      score_train             = score_train,
      score_valid             = score_valid,

      test_observed           = test_observed,
      train_observed          = train_observed,
      valid_observed          = valid_observed
    )
    if (type == 'regression') {
      class(other_models) <- c('regression', 'list')
    } else if (type == 'survival') {
      class(other_models) <- c('survival', 'list')
    }
    return(other_models)
  }
}
ModelOriented/forester documentation built on June 6, 2024, 7:29 a.m.