R/univarImpAUC.R

# @importFrom stats as.formula complete.cases
# @importFrom party ctree_control initVariableFrame ctree initVariableFrame party_intern

univarImpAUC = function (object, vname, mincriterion = 0, conditional = FALSE, threshold = 0.2, 
                      nperm = 1, OOB = TRUE, pre1.0_0 = conditional) { 

  response = object@responses
  input = object@data@get("input")
  xnames = colnames(input)
  idv <- which(xnames==vname)
  inp = initVariableFrame(input, trafo = NULL)
  y = object@responses@variables[[1]]
  if (length(response@variables) != 1) 
    stop("cannot compute variable importance measure for multivariate response")
  if (conditional || pre1.0_0) {
    if (!all(complete.cases(inp@variables))) 
      stop("cannot compute variable importance measure with missing values")
  }
  CLASS = all(response@is_nominal)
  ORDERED = all(response@is_ordinal)
  if (!CLASS & !ORDERED)
    stop("only calculable for classification")
  if (CLASS) {
    if (nlevels(y) > 2) {
      stop("varImpAUC() is only usable for binary classification. For multiclass classification please use the standard varImp() function.")
    } else { 
      error <- function(x, oob) {
        xoob <- sapply(x, function(x) x[1])[oob]
        yoob <- y[oob]
        which1 <- which(yoob==levels(y)[1])
        noob1 <- length(which1)
        noob <- length(yoob)
        if (noob1==0|noob1==noob) { return(NA) }       # AUC cannot be computed if all OOB-observations are from one class
        return(1-sum(kronecker(xoob[which1] , xoob[-which1],">"))/(noob1*(length(yoob)-noob1)))       # calculate AUC
      }
    }
  } else {
    if (ORDERED) {
      error = function(x, oob) mean((sapply(x, which.max) != y)[oob])
    }
    else {
      error = function(x, oob) mean((unlist(x) - y)[oob]^2)
    }
  }
  w = object@initweights
  if (max(abs(w - 1)) > sqrt(.Machine$double.eps)) 
    warning(sQuote("varimp"), " with non-unity weights might give misleading results")
  perror = matrix(0, nrow = nperm * length(object@ensemble), ncol = length(xnames))
  colnames(perror) = xnames
  for (b in 1:length(object@ensemble)) {
    tree = object@ensemble[[b]]
    if (OOB) {
      oob = object@weights[[b]] == 0
    } else {
      oob = rep(TRUE, length(xnames))
    }
    p = party_intern(tree, inp, mincriterion, -1L, fun = "R_predict") 
    eoob = error(p, oob)
    for (j in unique(varIDs(tree))) {
      if(j==idv) {
        for (per in 1:nperm) {
          if (conditional || pre1.0_0) {
            tmp = inp
            ccl = create_cond_list(conditional, threshold, 
                                   xnames[j], input)
            if (is.null(ccl)) {
              perm = sample(which(oob))
            }
            else {
              perm = conditional_perm(ccl, xnames, input, 
                                      tree, oob)
            }
            tmp@variables[[j]][which(oob)] = tmp@variables[[j]][perm]
            p = party_intern(tree, tmp, mincriterion, -1L, fun = "R_predict") 
          } else {
            p = party_intern(tree, inp, mincriterion, as.integer(j), fun = "R_predict") 
          }
          perror[(per + (b - 1) * nperm), j] = - (error(p,oob) - eoob)
        }
      }
    }
  }
  perror = as.data.frame(perror)
  MeanDecrease = colMeans(perror, na.rm = TRUE)
  MeanDecrease = MeanDecrease[names(MeanDecrease)==vname]
  return(MeanDecrease)
}
nicolas-robette/moreparty documentation built on April 10, 2024, 2:24 p.m.