R/global_variable_importance.R

Defines functions global_variable_importance

Documented in global_variable_importance

#' Global Variable Importance measure based on Partial Dependence profiles.
#'
#' This function calculate global importance measure.
#'
#' @param profiles \code{data.frame} generated by \code{DALEX::model_profile()} or \code{DALEX::variable_profile()}
#' @return A \code{data.frame} of the class \code{global_variable_importance}.
#' It's a \code{data.frame} with calculated global variable importance measure.
#' @examples
#'
#'
#' library("DALEX")
#' data(apartments)
#'
#' library("randomForest")
#' apartments_rf_model <- randomForest(m2.price ~ construction.year + surface +
#'                                     floor + no.rooms, data = apartments)
#'
#' explainer_rf <- explain(apartments_rf_model, data = apartmentsTest[,2:5],
#'                         y = apartmentsTest$m2.price)
#'
#' profiles <- model_profile(explainer_rf)
#'
#' library("vivo")
#' global_variable_importance(profiles)
#'
#'
#' @export
#'


global_variable_importance <- function(profiles){
  if(any(names(profiles) %in% c("cp_profiles"))){
    names_profiles <- profiles$cp_profiles
  }
  if(any(names(profiles) %in% c("agr_profiles"))){
    profiles <- profiles$agr_profiles
  }

  if (!(c("aggregated_profiles_explainer") %in% class(profiles)) & !(c("partial_dependence_explainer") %in% class(profiles)))
    stop("The global_variable_importance() function requires an object created with model_profile() or variable_profile() function.")

  is_numeric <- sapply(names(names_profiles[, c(unique(names_profiles$`_vname_`))]), function(x){
    is.numeric(names_profiles[, x])
  })

  if (!all(is_numeric))
    message("The measure of global variable importance is calculated only for numerical variables.")

  vnames <- names(names_profiles[, c(unique(names_profiles$`_vname_`))])[which(unname(is_numeric)== TRUE)]

  avg_yhat <- lapply(unique(profiles$`_vname_`), function(x){
    mean(profiles$`_yhat_`[profiles$`_vname_` == x])
  })
  names(avg_yhat) <- unique(profiles$`_vname_`)

  result <- unlist(lapply(unique(vnames), function(w){
    mean(abs((profiles[profiles$`_vname_` == w, "_yhat_"] - avg_yhat[[w]])))
    }))

  gvivo <- data.frame(variable_name = unique(vnames),
                      measure = result,
                      `_label_model_` = unique(profiles$`_label_`))
  colnames(gvivo) <- c("variable_name", "measure", "_label_model_")
  class(gvivo) = c("global_importance", "data.frame")
  gvivo
}
ModelOriented/vivo documentation built on Sept. 29, 2020, 10:53 p.m.