R/helpers.r

Defines functions guess_distribution convert_strata warnNoVariation convertY check_var_type check_sanity get_var_names check_offset check_weights check_interaction_depth checkMissing check_if_natural_number check_cv_parameters check_if_gbm_var_container check_if_gbm_train_params check_if_gbm_fit check_if_gbm_data check_if_gbm_dist assertInherits

# Series of internal functions used 
# to check inputs and convert parameters

assertInherits <- function(object, class.name) {
    if (!isTRUE(inherits(object, class.name))) {
        stop("Function requires a ", class.name, " object.")
    }
}

##### Check gbm objects #####
check_if_gbm_dist <- function(distribution_obj) {
  # Check if GBM dist object
  if(!any(class(distribution_obj) %in% paste0(available_distributions(), "GBMDist"))) {
    stop("Function requires a specific GBMDist object.")
  }
}

check_if_gbm_data <- function(data_obj) {
    assertInherits(data_obj, "GBMData")
}

check_if_gbm_fit <- function(fit_obj) {
    assertInherits(fit_obj, "GBMFit")
}

check_if_gbm_train_params <- function(params_obj) {
    assertInherits(params_obj, "GBMTrainParams")
}

check_if_gbm_var_container <- function(var_obj) {
    assertInherits(var_obj, "GBMVarCont")
}

#### Check function inputs ####
check_cv_parameters <- function(cv_folds, cv_class_stratify, fold_id, train_params) {
  check_if_natural_number(cv_folds, "cv_folds")
  if(!is.logical(cv_class_stratify)) stop("cv_class_stratify must be a logical")
  check_if_gbm_train_params(train_params)

  # Check fold_id does not split observation data up
  if(!is.null(fold_id)) {
    for(id in train_params$id) {
      if(length(unique(fold_id[train_params$id == id])) > 1) {
        stop("Observations are split across multiple folds")
      }
    }
  }
}

check_if_natural_number <- function(value, name) {
  # value - the value of the parameter to check 
  if(is.null(value) || is.infinite(value)  || is.logical(value)
     || !(abs(value - round(value)) < .Machine$double.eps^0.5) ||
     (value <= 0) || (length(value) > 1)) {
    stop("The parameter  ", name,  " must be a positive whole number")
  }
}

checkMissing <- function(x, y){
   nms <- get_var_names(x)
   #### Check for NaNs in x and NAs in response
   j <- apply(x, 2, function(z) any(is.nan(z)))
   if(any(j)) {
      stop("Use NA for missing values. NaN found in predictor variables:",
           paste(nms[j],collapse=","))
   }
   
   if(any(is.na(y))) stop("Missing values are not allowed in the response")
   
   AllMiss <- apply(x, 2, function(X){all(is.na(X))})
   AllMissVarIndex <- paste(which(AllMiss), collapse = ', ')
   AllMissVar <- paste(nms[which(AllMiss)], collapse = ', ')
   
   if(any(AllMiss)) {
      stop("variable(s) ", AllMissVarIndex, ": ", AllMissVar, " contain only missing values.")
   }
   
   invisible(NULL)
 }

check_interaction_depth <- function(id){
   # Check for disallowed interaction.depth
   if(id < 1) {
      stop("interaction_depth must be at least 1.")
   }
   else if(id > 49) {
      stop("interaction_depth must be less than 50. You should also ask yourself why you want such large interaction terms. A value between 1 and 5 should be sufficient for most applications.")
   }
   invisible(id)
}

check_weights <- function(w, n){
   # Logical checks on weights
   if(length(w)==0) { w <- rep(1, n) }
   else if(any(w != as.double(w))) {stop("weights must be doubles")}
   else if(any(w < 0)) {stop("negative weights not allowed")}
   w
}

check_offset <- function(o, y, dist){
   # Check offset
  if(is.null(o))
      o <- rep(0,length(y))
   else if((length(o) != length(y)) && (distribution_name(dist) != "CoxPH"))
      stop("The length of offset does not equal the length of y.")
   else if(!is.numeric(o))
     stop("offset must be numeric")
   else if(sum(is.na(o))>0)
     stop("offset can not contain NA's")

   o
}

get_var_names <- function(x){
  if(is.matrix(x)) { var.names <- colnames(x) }
  else if(is.data.frame(x)) { var.names <- names(x) }
  else { var.names <- paste("X", 1:ncol(x),sep="") }
  var.names
}

check_sanity <- function(x, y){
  # x and y are not the same length
  if(nrow(x) != ifelse(!is.null(dim(y)), nrow(y), length(y))) {
    stop("The number of rows in x does not equal the length of y.")
  }
}

check_var_type <- function(x, y){
  
  nms <- get_var_names(x)
  
  # Excessive Factors
  Factors <- vapply(x, is.factor, TRUE)
  nLevels <- vapply(x, nlevels, 0L)
  
  excessLevels <- nLevels > 1024
  excessLevelsIndex <- paste(which(excessLevels), collapse = ', ')
  excessLevelsVars <- paste(nms[which(excessLevels)], collapse = ', ')
  
  if(any(excessLevels)) {
    stop("gbm does not currently handle categorical variables with more than 1024 levels. Variable ", excessLevelsIndex,": ", excessLevelsVars," has ", nLevels[which(excessLevels)]," levels.")
    
  }
  
  # Not an acceptable class
  inacceptClass <- vapply(x, function(X){! (is.ordered(X) | is.factor(X) | is.numeric(X)) }, TRUE)
  inacceptClassIndex <- paste(which(inacceptClass), collapse = ', ')
  inacceptClassVars <- paste(nms[which(inacceptClass)], collapse = ', ')
  
  if(any(inacceptClass)){
    stop("variable ", inacceptClassIndex,": ", inacceptClassVars, " is not of type - numeric, ordered or factor.")
  }
  
}

#### Data conversion functions ####
convertY <- function(y){

  FactorsY <- is.factor(y)
  nLevelsY <- nlevels(y)
  
  if(FactorsY & nLevelsY == 2){
    Y = as.numeric(y == levels(y)[2])
  } else {
    Y = y
  }
  
  return(Y)
  
}

## miscellaneous small internal helper functions.

## these are not exported and not formally documented in the manual

## Warn if a variable does not vary
##
## @param x a numeric variable to check
## @param ind an index to use in a potential warning
## @param name a name to include in a potential warning
## @return NULL, invisibly

warnNoVariation <- function(x, ind, name) {
  ## suppress warnings in here
  ## because min and max warn if the variable is completely NA
  suppressWarnings(variation <- range(x, na.rm=TRUE))
  
  ## I really mean ">=" here, which catches the all NA case
  ## and the standard case
  if (variation[[1]] >= variation[[2]]) {
    warning("variable ", ind, ": ", name, " has no variation.")
  }
  
  invisible(NULL)
}

convert_strata <- function(strata) {
  # If factor then convert to integer
  if(is.factor(strata)) {
    strata <- as.integer(strata)
  }
  
  # If it isn't default then check
  if(!is.na(strata[1])) {
    if(!is.vector(strata) || any(is.character(strata)) || any(is.infinite(strata)) || any(is.nan(strata)) ||
       !(all(strata == as.factor(strata)) || all(strata == as.integer(strata)))) {
      stop("strata must be an atomic vector of factors or integers")
    }
  }
  
  return(strata)
}

guess_distribution <- function(response) {
  # This function guesses the distribution if one is not provided
  if(length(unique(response)) == 2) {
    name <- "Bernoulli"
  } else if (inherits(response, "Surv")) {
    name <- "CoxPH"
  } else {
    name <- "Gaussian"
  }
  message("Distribution not specified, assuming ", name, " ...")
  return(list(name=name))
}
gbm-developers/gbm3 documentation built on April 28, 2024, 10:04 p.m.