measure_cat: Measure Performance for Multi-Class Classification Models

View source: R/measure.R

measure_catR Documentation

Measure Performance for Multi-Class Classification Models

Description

Evaluates the performance of a multi-class classification model using log loss and multiclass AUC.

Usage

measure_cat(obs, pred)

Arguments

obs

A factor vector of observed class labels. Each level represents a unique class.

pred

A numeric matrix of predicted probabilities, where each row corresponds to an observation, and each column corresponds to a class. The number of columns must match the number of levels in obs.

Details

The log loss is calculated as:

-\frac{1}{N} \sum_{i=1}^N \sum_{c=1}^C y_{ic} \log(p_{ic})

where y_{ic} is 1 if observation i belongs to class c, and p_{ic} is the predicted probability for that class.

The AUC is computed using the pROC::multiclass.roc function, which provides an overall measure of model performance for multiclass classification.

Value

A list containing:

log_loss

The negative log-likelihood averaged across observations.

ROC

ROC generated using pROC::roc

AUC

The multiclass Area Under the Curve (AUC) as computed by pROC::multiclass.roc.

Examples

library(pROC)
obs <- factor(c("A", "B", "C"), levels = LETTERS[1:3])
pred <- matrix(
  c(
    0.8, 0.1, 0.1,
    0.2, 0.6, 0.2,
    0.7, 0.2, 0.1
  ),
  nrow = 3, byrow = TRUE
)
measure_cat(obs, pred)
# Returns: list(log_loss = 1.012185, ROC = <ROC>, AUC = 0.75)


bnns documentation built on April 3, 2025, 6:12 p.m.