R/local_variable_importance.R

Defines functions local_variable_importance

Documented in local_variable_importance

#' Local Variable Importance measure based on Ceteris Paribus profiles.
#'
#' This function calculate local importance measure in eight variants. We obtain eight variants measure through the possible options of three parameters such as \code{absolute_deviation}, \code{point} and \code{density}.
#'
#' @param profiles \code{data.frame} generated by \code{DALEX::predict_profile()}, \code{DALEX::individual_profile()} or \code{ingredients::ceteris_paribus()}
#' @param data \code{data.frame} with raw data to model
#' @param absolute_deviation logical parameter, if \code{absolute_deviation = TRUE} then measure is calculated as absolute deviation, else is calculated as a root from average squares
#' @param point logical parameter, if \code{point = TRUE} then measure is calculated as a distance from f(x), else measure is calculated as a distance from average profiles
#' @param density logical parameter, if \code{density = TRUE} then measure is weighted based on the density of variable, else is not weighted
#' @param grid_points maximum number of points for profile calculations, the default values is 101, the same as in \code{ingredients::ceteris_paribus()}, if you use a different on, you should also change here
#' @return A \code{data.frame} of the class \code{local_variable_importance}.
#' It's a \code{data.frame} with calculated local variable importance measure.
#' @examples
#'
#'
#' library("DALEX")
#' data(apartments)
#'
#' library("randomForest")
#' apartments_rf_model <- randomForest(m2.price ~ construction.year + surface +
#'                                     floor + no.rooms, data = apartments)
#'
#' explainer_rf <- explain(apartments_rf_model, data = apartmentsTest[,2:5],
#'                         y = apartmentsTest$m2.price)
#'
#' new_apartment <- data.frame(construction.year = 1998, surface = 88, floor = 2L, no.rooms = 3)
#'
#' profiles <- predict_profile(explainer_rf, new_apartment)
#'
#'
#' library("vivo")
#' local_variable_importance(profiles, apartments[,2:5],
#'                           absolute_deviation = TRUE, point = TRUE, density = TRUE)
#'
#' local_variable_importance(profiles, apartments[,2:5],
#'                           absolute_deviation = TRUE, point = TRUE, density = FALSE)
#'
#' local_variable_importance(profiles, apartments[,2:5],
#'                           absolute_deviation = TRUE, point = FALSE, density = TRUE)
#'
#'
#'
#' @export
#'


local_variable_importance <- function(profiles,
                                      data,
                                      absolute_deviation = TRUE,
                                      point = TRUE,
                                      density = TRUE,
                                      grid_points = 101){

  if (!(c("ceteris_paribus_explainer") %in% class(profiles)) & !(c("predict_profile") %in% class(profiles)))
    stop("The local_variable_importance() function requires an object created with predict_profile() or ceteris_paribus() function.")
  if (!c("data.frame") %in% class(data))
    stop("The local_variable_importance() function requires a data.frame.")

  is_numeric <- sapply(names(profiles[, c(unique(profiles$`_vname_`))]), function(x){
    is.numeric(profiles[, x])
  })

  if (!all(is_numeric))
    message("The measure of local variable importance is calculated only for numerical variables.")

  vnames <- names(profiles[, c(unique(profiles$`_vname_`))])[which(unname(is_numeric)== TRUE)]

  if (density == TRUE){
    if (!c(any(colnames(data) %in% unique(profiles$`_vname_`)))){
      stop("The Ceteris Paribus profiles for variables in data are missing or data include target variable.")
    }else{
      vnames <- colnames(data)[which(colnames(data) %in% vnames)]
    }
  }

  avg_yhat <- lapply(unique(vnames), function(x){
    mean(profiles$`_yhat_`[profiles$`_vname_` == x])
  })
  names(avg_yhat) <- unique(vnames)

  variable_split <- vivo::calculate_variable_split(data, variables = colnames(data), grid_points = grid_points)

  if(density == TRUE)
    weight <- vivo::calculate_weight(profiles, data[, vnames], variable_split = variable_split)

  obs <- attr(profiles, "observations")


  if(absolute_deviation == TRUE){
    if(point == TRUE){
      if(density == TRUE){
        result <- unlist(lapply(unique(vnames), function(m){
          sum(weight[[m]] *(abs(profiles[profiles$`_vname_` == m, "_yhat_"] - unlist(unname(obs["_yhat_"])))))
        }))
      }else{
        result <- unlist(lapply(unique(vnames), function(w){
          mean(abs((profiles[profiles$`_vname_` == w, "_yhat_"] - unlist(unname(obs["_yhat_"])))))
        }))
      }
    }else{
      if(density == TRUE){
        result <- unlist(lapply(unique(vnames), function(m){
          sum(weight[[m]] * (abs(profiles[profiles$`_vname_` == m, "_yhat_"] - avg_yhat[[m]])))
        }))
      }else{
        result <- unlist(lapply(unique(vnames), function(w){
          mean(abs((profiles[profiles$`_vname_` == w, "_yhat_"] - avg_yhat[[w]])))
        }))
      }
    }
  }else{
    if(point == TRUE){
      if(density == TRUE){
        result <- unlist(lapply(unique(vnames), function(m){
          sqrt(sum(weight[[m]] *(profiles[profiles$`_vname_` == m, "_yhat_"] - unlist(unname(obs["_yhat_"])))^2))
        }))
      }else{
        result <- unlist(lapply(unique(vnames), function(w){
          sqrt(mean((profiles[profiles$`_vname_` == w, "_yhat_"] - unlist(unname(obs["_yhat_"])))^2))
        }))
      }
    }else{
      if(density == TRUE){
        result <- unlist(lapply(unique(vnames), function(m){
          sqrt(sum(weight[[m]] * ((profiles[profiles$`_vname_` == m, "_yhat_"] - avg_yhat[[m]])^2)))
        }))
      }else{
        result <- unlist(lapply(unique(vnames), function(w){
          sqrt(mean((profiles[profiles$`_vname_` == w, "_yhat_"] - avg_yhat[[w]])^2))
        }))
      }
    }
  }

  lvivo <- data.frame(variable_name = unique(vnames),
                      measure = result,
                      `_label_model_` = obs$`_label_`,
                      `_label_method_` = paste0('absolute_deviation = ', absolute_deviation, ", point = ", point, ", density = ", density)
                      )
  colnames(lvivo) <- c("variable_name", "measure", "_label_model_", "_label_method_")
  attr(lvivo, "observations") <- obs
  class(lvivo) = c("local_importance", "data.frame")
  lvivo
}
ModelOriented/vivo documentation built on Sept. 29, 2020, 10:53 p.m.