R/plots.R

Defines functions get_shap_values get_feature_importance get_partial_dependence get_accumulated_dependence get_break_down get_ceteris_paribus get_rec get_roc get_metrics get_funnel_measure get_subsets_performance get_fairness get_variable_distribution get_variable_against_another get_message_output get_dataset_plots get_global_plots get_local_plots

Documented in get_accumulated_dependence get_break_down get_ceteris_paribus get_dataset_plots get_fairness get_feature_importance get_funnel_measure get_global_plots get_local_plots get_message_output get_metrics get_partial_dependence get_rec get_roc get_shap_values get_subsets_performance get_variable_against_another get_variable_distribution

#' Internal function for calculating local plots for all observations
#'
#' Function runs all plot generating methods for given observations
#' 
#' @param explainer Explainer created using \code{DALEX::explain}
#' @param observations Data frame of observations
#' @param params Params from arena object 
#' @return list of generated plots' data
get_local_plots <- function(explainer, observations, params) {
  is_y <- sapply(explainer$data, function(v) identical(v, explainer$y))
  vars <- intersect(names(is_y[!is_y]), colnames(observations))

  # observations were validated and have min 1 row
  obs_list <- lapply(1:nrow(observations), function(i) observations[i, vars])

  get_bd <- function(obs) get_break_down(explainer, obs, params)
  get_shap <- function(obs) get_shap_values(explainer, obs, params)
  get_cp <- function(obs) lapply(vars, function(v) {
    get_ceteris_paribus(explainer, obs, v, params)
  })

  if (is.null(params$cl)) { # single thread if cluster was not provided
    bd <- lapply(obs_list, get_bd)
    sp <- lapply(obs_list, get_shap)
    cp <- lapply(obs_list, get_cp)
    cp <- unlist(cp, recursive = FALSE)
  } else {
    # Export variables and functions to cluster
    to_export <- c(
      "explainer",
      "params",
      "vars",
      "get_break_down",
      "get_shap_values",
      "get_ceteris_paribus",
      "get_message_output"
    )
    parallel::clusterExport(params$cl, to_export, envir=environment())
    # Load model's library to access predict function
    parallel::clusterEvalQ(
      params$cl,
      library(explainer$model_info$package, character.only=TRUE)
    )
    bd <- parallel::parLapply(params$cl, obs_list, get_bd)
    sp <- parallel::parLapply(params$cl, obs_list, get_shap)
    cp <- parallel::parLapply(params$cl, obs_list, get_cp)
    cp <- unlist(cp, recursive = FALSE)
  }
  
  # Join results into one list
  c(
    bd[!sapply(bd, is.null)],
    sp[!sapply(sp, is.null)],
    cp[!sapply(cp, is.null)]
  )
}

#' Internal function for calculating global plots
#'
#' Function runs all plot generating methods for given explainer
#' 
#' @param explainer Explainer created using \code{DALEX::explain}
#' @param params Params from arena object 
#' @return list of generated plots' data
get_global_plots <- function(explainer, params) {
  is_y <- sapply(explainer$data, function(v) identical(v, explainer$y))
  vars <- names(is_y[!is_y])

  global_plots_names <- c("roc", "rec", "metrics", "subsets_performance", "funnel_measure", "feature_importance")
  global_plots <- lapply(global_plots_names, function(n) get(paste0("get_", n)))

  get_global <- function(f) f(explainer, params)
  get_pd <- function(v) get_partial_dependence(explainer, v, params)
  get_ad <- function(v) get_accumulated_dependence(explainer, v, params)
  get_fr <- function(v) get_fairness(explainer, v, params)

  if (is.null(params$cl)) { # single thread if cluster was not provided
    globals <- lapply(global_plots, get_global)
    pd <- lapply(vars, get_pd)
    ad <- lapply(vars, get_ad)
    fr <- lapply(vars, get_fr)
  } else {
    # Export variables and functions to cluster
    to_export <- c(
      "explainer",
      "params",
      "calculate_subsets_performance",
      "get_partial_dependence",
      "get_accumulated_dependence",
      "get_fairness",
      "get_message_output"
    )
    parallel::clusterExport(params$cl, to_export, envir=environment())
    # Load model's library to access predict function
    parallel::clusterEvalQ(
      params$cl,
      library(explainer$model_info$package, character.only=TRUE)
    )
    globals <- parallel::parLapply(params$cl, global_plots, get_global)
    pd <- parallel::parLapply(params$cl, vars, get_pd)
    ad <- parallel::parLapply(params$cl, vars, get_ad)
    fr <- parallel::parLapply(params$cl, vars, get_fr)
  }

  # filter out plots only for classificators
  if (explainer$model_info$type != 'classification') {
    fr <- list()
    globals <- globals[sapply(globals, function(p) p$plotType != "ROC")]
  }

  # Join results into one list
  c(
    pd[!sapply(pd, is.null)],
    ad[!sapply(ad, is.null)],
    fr[!sapply(fr, is.null)],
    globals[!sapply(globals, is.null)]
  )
}

#' Internal function for calculating exploratory data anaylysis plots
#'
#' Function runs all plot generating methods for given dataset
#' 
#' @param dataset List with following elements 
#' \itemize{
#'   \item{dataset}{ Data frame}
#'   \item{target}{ Name of one column from data frame that is used as target variable}
#'   \item{label}{ Label for dataset to be displayed in Arena}
#'   \item{variables}{ vector of column names from data frame without target}
#' }
#' @param params Params from arena object 
#' @return list of generated plots' data
get_dataset_plots <- function(dataset, params) {
  # Helper methods to reduce arguments length
  get_vd <- function(v) get_variable_distribution(dataset, v, params)
  get_vaa <- function(v) get_variable_against_another(dataset, v, params)

  vd <- lapply(dataset$variables, get_vd)
  vaa <- lapply(dataset$variables, get_vaa)

  # Join results and filter out null
  c(
    vd[!sapply(vd, is.null)],
    vaa[!sapply(vaa, is.null)]
  )
}

#' Internal function for returning message as plot data
#'
#' This method modify exisiting plot's data in Arena's format
#' to show message instead of chart.
#' @param output existing plot data to be overwritten
#' @param type type of message "info" or "error"
#' @param msg message to be displayed
#' @return Plot data in Arena's format
get_message_output <- function(output, type, msg) {
  output$plotComponent <- "Message"
  output$data <- list(message = msg, type = type)
  output
}

#' Internal function for variable against another plot
#'
#' @param dataset List with following elements 
#' \itemize{
#'   \item{dataset}{ Data frame}
#'   \item{target}{ Name of one column from data frame that is used as target variable}
#'   \item{label}{ Label for dataset to be displayed in Arena}
#'   \item{variables}{ vector of column names from data frame without target}
#' }
#' @param variable Name of primary variable
#' @param params Params from arena object 
#' @return Plot data in Arena's format
get_variable_against_another <- function(dataset, variable, params) {
  output <- NULL
  tryCatch({
    output <- list(
      plotType = "VariableAgainstAnother",
      plotCategory = "EDA",
      name = "Variable Against Another",
      plotComponent = "VariableAgainstAnother",
      params = list(dataset = dataset$label, variable = variable)
    )
    # primary variable
    first <- dataset$dataset[, variable]
    if (is.logical(first)) first  <- as.factor(first)
    # for each other variable (including target) compute plot
    vars <- sapply(dataset$dataset, function(secondary) {
      if (is.logical(secondary)) secondary <- as.factor(secondary)
      if (identical(first, secondary)) return(NULL)
      # count table for two factors
      if (is.factor(secondary) && is.factor(first)) {
        tab <- table(first, secondary)
        return(
          list(
            type = "table",
            counts = lapply(seq_len(nrow(tab)), function(i) unname(tab[i, ])),
            first = dimnames(tab)[[1]],
            secondary = dimnames(tab)[[2]]
          )
        )
      # scatter plot for two numeric vectors
      } else if (is.numeric(secondary) && is.numeric(first)) {
        # get subset of points
        points_number <- min(params$vaa_points_number, length(first))
        points <- sample(seq_len(length(first)), size = points_number)
        return(list(type = "scatter", first = first[points], secondary = secondary[points]))
      # boxplots of secondary variable
      } else if (is.factor(first) && is.numeric(secondary)) {
        boxes <- lapply(levels(first), function(lev) {
          filtered <- secondary[first == lev]
          quantiles <- quantile(filtered, probs = c(0.25, 0.5, 0.75))
          iqr <- quantiles[3] - quantiles[1]
          # lower, upper fences
          lf <- max(quantiles[1] - (1.5 * iqr), min(filtered))
          uf <- min(quantiles[3] + (1.5 * iqr), max(filtered))
          list(
            q1 = quantiles[1],
            q3 = quantiles[3],
            mean = mean(filtered),
            median = quantiles[2],
            lf = lf,
            uf = uf,
            outliers = filtered[filtered > uf | filtered < lf]
          )
        })
        return(list(type = "boxplots", first = levels(first), secondary = boxes, numerical = "secondary"))
      # boxplots of primary variable
      } else if (is.numeric(first) && is.factor(secondary)) {
        boxes <- lapply(levels(secondary), function(lev) {
          filtered <- first[secondary == lev]
          quantiles <- quantile(filtered, probs = c(0.25, 0.5, 0.75))
          iqr <- quantiles[3] - quantiles[1]
          # lower, upper fences
          lf <- max(quantiles[1] - (1.5 * iqr), min(filtered))
          uf <- min(quantiles[3] + (1.5 * iqr), max(filtered))
          list(
            q1 = quantiles[1],
            q3 = quantiles[3],
            mean = mean(filtered),
            median = quantiles[2],
            lf = lf,
            uf = uf,
            outliers = filtered[filtered > uf | filtered < lf]
          )
        })
        return(list(type = "boxplots", secondary = levels(secondary), first = boxes, numerical = "first"))
      } else {
        return(NULL)
      }
    })
    names(vars) <- colnames(dataset$dataset)
    output$data <- as.list(vars[!sapply(vars, is.null)])
  }, error = function(e) {
    stop("Failed to calculate variable against another\n", e)
  })
  output
}

#' Internal function for variable distribution
#'
#' @param dataset List with following elements 
#' \itemize{
#'   \item{dataset}{ Data frame}
#'   \item{target}{ Name of one column from data frame that is used as target variable}
#'   \item{label}{ Label for dataset to be displayed in Arena}
#'   \item{variables}{ vector of column names from data frame without target}
#' }
#' @param variable Name of variable
#' @param params Params from arena object 
#' @return Plot data in Arena's format
get_variable_distribution <- function(dataset, variable, params) {
  output <- NULL
  tryCatch({
    output <- list(
      plotType = "VariableDistribution",
      plotCategory = "EDA",
      name = "Variable Distribution",
      params = list(dataset = dataset$label, variable = variable)
    )
    column <- dataset$dataset[, variable]
    # bars of count/density if variable is categorical
    if (is.factor(column)) {
      counts <- as.numeric(table(column))
      output$data <- list(
        names = levels(column),
        count = counts,
        density = counts / sum(counts)
      )
      output$plotComponent <- "DistributionCounts"
    # histogram if variable is numerical
    } else if (is.numeric(column)) {
      bins <- params$vd_bins
      # get histogram for different bins number
      output$data <- lapply(bins, function(nbins) {
        breaks <- seq(from=min(column), to=max(column), length.out=nbins + 1)
        hist_data <- graphics::hist(column, plot=FALSE, breaks=breaks)
        list(
          breaks = breaks,
          mids = hist_data$mids,
          density = hist_data$density,
          counts = hist_data$counts
        )
      })
      names(output$data) <- bins
      output$plotComponent <- "DistributionHistogram"
    } else {
      return(get_message_output(output, "info", "Distribution is available only for numerical and categorical columns"))
    }
  }, error = function(e) {
    stop("Failed to calculate variable distribution\n", e)
  })
  output
}

#' Internal function for calculating fairness
#'
#' @param explainer Explainer created using \code{DALEX::explain}
#' @param variable Name of variable
#' @param params Params from arena object 
#' @return Plot data in Arena's format
get_fairness <- function(explainer, variable, params) {
  output <- NULL
  tryCatch({
    output <- list(
      plotType = "Fairness",
      plotCategory = "Dataset Level",
      plotComponent = "Fairness",
      name = "Fairness",
      params = list(model = explainer$label, variable = variable)
    )
    if (explainer$model_info$type != 'classification') {
      return(get_message_output(output, "info", "Fairness plot is only available for classificators"))
    }
    protected <- explainer$data[, variable]
    if (!is.factor(protected)) {
      return(get_message_output(output, "info", "Select categorical variable to check fairness"))
    }
    subgroups <- levels(protected)
    # for every cutoff level get group metric matrix
    gmm_list <- lapply(params$fairness_cutoffs, function(cutoff) {
      # make cutoff a const list
      cutoff_list <- as.list(rep(cutoff, length(subgroups)))
      names(cutoff_list) <- subgroups
      # calculate confusion matrices for each subgroup
      stopifnot(is.logical(explainer$y) || is.numeric(explainer$y))
      gm <- fairmodels::group_matrices(
        protected = protected,
        probs = explainer$y_hat,
        preds = as.numeric(explainer$y),
        cutoff = cutoff_list
      )
      # group metric matrix
      gmm <- fairmodels::calculate_group_fairness_metrics(gm)
      data.frame(
        value = as.vector(gmm),
        subgroup = rep(colnames(gmm), each=nrow(gmm)),
        metric = rep(rownames(gmm), ncol(gmm)),
        cutoff = cutoff
      )
    }) 
    gmm <- do.call('rbind', gmm_list)
    # Split data frame by subgroup names
    gmm_transformed <- lapply(split(gmm, gmm$subgroup), function(x) {
      # for each subgroup there is a list for each cutoff
      lapply(split(x, x$cutoff), function(x) {
        # the element for specified subgroup and cutoff
        # is a named list with values for metrics
        structure(as.list(x$value), names=x$metric)
      })
    })
    output$data <- list(subgroups = gmm_transformed)
  }, error = function(e) {
    stop("Failed to calculate fairness\n", e)
  })
  output
}


#' Internal function for calculating subset performance
#'
#' @param explainer Explainer created using \code{DALEX::explain}
#' @param params Params from arena object 
#' @return Plot data in Arena's format
get_subsets_performance <- function(explainer, params) {
  output <- NULL
  tryCatch({
    # get loss function and its name based on model type
    score_functions <- switch(
      explainer$model_info$type,
      regression = list(
        RMSE=auditor::score_rmse,
        MSE=auditor::score_mse,
        R2=auditor::score_r2,
        MAE=auditor::score_mae
      ),
      classification = list(
        Accuracy=auditor::score_acc,
        Recall=auditor::score_recall,
        Precision=auditor::score_precision,
        Specificity=auditor::score_specificity,
        AUC=auditor::score_auc,
        F1=auditor::score_f1
      ),
      stop(explainer$model_info$type, " is not recognized as task name")
    )
    scores <- calculate_subsets_performance(
      explainer,
      score_functions = score_functions,
      nbins = params$fm_nbins,
      cutoff = params$fm_cutoff,
      factor_conversion_threshold = params$fm_factor_threshold
    )
    splited <- split(scores, scores$Variable)
    output_data <- lapply(names(score_functions), function(score_name) {
      list(
        scoreValues = lapply(splited, function(x) structure(as.list(x[, score_name]), names=x$Label)),
        base = score_functions[[score_name]](explainer)$score
      )
    })
    names(output_data) <- names(score_functions)
    output <- list(
      plotComponent = "SubsetsPerformance",
      plotType = "SubsetsPerformance",
      plotCategory = "Model Performance",
      name = "Subsets Performance",
      params = list(
        model = explainer$label
      ),
      data = output_data
    )
  }, error = function(e) {
    stop("Failed to calculate subsets performance\n", e)
  })
  output
}

#' Internal function for calculating funnel measure
#'
#' @param explainer Explainer created using \code{DALEX::explain}
#' @param params Params from arena object 
#' @return Plot data in Arena's format
get_funnel_measure <- function(explainer, params) {
  output <- NULL
  tryCatch({
    # get loss function and its name based on model type
    score_functions <- switch(
      explainer$model_info$type,
      regression = list(MSE=auditor::score_mse),
      classification = list("ONE_MINUS_AUC"=auditor::score_one_minus_auc),
      stop(explainer$model_info$type, " is not recognized as task name")
    )
    scores <- calculate_subsets_performance(
      explainer,
      score_functions = score_functions,
      nbins = params$fm_nbins,
      cutoff = params$fm_cutoff,
      factor_conversion_threshold = params$fm_factor_threshold
    )
    splited <- split(scores, scores$Variable)
    transformed <- lapply(splited, function(x) structure(as.list(x[, names(score_functions)]), names=x$Label))
    output <- list(
      plotComponent = "FunnelMeasure",
      plotType = "FunnelMeasure",
      plotCategory = "Model Performance",
      name = "Funnel Plot",
      params = list(
        model = explainer$label
      ),
      data = list(
        lossValues = transformed,
        lossFunction = gsub("ONE_MINUS_", "1 - ", names(score_functions))
      )
    )
  }, error = function(e) {
    stop("Failed to calculate funnel measure\n", e)
  })
  output
}

#' Internal function for calculating model performance metrics
#'
#' @param explainer Explainer created using \code{DALEX::explain}
#' @param params Params from arena object 
#' @return Plot data in Arena's format
get_metrics <- function(explainer, params) {
  output <- NULL
  tryCatch({
    perf <- DALEX::model_performance(explainer)$measures
    output <- list(
      plotComponent = "Metrics",
      plotType = "Metrics",
      plotCategory = "Model Performance",
      name = "Metrics",
      params = list(
        model = explainer$label
      ),
      data = perf
    )
  }, error = function(e) {
    stop("Failed to calculate Metrics\n", e)
  })
  output
}

#' Internal function for calculating receiver operating curve
#'
#' @param explainer Explainer created using \code{DALEX::explain}
#' @param params Params from arena object 
#' @return Plot data in Arena's format
#' @importFrom stats runif
get_roc <- function(explainer, params) {
  output <- NULL
  tryCatch({
    output <- list(
      plotComponent = "ROC",
      plotType = "ROC",
      plotCategory = "Model Performance",
      name = "Receiver Operating Characterstic",
      params = list(model = explainer$label)
    )
    if (explainer$model_info$type != 'classification') {
      return(get_message_output(output, "info", "ROC plot is only available for classificators"))
    }
    eva <- auditor::model_evaluation(explainer)
    if (nrow(eva) > params$roc_grid_points) {
      # take random points
      points <- round(runif(n=params$roc_grid_points, min=1, max=nrow(eva)))
      eva <- eva[points, ]
    }
    eva <- eva[order(eva$`_fpr_`, decreasing = TRUE), ]
    output$data <- list(
      specifity = 1 - eva$`_fpr_`,
      sensivity = eva$`_tpr_`,
      cutoff = eva$`_cutoffs_`
    )
  }, error = function(e) {
    stop("Failed to calculate ROC\n", e)
  })
  output
}

#' Internal function for calculating regression error characteristic
#'
#' @param explainer Explainer created using \code{DALEX::explain}
#' @param params Params from arena object 
#' @return Plot data in Arena's format
#' @importFrom stats runif
get_rec <- function(explainer, params) {
  output <- NULL
  tryCatch({
    res <- auditor::model_residual(explainer)
    make_dataframe <- utils::getFromNamespace("make_dataframe", "auditor")
    df <- make_dataframe(res, type="rec")
    if (nrow(df) > params$rec_grid_points) {
      # take random points
      points <- round(runif(n=params$rec_grid_points, min=1, max=nrow(df)))
      df <- df[sort(points), ]
    }
    
    output <- list(
      plotComponent = "REC",
      plotType = "REC",
      plotCategory = "Model Performance",
      name = "Regression Error Characteristic",
      params = list(
        model = explainer$label
      ),
      data = list(
        tolerance = df$`_rec_x_`,
        quantity = df$`_rec_y_`
      )
    )
  }, error = function(e) {
    stop("Failed to calculate REC\n", e)
  })
  output
}

#' Internal function for calculating Ceteris Paribus
#'
#' @param explainer Explainer created using \code{DALEX::explain}
#' @param observation One row data frame observation
#' @param variable Name of variable
#' @param params Params from arena object
#' @return Plot data in Arena's format
get_ceteris_paribus <- function(explainer, observation, variable, params) {
  output <- NULL
  tryCatch({
    cp <- ingredients::ceteris_paribus(
      explainer,
      observation, 
      variables = variable,
      grid_points = params$cp_grid_points
    )
    cp <- cp[cp$`_vname_` == variable,]
    is_num <- is.numeric(explainer$data[, variable])
    output <- list(
      plotComponent = ifelse(is_num,
        "NumericalCeterisParibus",
        "CategoricalCeterisParibus"
      ),
      plotType = "CeterisParibus",
      plotCategory = "Observation Level",
      name = "Ceteris Paribus",
      params = list(
        model = explainer$label,
        observation = rownames(observation),
        variable = variable
      ),
      data = list(
        x = cp[, variable],
        y = cp$`_yhat_`,
        min = min(cp$`_yhat`),
        max = max(cp$`_yhat_`),
        variable = variable,
        observation = as.list(attr(cp, "observations"))
      )
    )
  }, error = function(e) {
    stop("Failed to calculate ceteris paribus\n", e)
  })
  output
}

#' Internal function for calculating Break Down
#'
#' @param explainer Explainer created using \code{DALEX::explain}
#' @param observation One row data frame observation
#' @param params Params from arena object
#' @return Plot data in Arena's format
get_break_down <- function(explainer, observation, params) {
  output <- NULL
  tryCatch({
    bd <- iBreakDown::local_attributions(explainer, observation)
    # remove rows: intercept and prediction 
    vars_index <- 2:(length(bd$variable) - 1)
    output <- list(
      plotComponent = "Breakdown",
      plotCategory = "Observation Level",
      plotType = "Breakdown",
      name = "Break Down",
      params = list(
        model = explainer$label,
        observation = rownames(observation)
      ),
      data = list(
        variables = bd$variable_name[vars_index],
        intercept = bd$contribution[1],
        prediction = bd$cumulative[length(bd$cumulative)],
        variables_value = bd$variable_value[vars_index],
        contribution = bd$contribution[vars_index]
      )
    )
  }, error = function(e) {
    stop("Failed to calculate break down\n", e)
  })
  output
}

#' Internal function for calculating Accumulated Dependence
#'
#' @param explainer Explainer created using \code{DALEX::explain}
#' @param variable Name of variable
#' @param params Params from arena object
#' @return Plot data in Arena's format
get_accumulated_dependence <- function(explainer, variable, params) {
  output <- NULL
  tryCatch({
    is_num <- is.numeric(explainer$data[, variable])
    pd <- ingredients::accumulated_dependence(
      explainer,
      variables = variable,
      variable_type = ifelse(is_num, "numerical", "categorical"),
      grid_points = params$ad_grid_points,
      N = params$ad_N
    )
    pd <- pd[pd$`_vname_` == variable, ]
    output <- list(
      data = list(
        x = pd$`_x_`,
        y = pd$`_yhat_`,
        variable = variable,
        base = 0
      ),
      plotType = "AccumulatedDependence",
      plotCategory = "Dataset Level",
      plotComponent = ifelse(is_num,
        "LinearDependence",
        "CategoricalDependence"
      ),
      name = "Accumulated Dependence",
      params = list(model = explainer$label, variable = variable)
    )
  }, error = function(e) {
    stop("Failed to calculate accumulated dependence\n", e)
  })
  output
}

#' Internal function for calculating Partial Dependence
#'
#' @param explainer Explainer created using \code{DALEX::explain}
#' @param variable Name of variable
#' @param params Params from arena object
#' @return Plot data in Arena's format
get_partial_dependence <- function(explainer, variable, params) {
  output <- NULL
  tryCatch({
    is_num <- is.numeric(explainer$data[, variable])
    pd <- ingredients::partial_dependence(
      explainer,
      variables = variable,
      variable_type = ifelse(is_num, "numerical", "categorical"),
      grid_points = params$pd_grid_points,
      N = params$pd_N
    )
    pd <- pd[pd$`_vname_` == variable, ]
    output <- list(
      data = list(
        x = pd$`_x_`,
        y = pd$`_yhat_`,
        variable = variable,
        base = attr(pd, "mean_prediction")
      ),
      plotType = "PartialDependence",
      plotCategory = "Dataset Level",
      plotComponent = ifelse(is_num,
        "LinearDependence",
        "CategoricalDependence"
      ),
      name = "Partial Dependence",
      params = list(model = explainer$label, variable = variable)
    )
  }, error = function(e) {
    stop("Failed to calculate partial dependence\n", e)
  })
  output
}

#' Internal function for calculating feature importance
#'
#' @param explainer Explainer created using \code{DALEX::explain}
#' @param params Params from arena object 
#' @return Plot data in Arena's format
#' @importFrom stats quantile
get_feature_importance <- function(explainer, params) {
  output <- NULL
  tryCatch({
    is_y <- sapply(explainer$data, function(v) identical(v, explainer$y))
    vars <- names(is_y[!is_y])
    fi <- ingredients::feature_importance(
      explainer,
      variables = vars,
      B = params$fi_B,
      n_sample = params$fi_n_sample
    )
    stats <- data.frame(
      min = tapply(fi$dropout_loss, fi$variable, min, na.rm = TRUE),
      q1 = tapply(fi$dropout_loss, fi$variable, quantile, 0.25, na.rm = TRUE),
      q3 = tapply(fi$dropout_loss, fi$variable, quantile, 0.75, na.rm = TRUE),
      max = tapply(fi$dropout_loss, fi$variable, max, na.rm = TRUE)
    )

    perm0 <- merge(
      fi[fi$permutation == 0,],
      cbind(rownames(stats), stats),
      by.x = "variable",
      by.y = "rownames(stats)"
    )
    # rm permutation column
    perm0 <- subset(perm0, select = setdiff(colnames(perm0), "permutation"))
    # leave only rows for variables, not for full model and baseline
    vars_only <- perm0[!(perm0$variable %in% c("_baseline_", "_full_model_")), ]
    vars_only <- vars_only[order(vars_only$dropout_loss, decreasing = TRUE), ]
    
    output <- list(
      data = list(
        base = perm0$dropout_loss[perm0$variable == "_full_model_"],
        variables = as.character(vars_only$variable),
        dropout_loss = vars_only$dropout_loss,
        min = vars_only$min,
        max = vars_only$max,
        q1 = vars_only$q1,
        q3 = vars_only$q3
      ),
      plotType = "FeatureImportance",
      plotCategory = "Dataset Level",
      plotComponent = "FeatureImportance",
      name = "Variable Importance",
      params = list(model = explainer$label)
    )
  }, error = function(e) {
    stop("Failed to calculate feature importance\n", e)
  })
  output
}

#' Internal function for calculating Shapley Values
#'
#' @param explainer Explainer created using \code{DALEX::explain}
#' @param observation One row data frame observation to calculate Shapley Values
#' @param params Params from arena object
#' @return Plot data in Arena's format
#' @importFrom stats quantile
get_shap_values <- function(explainer, observation, params) {
  output <- NULL
  params
  tryCatch({
    sp <- iBreakDown::shap(
      explainer,
      observation,
      B = params$shap_B
    )
    perm0 <- sp[sp$B == 0, ]
    sp <- sp[sp$B != 0, ]

    stats <- data.frame(
      min = tapply(sp$contribution, sp$variable, min, na.rm = TRUE),
      q1 = tapply(sp$contribution, sp$variable, quantile, 0.25, na.rm = TRUE),
      q3 = tapply(sp$contribution, sp$variable, quantile, 0.75, na.rm = TRUE),
      max = tapply(sp$contribution, sp$variable, max, na.rm = TRUE)
    )

    perm0 <- merge(
      perm0,
      cbind(rownames(stats), stats),
      by.x = "variable",
      by.y = "rownames(stats)"
    )
    # rm permutation column
    perm0 <- subset(perm0, select = setdiff(colnames(perm0), "B"))
    perm0 <- perm0[order(abs(perm0$contribution), decreasing = TRUE), ]

    output <- list(
      data = list(
        intercept = attr(sp, "intercept"),
        variables = as.character(perm0$variable_name),
        variables_value = perm0$variable_value,
        mean = perm0$contribution,
        min = perm0$min,
        max = perm0$max,
        q1 = perm0$q1,
        q3 = perm0$q3
      ),
      plotType = "SHAPValues",
      plotCategory = "Observation Level",
      plotComponent = "SHAPValues",
      name = "Shapley Values",
      params = list(
        model = explainer$label,
        observation = rownames(observation)
      )
    )
  }, error = function(e) {
    stop("Failed to calculate Shapley Values\n", e)
  })
  output
}

Try the arenar package in your browser

Any scripts or data that you put into this service are public.

arenar documentation built on Oct. 23, 2020, 8:03 p.m.