factory_pipeline_r: Prepare and fit a text classification pipeline

Description Usage Arguments Details Value Note References Examples

View source: R/factory_pipeline_r.R

Description

Prepare and fit a text classification pipeline with Scikit-learn.

Usage

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
factory_pipeline_r(
  x,
  y,
  tknz = "spacy",
  ordinal = FALSE,
  metric = "class_balance_accuracy_score",
  cv = 5,
  n_iter = 2,
  n_jobs = 1,
  verbose = 3,
  learners = c("SGDClassifier", "RidgeClassifier", "Perceptron",
    "PassiveAggressiveClassifier", "BernoulliNB", "ComplementNB", "MultinomialNB",
    "RandomForestClassifier"),
  theme = NULL
)

Arguments

x

Data frame. The text feature.

y

Vector. The response variable.

tknz

Tokenizer to use ("spacy" or "wordnet").

ordinal

Whether to fit an ordinal classification model. The ordinal model is the implementation of Frank and Hall (2001) that can use any standard classification model that calculates probabilities.

metric

String. Scorer to use during pipeline tuning ("accuracy_score", "balanced_accuracy_score", "matthews_corrcoef", "class_balance_accuracy_score").

cv

Number of cross-validation folds.

n_iter

Number of parameter settings that are sampled (see sklearn.model_selection.RandomizedSearchCV).

n_jobs

Number of jobs to run in parallel (see sklearn.model_selection.RandomizedSearchCV). NOTE: If your machine does not have the number of cores specified in n_jobs, then an error will be returned.

verbose

Controls the verbosity (see sklearn.model_selection.RandomizedSearchCV).

learners

Vector. Scikit-learn names of the learners to tune. Must be one or more of "SGDClassifier", "RidgeClassifier", "Perceptron", "PassiveAggressiveClassifier", "BernoulliNB", "ComplementNB", "MultinomialNB", "KNeighborsClassifier", "NearestCentroid", "RandomForestClassifier". When a single model is used, it can be passed as a string.

theme

String. For internal use by Nottinghamshire Healthcare NHS Foundation Trust or other trusts that use theme labels ("Access", "Environment/ facilities" etc.). The column name of the theme variable. Defaults to NULL. If supplied, the theme variable will be used as a predictor (along with the text predictor) in the model that is fitted with criticality as the response variable. The rationale is two-fold. First, to help the model improve predictions on criticality when the theme labels are readily available. Second, to force the criticality for "Couldn't be improved" to always be "3" in the training and test data, as well as in the predictions. This is the only criticality value that "Couldn't be improved" can take, so by forcing it to always be "3", we are improving model performance, but are also correcting possible erroneous assignments of values other than "3" that are attributed to human error.

Details

The pipeline's parameter grid switches between two approaches to text classification: Bag-of-Words and Embeddings. For the former, both TF-IDF and raw counts are tried out.

The pipeline does the following:

The numeric values in the grid are currently lists/tuples (Python objects) of values that are defined either empirically or are based on the published literature (e.g. for Random Forest, see Probst et al. 2019). Values may be replaced by appropriate distributions in a future release.

Value

A fitted Scikit-learn pipeline containing a number of objects that can be accessed with the $ sign (see examples). For a partial list see "Atributes" in sklearn.model_selection.RandomizedSearchCV. Do not be surprised if more objects are in the pipeline than those in the aforementioned "Attributes" list. Python objects can contain several objects, from numeric results (e.g. the pipeline's accuracy), to methods (i.e. functions in the R lingo) and classes. In Python, these are normally accessed with object.<whatever>, but in R the command is object$<whatever>. For instance, one can access method predict() to make predictions on unseen data. See Examples.

Note

The pipeline uses the tokenizers of Python library pxtextmining. Any warnings from Scikit-learn like UserWarning: The parameter 'token_pattern' will not be used since 'tokenizer' is not None' can therefore be safely ignored.

Also, any warnings about over-sampling can also be safely ignored. These warnings are simply a result of an internal check in the over-sampler of imblearn.

References

Frank E. & Hall M. (2001). A Simple Approach to Ordinal Classification. Machine Learning: ECML 2001 145–156.

Pedregosa F., Varoquaux G., Gramfort A., Michel V., Thirion B., Grisel O., Blondel M., Prettenhofer P., Weiss R., Dubourg V., Vanderplas J., Passos A., Cournapeau D., Brucher M., Perrot M. & Duchesnay E. (2011), Scikit-learn: Machine Learning in Python. Journal of Machine Learning Research 12:2825–-2830.

Probst P., Bischl B. & Boulesteix A-L (2018). Tunability: Importance of Hyperparameters of Machine Learning Algorithms. https://arxiv.org/abs/1802.09596

Examples

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
# Prepare training and test sets
data_splits <- pxtextmineR::factory_data_load_and_split_r(
  filename = pxtextmineR::text_data,
  target = "label",
  predictor = "feedback",
  test_size = 0.90) # Make a small training set for a faster run in this example

# Let's take a look at the returned list
str(data_splits)

# Fit the pipeline
pipe <- pxtextmineR::factory_pipeline_r(
  x = data_splits$x_train,
  y = data_splits$y_train,
  tknz = "spacy",
  ordinal = FALSE,
  metric = "accuracy_score",
  cv = 2, n_iter = 1, n_jobs = 1, verbose = 3,
  learners = "SGDClassifier"
)

# Mean cross-validated score of the best_estimator
pipe$best_score_

# Best parameters during tuning
pipe$best_params_

# Is the best model a linear SVM (loss = "hinge") or logistic regression (loss = "log)?
pipe$best_params_$clf__estimator__loss

# Make predictions
preds <- pipe$predict(data_splits$x_test)
head(preds)

# Performance on test set #
# Can be done using the pipe's score() method
pipe$score(data_splits$x_test, data_splits$y_test)

# Or using dplyr
data_splits$y_test %>%
  data.frame() %>%
  dplyr::rename(true = '.') %>%
  dplyr::mutate(
    pred = preds,
    check = true == preds,
    check = sum(check) / nrow(.)
  ) %>%
  dplyr::pull(check) %>%
  unique

# We can also use other metrics, such as the Class Balance Accuracy score
pxtextmineR::class_balance_accuracy_score_r(
  data_splits$y_test,
  preds
)

nhs-r-community/pxtextmineR documentation built on Dec. 22, 2021, 2:10 a.m.