R/calculate_weight.R

Defines functions calculate_weight

Documented in calculate_weight

#' Calculated empirical density and weight based on variable split.
#'
#' This function calculate an empirical density of raw data based on variable split from Ceteris Paribus profiles. Then calculated weight for values generated by \code{DALEX::predict_profile()}, \code{DALEX::individual_profile()} or \code{ingredients::ceteris_paribus()}.
#'
#' @param profiles \code{data.frame} generated by  \code{DALEX::predict_profile()}, \code{DALEX::individual_profile()} or \code{ingredients::ceteris_paribus()}
#' @param data \code{data.frame} with raw data to model
#' @param variable_split list generated by \code{vivo::calculate_variable_split()}
#'
#' @return Return an weight based on empirical density.
#'
#' @examples
#'
#' library("DALEX", warn.conflicts = FALSE, quietly = TRUE)
#' data(apartments)
#'
#' split <- vivo::calculate_variable_split(apartments,
#'                         variables = colnames(apartments),
#'                         grid_points = 101)
#'
#' library("randomForest", warn.conflicts = FALSE, quietly = TRUE)
#' apartments_rf_model <- randomForest(m2.price ~ construction.year + surface +
#'                                     floor + no.rooms, data = apartments)
#'
#' explainer_rf <- explain(apartments_rf_model, data = apartmentsTest[,2:5],
#'                         y = apartmentsTest$m2.price)
#'
#' new_apartment <- data.frame(construction.year = 1998, surface = 88, floor = 2L, no.rooms = 3)
#'
#' profiles <- predict_profile(explainer_rf, new_apartment)
#'
#' library("vivo")
#' calculate_weight(profiles, data = apartments[, 2:5], variable_split = split)
#'
#'
#' @export
#'
calculate_weight <- function(profiles,
                             data,
                             variable_split){
  if (!(c("ceteris_paribus_explainer") %in% class(profiles)) & !(c("predict_profile") %in% class(profiles)))
    stop("The calculate_weight() function requires an object created with predict_profile() or ceteris_paribus() function.")
  if (!(c("list") %in% class(variable_split)))
    stop("The calculate_weight() function requires an object created with calculate_variable_split() function.")
  if (!(c("data.frame") %in% class(data)))
    stop("The calculate_weight() function requires a data.frame.")
  cut_range <- lapply(unique(colnames(data)), function(x){
    data.frame(table(cut(data[, as.vector(as.character(x))],
                         unique(c(min(data[, as.vector(as.character(x))]),
                                  variable_split[[as.character(x)]],
                                  max(data[, as.vector(as.character(x))]))), include.lowest = TRUE))/nrow(data))})
  weight_range <- lapply(unique(colnames(data)), function(x){
    data.frame("Var1" = cut(profiles[profiles$`_vname_` == x, as.vector(as.character(x))],
                            unique(c(min(data[, as.vector(as.character(x))]),
                                     variable_split[[as.character(x)]],
                                     max(data[, as.vector(as.character(x))]))), include.lowest = TRUE),
               "Value" = profiles[profiles$`_vname_` == x, as.vector(as.character(x))])
  })
  names(cut_range) <- as.vector(unique(colnames(data)))
  names(weight_range) <- as.vector(unique(colnames(data)))
  weight <- lapply(as.vector(unique(colnames(data))), function(x){
    unname(unlist(merge(weight_range[[x]], cut_range[[x]], by = "Var1", all.x = TRUE, sort = FALSE)["Freq"]))
  })
  names(weight) <- as.vector(unique(colnames(data)))
  weight
}
ModelOriented/vivo documentation built on Sept. 29, 2020, 10:53 p.m.