library(tidymodels) library(tidyverse) library(bestglm) library(textclassificationexamples) knitr::opts_chunk$set( collapse = TRUE, comment = "#>" )
The textclassificationexamples
package contains the data set emails
that consists of three variables:
ids
: An integer indicating row IDsubjectline
: A character variable containing the email subject line spam
: a logical variable that takes the value TRUE if the article is spam, and FALSE if the article is not spamThis data has already been split into training (emails_train
) and test (emails_test
) data, with the training data consisting of about 80% of observations, and the test data the remaining 20%.
Through this vignette, we'll use these data sets along with the helper functions from the textclassificationexamples
package in order to classify articles as spam or not.
First, we'll load the data so it's available to work with.
# LOAD DATA ------------------------------------------------------------------------------ data(emails_train) data(emails_test)
In order to classify the emails, we'll have to come up with certain features or explanatory variables that might contribute to whether or not a subject line is indicative of the email being spam or not.
A few functions available in the textclassificationexamples
package help to facilitate the creation of some such predetermined features
These include:
all_caps()
: Takes a character string and returns a logical - TRUE if the string is in all caps, and FALSE if it is not. has_dollar_sign()
: Takes a character string and returns a logical - TRUE if the string contains a dollar sign, and FALSE if it does not. has_dear()
: Takes a character string and returns a logical - TRUE if the string contains the word "dear", and FALSE if it does not. has_mister()
: Takes a character string and returns a logical - TRUE if the string contains the word "Mister," or "Mr," and FALSE if it does not. multiple_punctuation()
: Takes a character string and returns a logical - TRUE if the string contains multiple punctuation marks, and FALSE if it does not. has_religious()
: Takes a character string and returns a logical - TRUE if the string contains religious wording, and FALSE if it does not. Using dplyr::mutate()
, we can apply each of these functions to the subject lines in our data sets, and as such create our features.
# CREATE FEATURES ------------------------------------------------------------------------ spam_train <- emails_train |> na.omit() |> mutate( # combine the two non-spam levels into one type = as.factor(ifelse(type %in% c("spam"), "spam", "not_spam")), caps = as.factor(all_caps(subjectline)), dollar_sign = as.factor(has_dollar_sign(subjectline)), dear = as.factor(has_dear(subjectline)), punctuation = as.factor(multiple_punctuation(subjectline)), religious = as.factor(has_religious(subjectline)) ) |> select(-c(ids, subjectline)) spam_test <- emails_test |> na.omit() |> mutate( # combine the two non-spam levels into one type = as.factor(ifelse(type %in% c("spam"), "spam", "not_spam")), caps = as.factor(all_caps(subjectline)), dollar_sign = as.factor(has_dollar_sign(subjectline)), dear = as.factor(has_dear(subjectline)), punctuation = as.factor(multiple_punctuation(subjectline)), religious = as.factor(has_religious(subjectline)) ) |> select(-c(ids, subjectline))
Now that we have our features, we're ready to start modeling. First, we'll look at the accuracy of the null model (which classifies every observation into the most predominant class.)
# NULL MODEL ----------------------------------------------------------------------------- table(spam_train$type) mean(spam_train$type == "not_spam")
We can see that the most predominant class is not spam, and the default model that classifies all observations as not spam has an accuracy of around 87%.
We'll consider a simple, additive logistic regression as our first model. In order to decide which variables to include, we use stepAIC()
from the MASS
package to perform backward selection.
# MODEL SELECTION ------------------------------------------------------------------------ full_model <- glm( type ~ ., data = spam_train, family = binomial ) MASS::stepAIC(full_model, direction = "backward")
The automated selection process suggests the three predictor model using caps
, dollar_sign
, punctuation
, and religious
, to predict spam
.
# MODEL FIT ------------------------------------------------------------------------------ simple_logistic_model<- logistic_reg() %>% set_engine("glm") %>% set_mode("classification") %>% fit(type ~ caps + dollar_sign + punctuation + religious, data = spam_train) tidy(simple_logistic_model)
We can filter the above table to include only the significant terms, as below
tidy(simple_logistic_model) |> filter(p.value < 0.05)
From the above, we can see that the only significant features at a 0.05 level are caps
and dollar_sign
.
Now that we've fitted the model, it remains for us to interpret it, and then utilize it for prediction. As it currently stands, the coefficients in the model are in log-odds form. For easier interpretation, we can look at odds ratios by exponentiating these.
tidy(simple_logistic_model, exponentiate = TRUE)
ADD INTERPRETATION
We can now use our trained model to predict clickbait on our test data set. We look at both the predicted class, and the predicted probabilities that an article is clickbait or not.
# MODEL PREDICTION ----------------------------------------------------------------------- set.seed(495) pred_class_logistic <- predict( simple_logistic_model, new_data = spam_test, type = "class" ) pred_prob_logistic <- predict( simple_logistic_model, new_data = spam_test, type = "prob" ) spam_results_logistic <- spam_test |> select(type) |> bind_cols(pred_class_logistic, pred_prob_logistic)
Using our results, we can assess the testing accuracy of our model as below.
# ASSESSMENT METRICS --------------------------------------------------------------------- confusion_matrix_logistic <- yardstick::conf_mat( spam_results_logistic, truth = type, estimate = .pred_class ) confusion_matrix_logistic accuracy_logistic <- yardstick::accuracy_vec( truth = spam_results_logistic$type, estimate = spam_results_logistic$.pred_class ) accuracy_logistic
Our logistic regression model has an accuracy of around 89%, which is not much of an improvement over the null model.
We'll use the same variables as in our logistic regression model to fit a decision tree. In the plot below, the boxes at each node display three values, the first of which is the classification decision (spam or not_spam). The second value gives the probability of this classification, and the third gives the percentage of observations that fall into this classification at that stage.
# MODEL FIT ------------------------------------------------------------------------------ dec_tree <- decision_tree(tree_depth = 10) |> set_engine("rpart") |> set_mode("classification") |> fit(type ~ ., data = spam_train) rpart.plot::rpart.plot(dec_tree$fit)
We can assess the tree by considering the training and test accuracy, which we see are around 89% - once more barely an improvement on the null model.
# ASSESSMENT METRICS --------------------------------------------------------------------- # training accuracy augment(dec_tree, new_data = spam_train) |> accuracy(truth = type, estimate = .pred_class) # testing accuracy augment(dec_tree, new_data = spam_test) |> accuracy(truth = type, estimate = .pred_class)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.