R/varimp.R

Defines functions varimpAUC conditional_perm cutpoints_list create_cond_list

Documented in varimpAUC

# for the current variable of interest, xname,
# create the list of variables to condition on:

create_cond_list <- function(cond, threshold, xname, input) {

   stopifnot(is.logical(cond))
   if (!cond) return(NULL)
   if (threshold > 0 & threshold < 1) {
           ctrl <- ctree_control(teststat = "quad", testtype = "Univariate", stump = TRUE)
           xnames <- names(input)
           xnames <- xnames[xnames != xname]
           ct <- ctree(as.formula(paste(xname, "~", paste(xnames, collapse = "+"), collapse = "")),
                       data = input, controls = ctrl)
           crit <- ct@tree$criterion[[2]]
           crit[which(is.na(crit))] <- 0
           return(xnames[crit > threshold])
       }
   stop()
}



## mincriterion = 0 so that complete tree is evaluated; 
## regulate size of considered tree here via, e.g., mincriterion = 0.95
## or when building the forest in the first place via cforest_control(mincriterion = 0.95)

varimp <- function (object, mincriterion = 0, conditional = FALSE, 
                    threshold = 0.2, nperm = 1, OOB = TRUE, pre1.0_0 = conditional)
{

    response <- object@responses
    if (length(response@variables) == 1 && 
        inherits(response@variables[[1]], "Surv"))
        return(varimpsurv(object, mincriterion, conditional, threshold, nperm, OOB, pre1.0_0))
    input <- object@data@get("input")
    xnames <- colnames(input)
    inp <- initVariableFrame(input, trafo = NULL)
    y <- object@responses@variables[[1]]
    if(length(response@variables) != 1)
        stop("cannot compute variable importance measure for multivariate response")

    if (conditional || pre1.0_0) {
        if(!all(complete.cases(inp@variables)))
            stop("cannot compute variable importance measure with missing values")
    }
    CLASS <- all(response@is_nominal)
    ORDERED <- all(response@is_ordinal)
    if (CLASS) {
        error <- function(x, oob) mean((levels(y)[sapply(x, which.max)] != 
            y)[oob])
    }
    else {
        if (ORDERED) {
            error <- function(x, oob) mean((sapply(x, which.max) != 
                y)[oob])
        }
        else {
            error <- function(x, oob) mean((unlist(x) - y)[oob]^2)
        }
    }

    w <- object@initweights
    if (max(abs(w - 1)) > sqrt(.Machine$double.eps))
        warning(sQuote("varimp"), " with non-unity weights might give misleading results")

    ## list for several permutations
    perror <- matrix(0, nrow = nperm*length(object@ensemble), ncol = length(xnames))
    ## this matrix is initialized with values 0 so that a tree that does not 
    ## contain the current variable adds importance 0 to its average importance
    colnames(perror) <- xnames
        for (b in 1:length(object@ensemble)){
            tree <- object@ensemble[[b]]


            ## if OOB == TRUE use only oob observations, otherwise use all observations in learning sample
            if(OOB){oob <- object@weights[[b]] == 0} else{ oob <- rep(TRUE, length(y))}
            p <- .R_predict(tree, inp, mincriterion, -1L)
            eoob <- error(p, oob)

            ## for all variables (j = 1 ... number of variables) 
            for(j in unique(varIDs(tree))){
              for (per in 1:nperm){

                if (conditional || pre1.0_0) {
                    tmp <- inp
                    ccl <- create_cond_list(conditional, threshold, xnames[j], input)
                    if (length(ccl) < 1) {
                        perm <- sample(which(oob))
                    } else {
                        perm <- conditional_perm(ccl, xnames, input, tree, oob)
                    }
                    tmp@variables[[j]][which(oob)] <- tmp@variables[[j]][perm]
                    p <- .R_predict(tree, tmp, mincriterion, -1L)
                } else {
                    p <- .R_predict(tree, inp, mincriterion, as.integer(j))
                }
                ## run through all rows of perror
                perror[(per+(b-1)*nperm), j] <- (error(p, oob) - eoob)

              } ## end of for (per in 1:nperm)
            } ## end of for(j in unique(varIDs(tree)))
        } ## end of for (b in 1:length(object@ensemble))

    perror <- as.data.frame(perror)
    #return(MeanDecreaseAccuracy = perror) ## return the whole matrix (= nperm*ntree values per variable)
    return(MeanDecreaseAccuracy = colMeans(perror)) ## return only averages over permutations and trees
}


varimpsurv <- function (object, mincriterion = 0, conditional = FALSE, 
                        threshold = 0.2, nperm = 1, OOB = TRUE, pre1.0_0 = conditional)
{

    cat("\n")
    cat("Variable importance for survival forests; this feature is _experimental_\n\n")
    response <- object@responses
    input <- object@data@get("input")
    xnames <- colnames(input)
    inp <- initVariableFrame(input, trafo = NULL)
    y <- object@responses@variables[[1]]
    weights <- object@initweights
    stopifnot(inherits(y, "Surv"))

    if (conditional || pre1.0_0) {
        if(!all(complete.cases(inp@variables)))
            stop("cannot compute variable importance measure with missing values")
    }
    stopifnot(requireNamespace("ipred", quietly = TRUE))
    error <- function(x, oob) ipred::sbrier(y[oob,,drop = FALSE], x[oob])

    pred <- function(tree, newinp, j = -1L) {

        where <- .R_get_nodeID(tree, inp, mincriterion)
        wh <- .Call(R_get_nodeID, tree, newinp, mincriterion, as.integer(j))
        swh <- sort(unique(wh))
        RET <- vector(mode = "list", length = length(wh))
        for (i in 1:length(swh)) {
            w <- weights * (where == swh[i])
            RET[wh == swh[i]] <- list(mysurvfit(y, weights = w))
        }
        return(RET)
    }

    w <- object@initweights
    if (max(abs(w - 1)) > sqrt(.Machine$double.eps))
        warning(sQuote("varimp"), " with non-unity weights might give misleading results")

    ## list for several permutations
    perror <- matrix(0, nrow = nperm*length(object@ensemble), ncol = length(xnames))
    ## this matrix is initialized with values 0 so that a tree that does not 
    ## contain the current variable adds importance 0 to its average importance
    colnames(perror) <- xnames
        for (b in 1:length(object@ensemble)){
            tree <- object@ensemble[[b]]


            ## if OOB == TRUE use only oob observations, otherwise use all observations in learning sample
            if(OOB){oob <- object@weights[[b]] == 0} else{ oob <- rep(TRUE, length(y))}
            p <- pred(tree, inp)
            eoob <- error(p, oob)

            ## for all variables (j = 1 ... number of variables) 
            for(j in unique(varIDs(tree))){
              for (per in 1:nperm){
                 if (conditional || pre1.0_0) {
                    tmp <- inp
                    ccl <- create_cond_list(conditional, threshold, xnames[j], input)
                    if (length(ccl) < 1) {
                        perm <- sample(which(oob))
                    } else {
                        perm <- conditional_perm(ccl, xnames, input, tree, oob)
                    }
                    tmp@variables[[j]][which(oob)] <- tmp@variables[[j]][perm]
                    p <- pred(tree, tmp, -1L)
                } else {
                    p <- pred(tree, inp, as.integer(j))
                }

                ## run through all rows of perror
                perror[(per+(b-1)*nperm), j] <- (error(p, oob) - eoob)

              } ## end of for (per in 1:nperm)
            } ## end of for(j in unique(varIDs(tree)))
        } ## end of for (b in 1:length(object@ensemble))

    perror <- as.data.frame(perror)
    #return(MeanDecreaseAccuracy = perror) ## return the whole matrix (= nperm*ntree values per variable)
    return(MeanDecreaseAccuracy = colMeans(perror)) ## return only averages over permutations and trees
}




# cutpoints_list() returns:
# - vector of cutpoints (length=number of cutpoints) 
#   if variable is continuous
# - vector of indicators (length=number of categories x number of cutpoints)
#   if variable is categorical (nominal or ordered)
cutpoints_list <- function(tree, variableID) {

    cutp <- function(node) {
       if (node[[4]]) return(NULL)
       cp <- NULL
       if (node[[5]][[1]] == variableID)
           cp <- node[[5]][[3]]
       nl <- cutp(node[[8]])
       nr <- cutp(node[[9]])
       return(c(cp, nl, nr))
    }
    return(cutp(tree))
}


conditional_perm <- function(cond, xnames, input, tree, oob){

    ## intitial partitioning => all observations in one partition
    parts <- rep(1, length(oob))
                    
    ## develop partitioning by going over all the conditiong variables
    for (condVar in cond) {

        ## varID is variable index or column number of input (predictor matrix) 
        ## not variable name!
        varID <- which(xnames == condVar)

        ## if conditioning variable is not used for splitting in current tree
        ## proceed with next conditioning variable
        cl <- cutpoints_list(tree, varID)
        if (is.null(cl)) next

        ## proceed cutpoints for different types of variables
        x <- input[, varID]
        xclass <- class(x)[1]
        if (xclass == "integer") xclass <- "numeric"

        block <- switch(xclass, "numeric" = cut(x, breaks = c(-Inf, sort(unique(cl)), Inf)),
                        "ordered" = cut(as.numeric(x), breaks =  c(-Inf, sort(unique(cl)), Inf)),
                        "factor" = {
                            CL <- matrix(as.logical(cl), nrow = nlevels(x))                            
                            rs <- rowSums(CL)
                            dlev <- (1:nrow(CL))[rs %in% rs[duplicated(rs)]]
                            fuse <- c()
                            for (ii in dlev) {
                                for (j in dlev[dlev > ii]) {
                                    if (all(CL[ii,] == CL[j,])) fuse <- rbind(fuse, c(ii, j))
                                }
                            }
                            xlev <- 1:nlevels(x)
                            newl <- nlevels(x) + 1
                            block <- as.integer(x)
                            for (l in xlev) {
                                if (NROW(fuse) == 0) break
                                if (any(fuse[, 1] == l)) {
                                    f <- c(l, fuse[fuse[, 1] == l, 2])
                                    fuse <- fuse[!fuse[,1] %in% f, , drop = FALSE]
                                    block[block %in% f] <- newl
                                    newl <- newl + 1
                                 }
                            }
                            as.factor(block)
                         })
        ## add partitioning based on the split points the variable to the 
        ## current partitioning
        parts <- interaction(parts, as.numeric(block), drop = TRUE, sep = "")
    }

    ## if none of the conditioning variables are used in the tree
    if (!length(levels(parts)) > 1) {
        perm <- sample(which(oob))
        return(perm)
    } else {
        ## one conditional permutation
        perm <- 1:nrow(input)
        for(part in levels(parts)){
           index <- which(parts == part & oob)
           if (length(index) > 1)
               perm[index] <- .resample(index)
        }
        return(perm[oob])
    }
}

varimpAUC <- function(...)
    varImp::varImpAUC(...)

Try the party package in your browser

Any scripts or data that you put into this service are public.

party documentation built on March 31, 2023, 11:56 p.m.