calculate.rulelist: 'calculate' metrics for a rulelist

View source: R/rulelist.R

calculate.rulelistR Documentation

calculate metrics for a rulelist

Description

Computes some metrics (based on estimation_type) in cumulative window function style over the rulelist (in the same order) ignoring the keys.

Usage

## S3 method for class 'rulelist'
calculate(x, metrics_to_exclude = NULL, ...)

Arguments

x

A rulelist

metrics_to_exclude

(character vector) Names of metrics to exclude

...

Named list of custom metrics. See 'details'.

Details

Default Metrics

These metrics are calculated by default:

  • cumulative_coverage: For nth rule in the rulelist, number of distinct row_nbrs (of new_data) covered by nth and all preceding rules (in order). In weighted case, we sum the weights corresponding to the distinct row_nbrs.

  • cumulative_overlap: Up til nth rule in the rulelist, number of distinct row_nbrs (of new_data) already covered by some preceding rule (in order). In weighted case, we sum the weights corresponding to the distinct row_nbrs.

For classification:

  • cumulative_accuracy: For nth rule in the rulelist, fraction of row_nbrs such that RHS matches the y_name column (of new_data) by nth and all preceding rules (in order). In weighted case, weighted accuracy is computed.

For regression:

  • cumulative_RMSE: For nth rule in the rulelist, weighted RMSE of all predictions (RHS) predicted by nth rule and all preceding rules.

Custom metrics

Custom metrics to be computed should be passed a named list of function(s) in .... The custom metric function should take these arguments in same order: rulelist, new_data, y_name, weight. The custom metric function should return a numeric vector of same length as the number of rows of rulelist.

Value

A dataframe of metrics with a rule_nbr column.

See Also

rulelist, tidy, augment, predict, calculate, prune, reorder

Examples

library("magrittr")
model_c5  = C50::C5.0(Attrition ~., data = modeldata::attrition, rules = TRUE)
tidy_c5   = tidy(model_c5) %>%
            set_validation_data(modeldata::attrition, "Attrition") %>%
            set_keys(NULL)

# calculate default metrics (classification)
calculate(tidy_c5)

model_rpart = rpart::rpart(MonthlyIncome ~., data = modeldata::attrition)
tidy_rpart  =
  tidy(model_rpart) %>%
  set_validation_data(modeldata::attrition, "MonthlyIncome") %>%
  set_keys(NULL)

# calculate default metrics (regression)
calculate(tidy_rpart)

# calculate default metrics with a custom metric
#' custom function to get cumulative MAE
library("tidytable")
get_cumulative_MAE = function(rulelist, new_data, y_name, weight){

  priority_df =
    rulelist %>%
    select(rule_nbr) %>%
    mutate(priority = 1:nrow(rulelist)) %>%
    select(rule_nbr, priority)

  pred_df =
    predict(rulelist, new_data) %>%
    left_join(priority_df, by = "rule_nbr") %>%
    mutate(weight = local(weight)) %>%
    select(rule_nbr, row_nbr, weight, priority)

  new_data2 =
    new_data %>%
    mutate(row_nbr = 1:n()) %>%
    select(all_of(c("row_nbr", y_name)))

  rmse_till_rule = function(rn){

    if (is.character(rulelist$RHS)) {
      inter_df =
        pred_df %>%
        tidytable::filter(priority <= rn) %>%
        left_join(mutate(new_data, row_nbr = 1:n()), by = "row_nbr") %>%
        left_join(select(rulelist, rule_nbr, RHS), by = "rule_nbr") %>%
        nest(.by = c("RHS", "rule_nbr", "row_nbr", "priority", "weight")) %>%
        mutate(RHS = purrr::map2_dbl(RHS,
                                     data,
                                     ~ eval(parse(text = .x), envir = .y)
                                     )
               ) %>%
        unnest(data)
    } else {

      inter_df =
        pred_df %>%
        tidytable::filter(priority <= rn) %>%
        left_join(new_data2, by = "row_nbr") %>%
        left_join(select(rulelist, rule_nbr, RHS), by = "rule_nbr")
    }

    inter_df %>%
      summarise(rmse = MetricsWeighted::mae(RHS,
                                             .data[[y_name]],
                                             weight,
                                             na.rm = TRUE
                                             )
                ) %>%
      `[[`("rmse")
  }

  res = purrr::map_dbl(1:nrow(rulelist), rmse_till_rule)
  return(res)
}

calculate(tidy_rpart,
          metrics_to_exclude = NULL,
          list("cumulative_mae" = get_cumulative_MAE)
          )


tidyrules documentation built on June 30, 2024, 1:07 a.m.