library(tidymodels) library(tidyverse) library(bestglm) library(scales) library(textclassificationexamples) knitr::opts_chunk$set( collapse = TRUE, comment = NA )
The textclassificationexamples
package contains the data set headlines
that consists of three variables:
ids
: An integer indicating row IDtitle
: A character variable containing the article headline clickbait
: a logical variable that takes the value TRUE if the article is clickbait, and FALSE if the article is not clickbaitThis data has already been split into training (headline_train
) and test (headline_test
) data, with the training data consisting of about 80% of observations, and the test data the remaining 20%.
Through this vignette, we'll use these data sets along with the helper functions from the textclassificationexamples
package in order to classify articles as clickbait or not.
First, we'll load the data so it's available to work with.
# LOAD DATA ------------------------------------------------------------------------------ data(headlines_train) data(headlines_test)
Now both headlines_train
and headlines_test
are available as objects in our environment.
In order to classify the headlines, we'll have to come up with certain features or explanatory variables that might contribute to whether or not a headline is clickbait.
A few functions available in the textclassificationexamples
package help to facilitate the creation of some such predetermined features
These include:
has_common_phrase()
: Takes a character string and returns a logical - TRUE if the string contains a common phrase, and FALSE if it does not. has_exaggerated_phrase()
: Takes a character string and returns a logical - TRUE if the string contains an exaggerated phrase, and FALSE if it does not. num_contractions()
: Takes a character string and returns an integer - the number of contractions contained in the string. num_stop_words()
: Takes a character string and returns an integer - the number of stop words contained in the string. num_pronouns()
: Takes a character string and returns an integer - the number of pronouns contained in the string. starts_with_num()
: Takes a character string and returns a logical - TRUE if the string begins with a number, and FALSE if it does not. has_question_word()
: Takes a character string and returns a logical - TRUE if the string contains a question word, and FALSE if it does not.positivity()
: Takes a character string and returns the sum of the AFINN positivity scores of the words in the string. It's easy to see how these functions might help us determine whether or not an article is clickbait. For example, many clickbait articles appear to begin with numbers, in which case starts_with_num()
might be useful.
Using mutate()
, we can apply each of these functions to the titles in our data sets, and as such create our features. Note that in the case of positivity
, as well as num_stop_words
, we have to use rowwise()
to ensure that we evaluate each title individually.
# CREATE FEATURES ------------------------------------------------------------------------ clickbait_train <- headlines_train |> na.omit() |> mutate( clickbait = as.factor(clickbait), common = as.factor(has_common_phrase(title)), exaggerated = as.factor(has_exaggerated_phrase(title)), num_contractions = num_contractions(title), num_words = num_words(title), num_pronouns = num_pronouns(title), starts_num = as.factor(starts_with_num(title)), question = as.factor(has_question_word(title)) ) |> rowwise() |> mutate( num_stop_words = num_stop_words(title), positivity = positivity(title)) |> select(-c(ids, title)) # remove non feature variables clickbait_test <- headlines_test |> na.omit() |> mutate( clickbait = as.factor(clickbait), common = as.factor(has_common_phrase(title)), exaggerated = as.factor(has_exaggerated_phrase(title)), num_contractions = num_contractions(title), num_stop_words = num_stop_words(title), num_words = num_words(title), num_pronouns = num_pronouns(title), starts_num = as.factor(starts_with_num(title)), question = as.factor(has_question_word(title)) ) |> rowwise() |> mutate( num_stop_words = num_stop_words(title), positivity = positivity(title)) |> select(-c(ids, title)) # remove non feature variables
The clickbait data sets with features can also be directly loaded from the textclassificationexamples
package, as below
data(clickbait_train) data(clickbait_test) clickbait_train <- clickbait_train |> select(-c(ids, title)) clickbait_test <- clickbait_test |> select(-c(ids, title))
Now that we have our features, we're ready to start modeling. First, we'll look at the accuracy of the null model (which classifies every observation into the most predominant class.)
# NULL MODEL ----------------------------------------------------------------------------- table(clickbait_train$clickbait) mean(clickbait_train$clickbait == "FALSE") mean(clickbait_test$clickbait == "FALSE")
We can see that the most predominant class appears to be FALSE (not clickbait), and both the training and test accuracy of the model that classifies every observation as not clickbait are around 0.655.
We'll consider a simple, additive logistic regression as our first model. In order to decide which variables to include, we use stepAIC()
from the MASS
package to perform backward selection.
# MODEL SELECTION ------------------------------------------------------------------------ full_model <- glm( clickbait ~ ., data = clickbait_train, family = binomial ) MASS::stepAIC(full_model, trace = 0)
The automated selection procedure selects the full model including all features, which we fit below. Notice that we reorder the levels of the factor clickbait
so that the reference level is "TRUE", and our model
# MODEL FIT ------------------------------------------------------------------------------ clickbait_train$clickbait <- relevel(clickbait_train$clickbait, ref = "FALSE") clickbait_test$clickbait <- relevel(clickbait_test$clickbait, ref = "FALSE") set.seed(495) simple_logistic_model<- logistic_reg() %>% set_engine("glm") %>% set_mode("classification") %>% fit(clickbait ~ ., data = clickbait_train) tidy(simple_logistic_model)
We can filter the above table to include only the significant terms, as below
tidy(simple_logistic_model) |> filter(p.value < 0.05)
From the above, we can see that all of the features appear significant at a 0.05 level, with the exception of common
and question
.
Now that we've fitted the model, it remains for us to interpret it, and then utilize it for prediction. As it currently stands, the coefficients in the model are in log-odds form. For easier interpretation, we can look at odds ratios by exponentiating these.
tidy(simple_logistic_model, exponentiate = TRUE)
ADD INTERPRETATION
We can now use our trained model to predict clickbait on our test data set. We look at both the predicted class, and the predicted probabilities that an article is clickbait or not.
# MODEL PREDICTION ----------------------------------------------------------------------- set.seed(495) pred_class_logistic <- predict( simple_logistic_model, new_data = clickbait_test, type = "class" ) pred_prob_logistic <- predict( simple_logistic_model, new_data = clickbait_test, type = "prob" ) clickbait_results_logistic <- clickbait_test |> select(clickbait) |> bind_cols(pred_class_logistic, pred_prob_logistic)
Using our results, we can assess the testing accuracy of our model as below.
# ASSESSMENT METRICS --------------------------------------------------------------------- confusion_matrix_logistic <- yardstick::conf_mat( clickbait_results_logistic, truth = clickbait, estimate = .pred_class ) confusion_matrix_logistic accuracy_logistic <- yardstick::accuracy_vec( truth = clickbait_results_logistic$clickbait, estimate = clickbait_results_logistic$.pred_class ) accuracy_logistic
Our logistic regression model has an accuracy of around 0.878, which is a clear improvement over the null model.
We'll use the same variables as in our logistic regression model to fit a decision tree. In the plot below, the boxes at each node display three values, the first of which is the classification decision (TRUE if an observation is classified as clickbait, and FALSE if it is not). The second value gives the probability of this classification, and the third gives the percentage of observations that fall into this classification at that stage.
# MODEL FIT ------------------------------------------------------------------------------ dec_tree <- decision_tree(tree_depth = 10) |> set_engine("rpart") |> set_mode("classification") |> fit(clickbait ~ ., data = clickbait_train) rpart.plot::rpart.plot(dec_tree$fit)
We can assess the tree by considering the training and test accuracy, which we see are around 0.89 - once more an improvement on the null model, and a slight improvement on the logistic regression model we fit earlier.
# ASSESSMENT METRICS --------------------------------------------------------------------- # training accuracy augment(dec_tree, new_data = clickbait_train) |> accuracy(truth = clickbait, estimate = .pred_class) # testing accuracy augment(dec_tree, new_data = clickbait_test) |> accuracy(truth = clickbait, estimate = .pred_class)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.