CNN_varIMP_permute: CNN variable importance using permutation

View source: R/CNN_varIMP_permute.R

CNN_varIMP_permuteR Documentation

CNN variable importance using permutation

Description

CNN permutation variable importance. The permutation importance measured by the decrease in a model score (i.e., Mean Decrease Accuracy (MDA), Mean Decrease in RMSE) when a variable is randomly shuffled n times.

Usage

CNN_varIMP_permute(optmodel, feature_names = NULL, train_y = NULL, train_x = NULL, type = c("difference", "ratio"), nsim = 1, sample_size = NULL, sample_frac = NULL, verbose = FALSE, progress = "none", parallel = FALSE, paropts = NULL, ...)

Arguments

optmodel

The optimal model used to estimate variable importance

feature_names

The names of the variables

train_y

The Y variable (dependent variable) used in regression

train_x

The independent variable dataset

type

Type of comparison "difference" or "ratio"

nsim

number of permutations

sample_size
sample_frac
verbose
progress
parallel
paropts
...

Details

In this implementation, the best model is determined and the orignal variable metrics are used as the baseline. Then the permutation variable performance metrics are tested using the best model as the training set. This procedure breaks the relationship between the variable and the target, thus the drop in the model score is indicative of how much the model depends on the variable.

Value

Return a list of scores, including CNN model decrease in accuracy, the permutation metrics, and the baseline metrics.

References

Fisher, Aaron, Cynthia Rudin, and Francesca Dominici. “Model Class Reliance: Variable importance measures for any machine learning model class, from the ‘Rashomon’ perspective.” http://arxiv.org/abs/1801.01489 (2018).

Examples

##---- Should be DIRECTLY executable !! ----
##-- ==>  Define data, use random,
##--	or do  help(data=index)  for the standard data sets.

## The function is currently defined as
function (optmodel, feature_names = NULL, train_y = NULL, train_x = NULL, 
    type = c("difference", "ratio"), nsim = 1, sample_size = NULL, 
    sample_frac = NULL, verbose = FALSE, progress = "none", parallel = FALSE, 
    paropts = NULL, ...) 
{
    x_cnn = kerasR::expand_dims(train_x, axis = 2)
    baseline <- as.data.frame(t(caret::postResample(pred = keras::predict_on_batch(optmodel, 
        x_cnn), obs = train_y)))
    (rm(x_cnn))
    type <- match.arg(type)
    `%compare%` <- if (type == "difference") {
        `-`
    }
    else {
        `/`
    }
    permute_columns <- function(x, columns = NULL) {
        if (is.null(columns)) {
            stop("No columns specified for permutation.")
        }
        x[, columns] <- x[sample(nrow(x)), columns]
        x
    }
    sort_importance_scores <- function(x, decreasing) {
        x[order(x$Importance, decreasing = decreasing), ]
    }
    CNN_varIMP <- replicate(nsim, (plyr::llply(feature_names, 
        .progress = "none", .parallel = parallel, .paropts = paropts, 
        .fun = function(x) {
            if (verbose && !parallel) {
                message("Computing variable importance for ", 
                  x, "...")
            }
            if (!is.null(sample_size)) {
                ids <- sample(length(train_y), size = sample_size, 
                  replace = FALSE)
                train_x <- train_x[ids, ]
                train_y <- train_y[ids]
            }
            train_x_permuted <- permute_columns(train_x, columns = x)
            x_tensor = kerasR::expand_dims(train_x_permuted, 
                axis = 2)
            permuted <- as.data.frame(t(caret::postResample(pred = keras::predict_on_batch(optmodel, 
                x_tensor), obs = train_y)))
        })))
    CNN_SNPsIMP = do.call(rbind, CNN_varIMP)
    rownames(CNN_SNPsIMP) = feature_names
    decrease_acc = lapply(1:3, function(i) baseline[, i] - CNN_SNPsIMP[, 
        i])
    CNN_Decrease_acc = do.call(cbind, decrease_acc)
    return(list(CNN_Decrease_acc = CNN_Decrease_acc, CNN_SNPsIMP = CNN_SNPsIMP, 
        baseline = baseline))
  }

xinghuq/DeepGenomeScan documentation built on Sept. 20, 2022, 8:46 a.m.