#' Predict models depending on the engine
#'
#' As all Machine Learning models have different predicting pipelines, we have to
#' provide a helpful tool for normalization of making predictions.
#'
#' @param models A list of models trained by `train_models()` function.
#' @param data A test data for models created by `prepare_data()` function.
#' @param y A string that indicates a target column name.
#' @param engine A vector of tree-based models that shall be created. Possible
#' values are: `ranger`, `xgboost`, `decision tree`, `lightgbm`, `catboost`.
#' @param type A string that determines if Machine Learning task is the
#' `binary_clf` or `regression` task.
#' @param probability A logical value that determines whether the output for
#' classification task should be 0/1 or described by probability.
#'
#' @return A list of predictions for every engine.
#' @export
#' @importFrom stats as.formula predict
predict_models <- function(models, data, y, engine, type, probability = FALSE) {
ranger_preds <- NULL
xgboost_preds <- NULL
decision_tree_preds <- NULL
lightgbm_preds <- NULL
catboost_preds <- NULL
if (type == 'regression') {
for (i in 1:length(engine)) {
if (engine[i] == 'ranger') {
ranger_preds <- (predict(models[[i]], data$ranger_data))$predictions
}
if (engine[i] == 'xgboost') {
xgboost_preds <- (predict(models[[i]], data$xgboost_data))
}
if (engine[i] == 'decision_tree') {
decision_tree_preds <-
unname(predict(models[[i]], data$decision_tree_data))
}
if (engine[i] == 'lightgbm') {
lightgbm_preds <- (predict(models[[i]], data$lightgbm_data))
}
if (engine[i] == 'catboost') {
catboost_preds <- (
catboost::catboost.predict(models[[i]],
data$catboost_data,
prediction_type = 'RawFormulaVal')
)
}
}
} else if (type == 'binary_clf') {
for (i in 1:length(engine)) {
if (engine[i] == 'ranger') {
ranger_preds <-
(ranger::predictions(predict(models[[i]], data$ranger_data))[, 2])
}
if (engine[i] == 'xgboost') {
xgboost_preds <- (predict(models[[i]], data$xgboost_data))
}
if (engine[i] == 'decision_tree') {
decision_tree_preds <- (unname(predict(models[[i]], data$decision_tree_data, type = 'prob')[, 2]))
}
if (engine[i] == 'lightgbm') {
lightgbm_preds <- (predict(models[[i]], data$lightgbm_data))
}
if (engine[i] == 'catboost') {
catboost_preds <- (
catboost::catboost.predict(models[[i]],
data$catboost_data,
prediction_type = 'RawFormulaVal')
)
}
}
}
if (type == 'binary_clf' && probability == FALSE) {
treshold <- 0.5
ranger_preds <- (ranger_preds >= treshold) + 1
decision_tree_preds <- (decision_tree_preds >= treshold) + 1
catboost_preds <- (catboost_preds >= treshold) + 1
xgboost_preds <- (xgboost_preds >= treshold) + 1
lightgbm_preds <- (lightgbm_preds >= treshold) + 1
}
return(
list(
ranger_preds = ranger_preds,
xgboost_preds = xgboost_preds,
decision_tree_preds = decision_tree_preds,
lightgbm_preds = lightgbm_preds,
catboost_preds = catboost_preds
)
)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.