R/RePrInDT.R

Defines functions plot.RePrInDT print.RePrInDT RePrInDT

Documented in RePrInDT

#' Repeated \code{\link{PrInDT}} for specified percentage combinations
#'
#' @description
#' By the function \code{\link{RePrInDT}}, the function \code{\link{PrInDT}} is called repeatedly according to all combinations of the percentages specified in the vectors 'plarge' and 'psmall'.\cr
#' The relationship between the two-class factor variable 'classname' and all other factor and numerical variables
#' in the data frame 'datain' is optimally modeled by means of 'N' repetitions of undersampling.\cr 
#' The optimization citerion is the balanced accuracy on the validation sample 'valdat' (default = full input sample 'datain').\cr
#' The trees generated from undersampling can be restricted by rejecting 
#' unacceptable trees which include split results specified in the character strings of the vector 'ctestv'.\cr
#' The probability threshold 'thres' for the prediction of the smaller class may be specified (default = 0.5).\cr
#' Undersampling may be stratified in two ways by the feature 'strat'.\cr
#' The parameters 'conf.level', 'minsplit', and 'minbucket' can be used to control the size of the trees.\cr
#'
#' \strong{Reference}\cr Weihs, C., Buschfeld, S. 2021c. Repeated undersampling in PrInDT (RePrInDT): Variation in undersampling and prediction, 
#' and ranking of predictors in ensembles. arXiv:2108.05129
#'
#' @usage RePrInDT(datain, classname, ctestv=NA, N, plarge, psmall, conf.level=0.95,
#'        thres=0.5, stratvers=0, strat=NA, seedl=TRUE,minsplit=NA,minbucket=NA,
#'        valdat=datain)
#'
#' @param datain Input data frame with class factor variable 'classname' and the\cr
#'    influential variables, which need to be factors or numericals (transform logicals and character variables to factors) 
#' @param classname Name of class variable (character)
#' @param ctestv Vector of character strings of forbidden split results;\cr
#'     (see function \code{\link{PrInDT}} for details.)\cr
#'     If no restrictions exist, the default = NA is used.
#' @param N Number of repetitions (integer > 0)
#' @param plarge Vector of undersampling percentages of larger class (numerical, > 0 and <= 1)
#' @param psmall Vector of undersampling percentages of smaller class (numerical, > 0 and <= 1)
#' @param conf.level (1 - significance level) in function \code{ctree} (numerical, > 0 and <= 1);\cr
#'     default = 0.95
#' @param thres Probability threshold for prediction of smaller class (numerical, >= 0 and < 1); default = 0.5
#' @param stratvers Version of stratification;\cr
#'     = 0: none (default),\cr
#'     = 1: stratification according to the percentages of the values of the factor variable 'strat',\cr
#'     > 1: stratification with minimum number 'stratvers' of observations per value of 'strat'
#' @param strat Name of one (!) stratification variable for undersampling (character);\cr
#'     default = NA (no stratification)
#' @param seedl Should the seed for random numbers be set (TRUE / FALSE)?\cr
#'     default = TRUE
#' @param minsplit Minimum number of elements in a node to be splitted;\cr
#'     default = 20
#' @param minbucket Minimum number of elements in a node;\cr
#'     default = 7
#' @param valdat Validation data; default = datain
#'
#' @return
#' \describe{
#' \item{treesb}{best trees for the different percentage combinations; refer to an individual tree as \code{treesb[[k]]}, k = 1, ..., length(plarge)*length(psmall)}
#' \item{acc1st}{accuracies of best trees on full sample}
#' \item{acc3en}{accuracies of ensemble of 3 best trees on full sample}
#' \item{simp_m}{mean of permutation losses for the predictors}
#' }
#'
#' @details
#' Standard output can be produced by means of \code{print(name)} or just \code{ name } as well as \code{plot(name)} where 'name' is the output data 
#' frame of the function.\cr
#' The plot function will produce a series of more than one plot. If you use R, you might want to specify \code{windows(record=TRUE)} before 
#' \code{plot(name)} to save the whole series of plots. In R-Studio this functionality is provided automatically.
#'
#' @export RePrInDT
#' @importFrom graphics barplot

#' @examples
#' datastrat <- PrInDT::data_zero
#' data <- na.omit(datastrat) # cleaned full data: no NAs
#' # interpretation restrictions (split exclusions)
#' ctestv <- rbind('ETH == {C2a, C1a}', 'MLU == {1, 3}')
#' N <- 51  # no. of repetitions
#' conf.level <- 0.99 # 1 - significance level (mincriterion) in ctree
#' psmall <- c(0.95,1)     # percentages of the small class
#' plarge <- c(0.09,0.1)  # percentages of the large class
#' outRe <- RePrInDT(data,"real",ctestv,N,plarge,psmall,conf.level) 
#' outRe
#' plot(outRe)
#'
RePrInDT <- function(datain,classname,ctestv=NA,N,plarge,psmall,conf.level=0.95,thres=0.5,stratvers=0,strat=NA,seedl=TRUE,minsplit=NA,minbucket=NA,valdat=datain){
  ## input check
  if (typeof(datain) != "list" || typeof(classname) != "character" || !(typeof(ctestv) %in% c("logical", "character")) || N <= 0 ||
      !(all(0 < plarge & plarge <= 1)) || !(all(0 < psmall & psmall <= 1)) ||
      !(0 < conf.level & conf.level <= 1) | !(0 <= thres & thres < 1) ||
      !(0 <= stratvers) || !(typeof(strat) %in% c("logical", "character")) || typeof(seedl) != "logical" || !(typeof(minsplit) %in% c("logical","double")) || 
      !(typeof(minbucket) %in% c("logical", "double")) || typeof(valdat) != "list" ) {
    stop("irregular input")
  }
  if (!(all(names(datain) %in% names(valdat)))){
    stop("validation data variables unequal input data variables")
  }
  if ((is.na(minsplit) == TRUE) & (is.na(minbucket) == TRUE)){
    minsplit <- 20
    minbucket <- 7
  }
  if (!(is.na(minsplit) == TRUE) & (is.na(minbucket) == TRUE)){
    minbucket <- minsplit / 3
  }
  if ((is.na(minsplit) == TRUE) & !(is.na(minbucket) == TRUE)){
    minsplit <- minbucket * 3
  }
  if (seedl == TRUE){
    set.seed(7654321)  # set seed of random numbers
  }
  ###
  data <- datain
  names(data)[names(data)==classname] <- "class"
  names(valdat)[names(valdat)==classname] <- "class"
  if (!(identical(levels(data$class),levels(valdat$class)))){
    stop("levels of input class variable unequal levels of validation class variable")
  }
  ## initializations
  norep <- length(psmall) * length(plarge)
  acc1st <- array(0,dim=c(length(plarge),length(psmall),8))
  acc3en <- array(0,dim=c(length(plarge),length(psmall),6)) # !!!
  simp <- array(0,dim=c(length(plarge),length(psmall),(dim(data)[2])) )
  ## permuting columns
  data_per <- valdat
  for (j in 2:(dim(valdat)[2])) {
    data_per[,j] <- sample(valdat[,j])
  }
  k <- 0
  K <- length(plarge) * length(psmall)
  for (i in 1:length(plarge)) {
    for (j in 1:length(psmall)) {
      k <- k + 1
#      message("\n")
      message("\n","sampling percentage for larger class: ",plarge[i])
      message("sampling percentage for smaller class:",psmall[j])
      ## call of PrInDT
      out <- PrInDT(data,classname,ctestv,N,plarge[i],psmall[j],conf.level,thres,stratvers,strat,seedl=TRUE,minsplit=minsplit,minbucket=minbucket,valdat=valdat)
      if (any(is.na(out)) == TRUE){ # !!!
        return(NA)
      }
      acc1st[i,j,] <- c(plarge[i],psmall[j],out$ba1st)
      acc3en[i,j,] <- c(plarge[i],psmall[j],out$baen[2,2:5])
      if (k == 1){
        treesb <- out$tree1st
        n_class1 <- table(out$valdat$class)[1] # no. of elements of larger class 1
        n_class2 <- table(out$valdat$class)[2] # no. of elements of smaller class 2
      } else {
        treesb <- c(treesb,out$tree1st)
      }
      
#      if (out$ba1st[3] > 0){
        for (l in 2:dim(out$valdat)[2]){
          data_imp <- out$valdat
          data_imp[,l] <- data_per[,l]
          ctpreds_imp <- predict(out$tree1st,newdata=data_imp)
          conf_imp <- table(ctpreds_imp, data_imp$class)
          simp[i,j,l] <- simp[i,j,l] + out$ba1st[3] - (conf_imp[1,1] / n_class1 + conf_imp[2,2] / n_class2)/2
          ctpreds_imp <- predict(out$tree2nd,newdata=data_imp)
          conf_imp <- table(ctpreds_imp, data_imp$class)
          simp[i,j,l] <- simp[i,j,l] + out$ba2nd[3] - (conf_imp[1,1] / n_class1 + conf_imp[2,2] / n_class2)/2
          ctpreds_imp <- predict(out$tree3rd,newdata=data_imp)
          conf_imp <- table(ctpreds_imp, data_imp$class)
          simp[i,j,l] <- simp[i,j,l] + out$ba3rd[3] - (conf_imp[1,1] / n_class1 + conf_imp[2,2] / n_class2)/2
        }
#      }
#      if (k < K){
#        remove(out,data_imp,ctpreds_imp,conf_imp)
#        gc(full=TRUE,reset=TRUE)
#        gc(full=TRUE,reset=TRUE)
#      }
    }
  }
 
## preparation of print
  dimnames(acc1st) <- list(1:length(plarge),1:length(psmall),c("plarge","psmall",paste0("validation ",levels(out$valdat$class)[1]),
                                                               paste0("validation ",levels(out$valdat$class)[2]),"validation balanced",
                                                               paste0("test ",levels(out$valdat$class)[1]),paste0("test ",levels(out$valdat$class)[2]),"test balanced"))
  dimnames(acc3en) <- list(1:length(plarge),1:length(psmall),c("plarge","psmall",paste0("validation ",levels(out$valdat$class)[1]),
                                                               paste0("validation ",levels(out$valdat$class)[2]),"validation balanced","mean test balanced")) # !!!
  simp_m <- sapply(2:dim(data)[2],function(x) mean(simp[,,x]))
  names(simp_m) <- colnames(out$valdat)[-(colnames(out$valdat) == "class")]
###
  result <- list(treesb=treesb,acc1st=acc1st,acc3en=acc3en,simp_m=simp_m)
  class(result) <- "RePrInDT"
  result
}
#' @export
print.RePrInDT <- function(x, ...){
  ####
  ## print accuracies of best trees
  ####
  cat("\n","Accuracies of best trees for different combinations of resampling percentages: on full and test samples","\n")
  apply(x$acc1st,1,print)
  ## print overall best trees
  cat("\n\n","Best trees of percentage combinations","\n")
  maxba <- as.vector(t(x$acc1st[,,5] == max(x$acc1st[,,5])))
  mm <- which(maxba)
  k <- 0 
  for (i in 1:dim(x$acc1st)[1]) {
    for (j in 1:dim(x$acc1st)[2]) {
      k <- k + 1
      if (k %in% mm) {
        cat("\n","tree for  ( plarge  psmall  bal.acc. ) = (",x$acc1st[i,j,c(1,2,5)],")")
        print(x$treesb[[k]])
      }
    }
  }
  ## print accuracies of ensembles of 3 best trees
  cat("\n","Accuracies of ensembles of 3 best trees for different combinations of resampling percentages","\n")
  apply(x$acc3en,1,print)
  ####
  ## ranking of predictors
  ####
  cat("\n\n","Mean of permutation losses for the predictors","\n")
  print(x$simp_m)
}
#' @export
plot.RePrInDT <- function(x, ...){
  ####
  ## plot best trees
  ####
  maxba <- as.vector(t(x$acc1st[,,5] == max(x$acc1st[,,5])))
  mm <- which(maxba)
  lmm <- length(mm)
  k <- 0 
  for (i in 1:dim(x$acc1st)[1]) {
    for (j in 1:dim(x$acc1st)[2]) {
      k <- k + 1
      if (k %in% mm) {
        x$acc1st[i,j,5] <- round(x$acc1st[i,j,5],4)
        v <- toString(x$acc1st[i,j,c(1,2,5)])
        plot(x$treesb[[k]],main=paste0("One of the ",lmm," best trees: ( plarge,  psmall,  bal.acc. ) = (",v,")"))
      }
    }
  }
  ####
  ## ranking of predictors
  ####
  cat("\n")
  barplot(sort(x$simp_m[x$simp_m>0] / max(x$simp_m)),horiz=TRUE,main="Normed means of permutation losses")
}

Try the PrInDT package in your browser

Any scripts or data that you put into this service are public.

PrInDT documentation built on Sept. 11, 2025, 5:11 p.m.