knitr::opts_chunk$set( collapse = TRUE, comment = "#>", rmarkdown.html_vignette.check_title = FALSE )
library(autotextclassifier) # Auto text classifier library(parallel) # Parallel processing library(doParallel) # Parallel processing library(here) # Creating reproducible file paths library(patchwork) # Putting ggplots together library(recipes) # Preprocessing library(zeallot) # Multiple assignments library(yardstick) # Metrics
load(file = here("inst/extdata/sample_data.rda"))
Don't forget to make sure that the type of the outcome variable should be factor.
names(sample_data) <- c("category", "org_name", "ein", "text") sample_data$category <- as.factor(sample_data$category)
The rec
object provides the following information. The function also checks whether the text
column has missing values or includes extremely short documents (less than five words).
There are two broad basic options for text preprocessing.
Without word embedding
Tokenization for text [trained]
Term frequency-inverse document frequency with text [trained]
With word embedding
Tokenization for text [trained]
# Without word embedding rec <- apply_basic_recipe(sample_data, category ~ text, text) # With word embedding rec_alt <- apply_basic_recipe(sample_data, category ~ text, text, add_embedding = TRUE)
set.seed(1234) c(train_x_class, test_x_class, train_y_class, test_y_class) %<-% split_using_srs(input_data = sample_data, category = category, rec = rec)
c(lasso_spec, rand_spec, xg_spec) %<-% create_tunes()
c(lasso_grid, rand_grid, xg_grid) %<-% create_search_spaces(train_x_class, category, lasso_spec, rand_spec, xg_spec)
c(lasso_wf, rand_wf, xg_wf) %<-% create_workflows(lasso_spec, rand_spec, xg_spec, category)
set.seed(1234) class_folds <- create_cv_folds(train_x_class, train_y_class, category)
Consider using parallel processing to speed up.
all_cores <- parallel::detectCores(logical = FALSE) cl <- makeCluster(all_cores[1] - 1) registerDoParallel(cl)
The default metric for model evaluation is accuracy. The other options are balanced accuracy (bal_accuracy), F-score (f_means), and Area under the ROC curve (roc_auc).
c(best_lasso, best_rand, best_xg) %<-% find_best_model(lasso_wf, rand_wf, xg_wf, class_folds, lasso_grid, rand_grid, xg_grid, metric_choice = "accuracy")
This step would take the longest running time. We recommend saving this output as an RData object somewhere.
c(lasso_fit, rand_fit, xg_fit) %<-% fit_best_model(lasso_wf, best_lasso, rand_wf, best_rand, xg_wf, best_xg, train_x_class, train_y_class, category)
# Based on the class-based metrics viz_class_fit(lasso_fit, "Lasso", test_x_class, test_y_class, "class") + viz_class_fit(rand_fit, "Random forest", test_x_class, test_y_class, "class") + viz_class_fit(xg_fit, "XGBoost", test_x_class, test_y_class, "class") # Based on the probability-based metrics viz_class_fit(lasso_fit, "Lasso", test_x_class, test_y_class, "probability") + viz_class_fit(rand_fit, "Random forest", test_x_class, test_y_class, "probability") + viz_class_fit(xg_fit, "XGBoost", test_x_class, test_y_class, "probability")
This is just a hypothetical example. Let's denote the new data as "new_sample_data." (It has the data structure as the sample_data.) The tidymodels
package provides the following functions.
# Preprocessding data using the recipe created earlier new_preporcessed <- bake(rec, new_sample_data) # Making predictions using the best lasso model new_sample_data$category <- predict(lasso_fit, new_preporcessed)$.pred_class # If you want to extract probabilities, use the following depending on whether you want to predict FALSE or TRUE values. # new_sample_data$category <- predict(lasso_fit, new_preporcessed, type = "prob")$.pred_FALSE. # new_sample_data$category <- predict(lasso_fit, new_preporcessed, type = "prob")$.pred_TRUE.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.