use_saved_results <- TRUE knitr::opts_chunk$set( collapse = TRUE, comment = "#>", echo = TRUE, eval = !use_saved_results, message = FALSE, warning = FALSE ) if (use_saved_results) { results <- readRDS("vignette_mc.rds") pred <- results$pred }
library(dplyr); library(tidyr); library(purrr) # Data wrangling library(ggplot2); library(stringr) # Plotting library(tidyfit) # Auto-ML modeling
Multinomial classification is possible in tidyfit
using the methods powered by glmnet
, e1071
and randomForest
(LASSO, Ridge, ElasticNet, AdaLASSO, SVM and Random Forest). Currently, none of the other methods support multinomial classification.^[Feature selection methods such as relief
or chisq
can be used with multinomial response variables. I may also add support for multinomial classification with mboost
in future.] When the response variable contains more than 2 classes, classify
automatically uses a multinomial response for the above-mentioned methods.
Here's an example using the built-in iris
dataset:
data("iris") # For reproducibility set.seed(42) ix_tst <- sample(1:nrow(iris), round(nrow(iris)*0.2)) data_trn <- iris[-ix_tst,] data_tst <- iris[ix_tst,] as_tibble(iris)
Species
The code chunk below fits the above mentioned algorithms on the training split, using a 10-fold cross validation to select optimal penalties. We then obtain out-of-sample predictions using predict
. Unlike binomial classification, the fit
and pred
objects contain a class
column with separate coefficients and predictions for each class. The predictions sum to one across classes:
fit <- data_trn |> classify(Species ~ ., LASSO = m("lasso"), Ridge = m("ridge"), ElasticNet = m("enet"), AdaLASSO = m("adalasso"), SVM = m("svm"), `Random Forest` = m("rf"), `Least Squares` = m("ridge", lambda = 1e-5), .cv = "vfold_cv") pred <- fit |> predict(data_tst)
Note that we can add unregularized least squares estimates by setting lambda = 0
(or very close to zero).
Next, we can use yardstick
to calculate the log loss accuracy metric and compare the performance of the different models:
metrics <- pred |> group_by(model, class) |> mutate(row_n = row_number()) |> spread(class, prediction) |> group_by(model) |> yardstick::mn_log_loss(truth, setosa:virginica) metrics |> mutate(model = str_wrap(model, 11)) |> ggplot(aes(model, .estimate)) + geom_col(fill = "darkblue") + theme_bw() + theme(axis.title.x = element_blank())
The least squares estimate performs poorest, while the random forest (nonlinear) and the support vector machine (SVM) achieve the best results. The SVM is estimated with a linear kernel by default (use kernel = <chosen_kernel>
to use a different kernel).
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.