R/explore.R

Defines functions sample_locally give_predictions add_predictions

Documented in add_predictions sample_locally

#' Generate dataset for local exploration.
#'
#' @param data Data frame from which new dataset will be simulated.
#' @param explained_instance One row data frame with the same variables
#'        as in data argument. Local exploration will be performed around this observation.
#' @param explained_var Name of a column with the variable to be predicted.
#' @param size Number of observations is a simulated dataset.
#' @param method If "live", new observations will be created by changing one value
#'        per observation. If "permute", new observation will be created by permuting  all
#'        columns of data. If "normal", numerical features will be sampled from multivariate
#'        normal distribution specified by ... arguments mu and Sigma.
#' @param fixed_variables names or numeric indexes of columns which will not be changed
#'        while sampling.
#' @param seed Seed to set before sampling. If NULL, results will not be reproducible.
#' @param ... Mean and covariance matrix for normal sampling method.
#'
#' @importFrom stats sd
#'
#' @return list of class "live_explorer" consisting of
#' \item{data}{Dataset generated by sample_locally function with response variable.}
#' \item{target}{Name of the response variable.}
#' \item{explained_instance}{Instance that is being explained.}
#' \item{sampling_method}{Name of used sampling method}
#' \item{fixed_variables}{Names of variables which were not sampled}
#' \item{sdevations}{Standard deviations of numerical variables}
#'
#' @export
#'
#' @examples
#' \dontrun{
#' dataset_for_local_exploration <- sample_locally(data = wine,
#'                                                 explained_instance = wine[5, ],
#'                                                 explained_var = "quality",
#'                                                 size = 50)
#' }
#'

sample_locally <- function(data, explained_instance, explained_var, size,
                           method = "live", fixed_variables = NULL, seed = NULL, ...) {
  check_conditions(data, explained_instance, size)
  explained_var_col <- which(colnames(data) == explained_var)
  if(method == "live") {
    similar <- generate_neighbourhood(data[, -explained_var_col],
                                      explained_instance[, -explained_var_col], 
                                      size,
                                      fixed_variables,
                                      seed)
  } else if(method == "permute") {
    similar <- permutation_neighbourhood(data[, -explained_var_col],
                                         explained_instance[, -explained_var_col],
                                         size,
                                         fixed_variables,
                                         seed)
  } else {
    similar <- normal_neighbourhood(data[, -explained_var_col],
                                    explained_instance[, -explained_var_col],
                                    size,
                                    fixed_variables,
                                    seed,
                                    ...)
  }

  sds = sapply(dplyr::select_if(data[, -explained_var_col], is.numeric), sd)
  
  explorer <- list(data = similar,
       target = explained_var,
       explained_instance = explained_instance,
       sampling_method = method,
       fixed_variables = fixed_variables,
       sdeviations = sds)
  class(explorer) <- c("live_explorer", "list")
  explorer
}


give_predictions <- function(data, black_box, explained_var, similar, predict_function,
                             hyperpars = list(), ...) {
  if(is.character(black_box)) {
    mlr_task <- create_task(black_box, as.data.frame(data), explained_var)
    lrn <- mlr::makeLearner(black_box, par.vals = hyperpars)
    trained <- mlr::train(lrn, mlr_task)
    pred <- predict(trained, newdata = as.data.frame(similar))
    list(model = mlr::getLearnerModel(trained),
         predictions = pred[["data"]][["response"]])
  } else {
    list(model = black_box,
         predictions = predict_function(black_box, similar, ...))
  }
}


#' Add black box predictions to generated dataset
#'
#' @param to_explain List return by sample_locally function.
#' @param black_box_model String with mlr signature of a learner or a model with predict interface.
#' @param data Original data frame used to generate new dataset.
#'        Need not be provided when a trained model is passed in
#'        black_box_model argument.
#' @param predict_fun Either a "predict" function that returns a vector of the
#'        same type as response or custom function that takes a model as a first argument,
#'        and data used to calculate predictions as a second argument
#'        and returns a vector of the same type as respone.
#'        Will be used only if a model object was provided in the black_box argument.
#' @param hyperparams Optional list of (hyper)parameters to be passed to mlr::makeLearner.
#' @param ... Additional parameters to be passed to predict function.
#'
#' @return list of class "live_explorer" consisting of
#' \item{data}{Dataset generated by sample_locally function with response variable.}
#' \item{target}{Name of the response variable.}
#' \item{model}{Black box model which is being explained.}
#' \item{explained_instance}{Instance that is being explained.}
#' \item{sampling_method}{Name of used sampling method}
#' \item{fixed_variables}{Names of variables which were not sampled}
#' \item{sdevations}{Standard deviations of numerical variables}
#'
#' @importFrom stats predict
#'
#' @export
#'
#' @examples
#' \dontrun{
#' # Train a model inside add_predictions call.
#' local_exploration1 <- add_predictions(dataset_for_local_exploration,
#'                                       black_box_model = "regr.svm",
#'                                       data = wine)
#' # Pass trained model to the function.
#' svm_model <- svm(quality ~., data = wine)
#' local_exploration2 <- add_predictions(dataset_for_local_exploration,
#'                                       black_box_model = svm_model)
#' }
#'

add_predictions <- function(to_explain, black_box_model, data = NULL, predict_fun = predict,
                            hyperparams = list(), ...) {
  if(is.null(data) & is.character(black_box_model))
    stop("Dataset for training black box model must be provided")
  trained_black_box <- give_predictions(data = data,
                                        black_box = black_box_model,
                                        explained_var = to_explain$target,
                                        similar = to_explain$data,
                                        predict_function = predict_fun,
                                        hyperpars = hyperparams,
                                        ...)
  to_explain$data[[to_explain$target]] <- trained_black_box$predictions

  explorer <- list(data = to_explain$data,
       target = to_explain$target,
       model = trained_black_box$model,
       explained_instance = to_explain$explained_instance,
       sampling_method = to_explain$sampling_method,
       fixed_variables = to_explain$fixed_variables,
       sdeviations = to_explain$sdeviations)
  class(explorer) <- c("live_explorer", "list")
  explorer
}

Try the live package in your browser

Any scripts or data that you put into this service are public.

live documentation built on Jan. 17, 2020, 9:06 a.m.