#' Local Variable Importance measure based on Ceteris Paribus profiles.
#'
#' This function calculate local importance measure in eight variants. We obtain eight variants measure through the possible options of three parameters such as \code{absolute_deviation}, \code{point} and \code{density}.
#'
#' @param profiles \code{data.frame} generated by \code{DALEX::predict_profile()}, \code{DALEX::individual_profile()} or \code{ingredients::ceteris_paribus()}
#' @param data \code{data.frame} with raw data to model
#' @param absolute_deviation logical parameter, if \code{absolute_deviation = TRUE} then measure is calculated as absolute deviation, else is calculated as a root from average squares
#' @param point logical parameter, if \code{point = TRUE} then measure is calculated as a distance from f(x), else measure is calculated as a distance from average profiles
#' @param density logical parameter, if \code{density = TRUE} then measure is weighted based on the density of variable, else is not weighted
#' @param grid_points maximum number of points for profile calculations, the default values is 101, the same as in \code{ingredients::ceteris_paribus()}, if you use a different on, you should also change here
#' @return A \code{data.frame} of the class \code{local_variable_importance}.
#' It's a \code{data.frame} with calculated local variable importance measure.
#' @examples
#'
#'
#' library("DALEX")
#' data(apartments)
#'
#' library("randomForest")
#' apartments_rf_model <- randomForest(m2.price ~ construction.year + surface +
#' floor + no.rooms, data = apartments)
#'
#' explainer_rf <- explain(apartments_rf_model, data = apartmentsTest[,2:5],
#' y = apartmentsTest$m2.price)
#'
#' new_apartment <- data.frame(construction.year = 1998, surface = 88, floor = 2L, no.rooms = 3)
#'
#' profiles <- predict_profile(explainer_rf, new_apartment)
#'
#'
#' library("vivo")
#' local_variable_importance(profiles, apartments[,2:5],
#' absolute_deviation = TRUE, point = TRUE, density = TRUE)
#'
#' local_variable_importance(profiles, apartments[,2:5],
#' absolute_deviation = TRUE, point = TRUE, density = FALSE)
#'
#' local_variable_importance(profiles, apartments[,2:5],
#' absolute_deviation = TRUE, point = FALSE, density = TRUE)
#'
#'
#'
#' @export
#'
local_variable_importance <- function(profiles,
data,
absolute_deviation = TRUE,
point = TRUE,
density = TRUE,
grid_points = 101){
if (!(c("ceteris_paribus_explainer") %in% class(profiles)) & !(c("predict_profile") %in% class(profiles)))
stop("The local_variable_importance() function requires an object created with predict_profile() or ceteris_paribus() function.")
if (!c("data.frame") %in% class(data))
stop("The local_variable_importance() function requires a data.frame.")
is_numeric <- sapply(names(profiles[, c(unique(profiles$`_vname_`))]), function(x){
is.numeric(profiles[, x])
})
if (!all(is_numeric))
message("The measure of local variable importance is calculated only for numerical variables.")
vnames <- names(profiles[, c(unique(profiles$`_vname_`))])[which(unname(is_numeric)== TRUE)]
if (density == TRUE){
if (!c(any(colnames(data) %in% unique(profiles$`_vname_`)))){
stop("The Ceteris Paribus profiles for variables in data are missing or data include target variable.")
}else{
vnames <- colnames(data)[which(colnames(data) %in% vnames)]
}
}
avg_yhat <- lapply(unique(vnames), function(x){
mean(profiles$`_yhat_`[profiles$`_vname_` == x])
})
names(avg_yhat) <- unique(vnames)
variable_split <- vivo::calculate_variable_split(data, variables = colnames(data), grid_points = grid_points)
if(density == TRUE)
weight <- vivo::calculate_weight(profiles, data[, vnames], variable_split = variable_split)
obs <- attr(profiles, "observations")
if(absolute_deviation == TRUE){
if(point == TRUE){
if(density == TRUE){
result <- unlist(lapply(unique(vnames), function(m){
sum(weight[[m]] *(abs(profiles[profiles$`_vname_` == m, "_yhat_"] - unlist(unname(obs["_yhat_"])))))
}))
}else{
result <- unlist(lapply(unique(vnames), function(w){
mean(abs((profiles[profiles$`_vname_` == w, "_yhat_"] - unlist(unname(obs["_yhat_"])))))
}))
}
}else{
if(density == TRUE){
result <- unlist(lapply(unique(vnames), function(m){
sum(weight[[m]] * (abs(profiles[profiles$`_vname_` == m, "_yhat_"] - avg_yhat[[m]])))
}))
}else{
result <- unlist(lapply(unique(vnames), function(w){
mean(abs((profiles[profiles$`_vname_` == w, "_yhat_"] - avg_yhat[[w]])))
}))
}
}
}else{
if(point == TRUE){
if(density == TRUE){
result <- unlist(lapply(unique(vnames), function(m){
sqrt(sum(weight[[m]] *(profiles[profiles$`_vname_` == m, "_yhat_"] - unlist(unname(obs["_yhat_"])))^2))
}))
}else{
result <- unlist(lapply(unique(vnames), function(w){
sqrt(mean((profiles[profiles$`_vname_` == w, "_yhat_"] - unlist(unname(obs["_yhat_"])))^2))
}))
}
}else{
if(density == TRUE){
result <- unlist(lapply(unique(vnames), function(m){
sqrt(sum(weight[[m]] * ((profiles[profiles$`_vname_` == m, "_yhat_"] - avg_yhat[[m]])^2)))
}))
}else{
result <- unlist(lapply(unique(vnames), function(w){
sqrt(mean((profiles[profiles$`_vname_` == w, "_yhat_"] - avg_yhat[[w]])^2))
}))
}
}
}
lvivo <- data.frame(variable_name = unique(vnames),
measure = result,
`_label_model_` = obs$`_label_`,
`_label_method_` = paste0('absolute_deviation = ', absolute_deviation, ", point = ", point, ", density = ", density)
)
colnames(lvivo) <- c("variable_name", "measure", "_label_model_", "_label_method_")
attr(lvivo, "observations") <- obs
class(lvivo) = c("local_importance", "data.frame")
lvivo
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.