R/check.R

Defines functions check_complete_cases check_beta_fun check_oobag_fun check_predict check_units fctr_check_levels check_new_data_fctrs check_new_data_types check_new_data_names check_pd_inputs check_orsf_inputs check_control_net check_control_cph check_var_types check_dots check_arg_is check_arg_is_integer check_arg_is_valid check_arg_lteq check_arg_gteq check_arg_lt check_arg_gt check_arg_bound check_arg_length check_arg_uni check_arg_type

#' strict checks for inputs
#'
#' @param arg_value the object that is to be checked
#' @param arg_name the name of the object (used for possible error message)
#' @param expected_type what type of object should this be?
#'
#' @return check functions 'return' errors and the intent is
#'   to return nothing if nothing is wrong,
#'   so hopefully nothing is returned.
#'
#' @noRd

check_arg_type <- function(arg_value, arg_name, expected_type){

 if('numeric' %in% expected_type)
  expected_type <- c(setdiff(expected_type, 'numeric'),
                     'double', 'integer')

 arg_type <- typeof(arg_value)

 type_match <-
  arg_type %in% expected_type | inherits(arg_value, expected_type)

 if (!type_match) {

  expected_types <- paste_collapse(x = expected_type,
                                   sep = ', ',
                                   last = ' or ')

  error_msg <-
   paste0(arg_name, " should have type <", expected_types, ">",
          " but instead has type <", arg_type, ">")

  stop(error_msg, call. = FALSE)

 }

}

#' strict checks for inputs
#'
#' @param arg_value the object that is to be checked
#' @param arg_name the name of the object (used for possible error message)
#' @param expected_uni what unique values should `arg_value` have?
#'
#' @return check functions 'return' errors and the intent is
#'   to return nothing if nothing is wrong,
#'   so hopefully nothing is returned.
#'
#' @noRd

check_arg_uni <- function(uni, arg_name, expected_uni){

 # expected_in_uni <- all(expected_uni %in% uni)
 uni_in_expected <- all(uni %in% expected_uni)

 expected_values <- paste_collapse(x = expected_uni,
                                   sep = ', ',
                                   last = ' and ')

 if(!uni_in_expected){

  invalid_values <- paste_collapse(x = setdiff(uni, expected_uni),
                                   sep = ', ',
                                   last = ' and ')

  error_msg <-
   paste0(arg_name, " should contain values of ", expected_values,
          " but has values of ", invalid_values)

  stop(error_msg, call. = FALSE)


 }

}


#' strict checks for inputs
#'
#' @param arg_value the object that is to be checked
#' @param arg_name the name of the object (used for possible error message)
#' @param expected_type what length should `arg_value` have?
#'
#' @return check functions 'return' errors and the intent is
#'   to return nothing if nothing is wrong,
#'   so hopefully nothing is returned.
#'
#' @noRd
check_arg_length <- function(arg_value, arg_name, expected_length){

 arg_length <- length(arg_value)

 length_match <- arg_length %in% expected_length

 if (!length_match) {

  expected_lengths <- paste_collapse(x = expected_length,
                                     sep = ', ',
                                     last = ' or ')

  error_msg <-
   paste0(arg_name, " should have length <", expected_lengths, ">",
          " but instead has length <", arg_length, ">")

  stop(error_msg, call. = FALSE)

 }

}

#' strict checks for inputs
#'
#' @param arg_value the object that is to be checked
#' @param arg_name the name of the object (used for possible error message)
#' @param bound what bounds to use for `arg_value`?
#' @param relational_operator <, <=, >, or >=. The operator determines
#'   how bounds are checked.
#' @param append_to_msg a note to be added to the error message.
#'
#' @return check functions 'return' errors and the intent is
#'   to return nothing if nothing is wrong,
#'   so hopefully nothing is returned.
#'
#' @noRd
check_arg_bound <- function(arg_value, arg_name, bound,
                            relational_operator,
                            append_to_msg){

 .op <- switch(relational_operator,
               'gt' = `>`,
               'lt' = `<`,
               'gteq' = `>=`,
               'lteq' = `<=`)

 .lab <- switch(relational_operator,
                'gt' = ">",
                'lt' = "<",
                'gteq' = ">=",
                'lteq' = "<=")

 .neg <- switch(relational_operator,
                'gt' = "<=",
                'lt' = ">=",
                'gteq' = "<",
                'lteq' = ">")

 fails <- !.op(arg_value, bound)

 if(any(fails)){

  if(length(arg_value) == 1){

   error_msg <-
    paste0(arg_name, " = ", arg_value, " should be ", .lab, " ", bound)

  } else {

   first_offense <- min(which(fails))

   error_msg <- paste0(arg_name, " should be ", .lab, " ", bound, " but has",
                       " at least one value that is ", .neg, " ", bound,
                       " (see ", arg_name, "[", first_offense, "])")
  }

  if(!is.null(append_to_msg)){

   error_msg <- paste(error_msg, append_to_msg)

  }

  stop(error_msg, call. = FALSE)

 }

}

#' strict checks for inputs
#'
#'  argument is strictly greater than a bound.
#'
#' @inheritParams check_arg_bound
#'
#' @return check functions 'return' errors and the intent is
#'   to return nothing if nothing is wrong,
#'   so hopefully nothing is returned.
#'
#' @noRd
check_arg_gt <- function(arg_value, arg_name, bound, append_to_msg = NULL){
 check_arg_bound(arg_value,
                 arg_name,
                 bound,
                 relational_operator = 'gt',
                 append_to_msg)
}

#' strict checks for inputs
#'
#'  argument is strictly less than a bound.
#'
#' @inheritParams check_arg_bound
#'
#' @return check functions 'return' errors and the intent is
#'   to return nothing if nothing is wrong,
#'   so hopefully nothing is returned.
#'
#' @noRd
check_arg_lt <- function(arg_value, arg_name, bound, append_to_msg = NULL){
 check_arg_bound(arg_value,
                 arg_name,
                 bound,
                 relational_operator = 'lt',
                 append_to_msg)
}

#' strict checks for inputs
#'
#'  argument is greater than or equal to a bound.
#'
#' @inheritParams check_arg_bound
#'
#' @return check functions 'return' errors and the intent is
#'   to return nothing if nothing is wrong,
#'   so hopefully nothing is returned.
#'
#' @noRd
check_arg_gteq <- function(arg_value, arg_name, bound, append_to_msg = NULL){
 check_arg_bound(arg_value,
                 arg_name,
                 bound,
                 relational_operator = 'gteq',
                 append_to_msg)
}

#' strict checks for inputs
#'
#'  argument is less than or equal to a bound.
#'
#' @inheritParams check_arg_bound
#'
#' @return check functions 'return' errors and the intent is
#'   to return nothing if nothing is wrong,
#'   so hopefully nothing is returned.
#'
#' @noRd

check_arg_lteq <- function(arg_value, arg_name, bound, append_to_msg = NULL){
 check_arg_bound(arg_value,
                 arg_name,
                 bound,
                 relational_operator = 'lteq',
                 append_to_msg)
}

#' strict checks for inputs
#'
#' @param arg_value the object that is to be checked
#' @param arg_name the name of the object (used for possible error message)
#' @param valid_options what are the valid inputs for `arg_value`?
#'
#' @return check functions 'return' errors and the intent is
#'   to return nothing if nothing is wrong,
#'   so hopefully nothing is returned.
#'
#' @noRd
check_arg_is_valid <- function(arg_value, arg_name, valid_options) {

 valid_arg <- arg_value %in% valid_options

 if (!valid_arg) {

  expected_values <- paste_collapse(x = valid_options,
                                    sep = ', ',
                                    last = ' or ')

  arg_values <- paste_collapse(x = arg_value,
                               sep = ', ',
                               last = ' or ')

  error_msg <- paste0(
   arg_name, " should be <", expected_values, ">",
   " but is instead <", arg_values, ">"
  )

  stop(error_msg, call. = FALSE)

 }

}

#' strict checks for inputs
#'
#' make sure the user has supplied an integer valued input.
#'
#' @param arg_value the object that is to be checked
#' @param arg_name the name of the object (used for possible error message)
#'
#' @return check functions 'return' errors and the intent is
#'   to return nothing if nothing is wrong,
#'   so hopefully nothing is returned.
#'
#' @noRd
check_arg_is_integer <- function(arg_name, arg_value){

 is_integer <- all(as.integer(arg_value) == arg_value)

 if(!is_integer){

  if(length(arg_value) == 1){
   error_msg <- paste0(arg_name, " should be an integer value",
                       " but instead has a value of ", arg_value)
  } else {

   first_offense <- min(which(as.integer(arg_value) != arg_value))

   error_msg <- paste0(arg_name, " should contain only integer values",
                       " but has at least one double value",
                       " (see ", arg_name, "[", first_offense, "])")

  }

  stop(error_msg, call. = FALSE)

 }

}

#' strict checks for inputs
#'
#' @param arg_value the object that is to be checked
#' @param arg_name the name of the object (used for possible error message)
#' @param expected_class what class should `arg_value` inherit from?
#'
#' @return check functions 'return' errors and the intent is
#'   to return nothing if nothing is wrong,
#'   so hopefully nothing is returned.
#'
#' @noRd
check_arg_is <- function(arg_value, arg_name, expected_class){

 arg_is <- inherits(arg_value, expected_class)

 if (!arg_is) {

  expected_classes <- paste_collapse(x = expected_class,
                                     sep = ', ',
                                     last = ' or ')

  arg_classes <- paste_collapse(x = class(arg_value),
                                sep = ', ',
                                last = ' or ')

  error_msg <- paste0(
   arg_name, " should inherit from class <", expected_classes, ">",
   " but instead inherits from <", arg_classes, ">"
  )

  stop(error_msg, call. = FALSE)

 }

}


#' check mis-typed arguments
#'
#' @param .dots ... from a call to .f
#' @param .f the function being called
#'
#' @return an error if you have mis-typed an arg
#' @noRd

check_dots <- function(.dots, .f){

 if(!is_empty(.dots)){

  .args <- setdiff(names(formals(.f)), '...')
  .dots <- names(.dots)

  for(i in seq_along(.dots)){

   .match_indices <- utils::adist(x = .dots[i],
                                  y = .args,
                                  fixed = TRUE,
                                  costs = c(ins = 1,
                                            del = 1,
                                            sub = 2))

   .match_index <- which.min(.match_indices)

   .dots[i] <- paste('  ', .dots[i],
                     ' is unrecognized - did you mean ',
                     .args[.match_index], '?', sep = '')
  }


  stop("there were unrecognized arguments:\n",
       paste(.dots, collapse = '\n'),
       call. = FALSE)

 }

}



#' Check variable types
#'
#' orsf() should only be run with certain types of variables. This function
#'   checks input data to make sure all variables have a primary (i.e., first)
#'   class that is within the list of valid options.
#'
#' @param data data frame with variables to be checked
#' @param .names names of variables in `data` to check
#' @param valid_types any of these types are okay.
#'
#' @return check functions 'return' errors and the intent is
#'   to return nothing if nothing is wrong,
#'   so hopefully nothing is returned.
#'
#' @noRd

check_var_types <- function(data, .names, valid_types){

 var_types <- vector(mode = 'character', length = length(.names))

 for(i in seq_along(.names)){
  var_types[i] <- class(data[[ .names[i] ]])[1]
 }

 good_vars <- var_types %in% valid_types

 if(!all(good_vars)){

  bad_vars <- which(!good_vars)

  vars_to_list <- .names[bad_vars]
  types_to_list <- var_types[bad_vars]

  meat <- paste0(' <', vars_to_list, '> has type <',
                 types_to_list, '>', collapse = '\n')

  msg <- paste0("some variables have unsupported type:\n",
                meat, '\nsupported types are ',
                paste_collapse(valid_types, last = ' and '))

  stop(msg, call. = FALSE)

 }

 var_types

}

#' Check inputs for orsf_control_cph()
#'
#' @inheritParams orsf_control_cph
#'
#' @return check functions 'return' errors and the intent is
#'   to return nothing if nothing is wrong,
#'   so hopefully nothing is returned.
#'
#' @noRd
#'
check_control_cph <- function(method = NULL,
                              eps = NULL,
                              iter_max = NULL,
                              do_scale = NULL){


 if(!is.null(method)){

  check_arg_type(arg_value = method,
                 arg_name = 'method',
                 expected_type = 'character')

  check_arg_is_valid(arg_value = method,
                     arg_name = 'method',
                     valid_options = c("breslow", "efron"))

 }

 if(!is.null(eps)){
  check_arg_type(arg_value = eps,
                 arg_name = 'eps',
                 expected_type = 'numeric')

  check_arg_gt(arg_value = eps,
               arg_name = 'eps',
               bound = 0)

  check_arg_length(arg_value = eps,
                   arg_name = 'eps',
                   expected_length = 1)
 }

 if(!is.null(iter_max)){
  check_arg_type(arg_value = iter_max,
                 arg_name = 'iter_max',
                 expected_type = 'numeric')

  check_arg_is_integer(arg_value = iter_max,
                       arg_name = 'iter_max')

  check_arg_gteq(arg_value = iter_max,
                 arg_name = 'iter_max',
                 bound = 1)

  check_arg_length(arg_value = iter_max,
                   arg_name = 'iter_max',
                   expected_length = 1)
 }

 if(!is.null(do_scale)){
  check_arg_type(arg_value = do_scale,
                 arg_name = 'do_scale',
                 expected_type = 'logical')

  check_arg_length(arg_value = do_scale,
                   arg_name = 'do_scale',
                   expected_length = 1)

  if(!is.null(iter_max)){

   if(!do_scale && iter_max > 1){
    stop("do_scale must be TRUE when iter_max > 1",
         call. = FALSE)
   }

  }

 }

}


#' Check inputs for orsf_control_net()
#'
#' @inheritParams orsf_control_net
#'
#' @return check functions 'return' errors and the intent is
#'   to return nothing if nothing is wrong,
#'   so hopefully nothing is returned.
#'
#' @noRd
#'
check_control_net <- function(alpha, df_target){

 check_arg_type(arg_value = alpha,
                arg_name = 'alpha',
                expected_type = 'numeric')

 check_arg_gteq(arg_value = alpha,
                arg_name = 'alpha',
                bound = 0)

 check_arg_lteq(arg_value = alpha,
                arg_name = 'alpha',
                bound = 1)

 check_arg_length(arg_value = alpha,
                  arg_name = 'alpha',
                  expected_length = 1)

 if(!is.null(df_target)){

  check_arg_type(arg_value = df_target,
                 arg_name = 'df_target',
                 expected_type = 'numeric')

  check_arg_is_integer(arg_value = df_target,
                       arg_name = 'df_target')

 }

}

#' Check inputs for orsf()
#'
#' @inheritParams orsf
#'
#' @return check functions 'return' errors and the intent is
#'   to return nothing if nothing is wrong,
#'   so hopefully nothing is returned.
#'
#' @noRd
#'
check_orsf_inputs <- function(data = NULL,
                              formula = NULL,
                              control = NULL,
                              weights = NULL,
                              n_tree = NULL,
                              n_split = NULL,
                              n_retry = NULL,
                              n_thread = NULL,
                              mtry = NULL,
                              sample_with_replacement = NULL,
                              sample_fraction = NULL,
                              leaf_min_events = NULL,
                              leaf_min_obs = NULL,
                              split_rule = NULL,
                              split_min_events = NULL,
                              split_min_obs = NULL,
                              split_min_stat = NULL,
                              oobag_pred_type = NULL,
                              oobag_pred_horizon = NULL,
                              oobag_eval_every = NULL,
                              importance = NULL,
                              tree_seeds = NULL,
                              attach_data = NULL,
                              verbose_progress = NULL){

 if(!is.null(data)){

  check_arg_is(arg_value = data,
               arg_name = 'data',
               expected_class = 'data.frame')

  # Minimum event numbers are checked later.
  # Also, later we check to make sure there are at least 2 columns.
  # We specify ncol > 0 here to make the error message that users will
  # receive more specific.
  if(nrow(data) == 0 || ncol(data) ==  0){
   stop("training data are empty",
        call. = FALSE)
  }

  # check for blanks first b/c the check for non-standard symbols
  # will detect blanks with >1 empty characters

  blank_names <- grepl(pattern = '^\\s*$',
                       x = names(data))

  if(any(blank_names)){

   s_if_plural_blank_otherwise <- ""

   to_list <- which(blank_names)

   if(length(to_list) > 1) s_if_plural_blank_otherwise <- "s"

   last <- ifelse(length(to_list) == 2, ' and ', ', and ')

   stop("Blank or empty names detected in training data: see column",
        s_if_plural_blank_otherwise, " ",
        paste_collapse(x = to_list, last = last),
        call. = FALSE)

  }

  ns_names <- grepl(pattern = '[^a-zA-Z0-9\\.\\_]+',
                    x = names(data))

  if(any(ns_names)){

   last <- ifelse(sum(ns_names) == 2, ' and ', ', and ')

   stop("Non-standard names detected in training data: ",
        paste_collapse(x = names(data)[ns_names],
                       last = last),
        call. = FALSE)

  }

 }


 if(!is.null(formula)){

  check_arg_is(arg_value = formula,
               arg_name = 'formula',
               expected_class = 'formula')

  if(length(formula) != 3){
   stop("formula must be two sided, i.e. left side ~ right side",
        call. = FALSE)
  }

  # browser()

  formula_deparsed <- as.character(formula)[[3]]

  for( symbol in c("*", "^", ":", "(", ")", "["," ]", "|", "%") ){

   if(grepl(symbol, formula_deparsed, fixed = TRUE)){

    stop("unrecognized symbol in formula: ", symbol,
         "\norsf recognizes '+', '-', and '.' symbols.",
         call. = FALSE)

   }

  }

 }

 if(!is.null(control)){

  check_arg_is(arg_value = control,
               arg_name = 'control',
               expected_class = 'orsf_control')

 }

 if(!is.null(weights)){

  check_arg_type(arg_value = weights,
                 arg_name = 'weights',
                 expected_type = 'numeric')

  check_arg_gteq(arg_value = weights,
                 arg_name = 'weights',
                 bound = 0)

  if(length(weights) != nrow(data)){

   stop('weights should have length <', nrow(data),
        "> (the number of observations in data)",
        "but instead has length <", length(weights), ">",
        call. = FALSE)

  }


 }

 if(!is.null(n_tree)){

  check_arg_type(arg_value = n_tree,
                 arg_name = 'n_tree',
                 expected_type = 'numeric')

  check_arg_is_integer(arg_value = n_tree,
                       arg_name = 'n_tree')

  check_arg_gteq(arg_value = n_tree,
                 arg_name = 'n_tree',
                 bound = 1)

  check_arg_length(arg_value = n_tree,
                   arg_name = 'n_tree',
                   expected_length = 1)

 }

 if(!is.null(n_split)){

  check_arg_type(arg_value = n_split,
                 arg_name = 'n_split',
                 expected_type = 'numeric')

  check_arg_is_integer(arg_value = n_split,
                       arg_name = 'n_split')

  check_arg_gteq(arg_value = n_split,
                 arg_name = 'n_split',
                 bound = 1)

  check_arg_length(arg_value = n_split,
                   arg_name = 'n_split',
                   expected_length = 1)

 }

 if(!is.null(n_retry)){

  check_arg_type(arg_value = n_retry,
                 arg_name = 'n_retry',
                 expected_type = 'numeric')

  check_arg_is_integer(arg_value = n_retry,
                       arg_name = 'n_retry')

  check_arg_gteq(arg_value = n_retry,
                 arg_name = 'n_retry',
                 bound = 0)

  check_arg_length(arg_value = n_retry,
                   arg_name = 'n_retry',
                   expected_length = 1)

 }

 if(!is.null(n_thread)){

  check_arg_type(arg_value = n_thread,
                 arg_name = 'n_thread',
                 expected_type = 'numeric')

  check_arg_is_integer(arg_name = 'n_thread',
                       arg_value = n_thread)

  check_arg_gteq(arg_name = 'n_thread',
                 arg_value = n_thread,
                 bound = 0)

  check_arg_length(arg_name = 'n_thread',
                   arg_value = n_thread,
                   expected_length = 1)

 }

 if(!is.null(mtry)){

  check_arg_type(arg_value = mtry,
                 arg_name = 'mtry',
                 expected_type = 'numeric')

  check_arg_is_integer(arg_name = 'mtry',
                       arg_value = mtry)

  check_arg_gteq(arg_name = 'mtry',
                 arg_value = mtry,
                 bound = 1)

  check_arg_length(arg_name = 'mtry',
                   arg_value = mtry,
                   expected_length = 1)

 }

 if(!is.null(sample_with_replacement)){

  check_arg_type(arg_value = sample_with_replacement,
                 arg_name = 'sample_with_replacement',
                 expected_type = 'logical')

  check_arg_length(arg_name = 'sample_with_replacement',
                   arg_value = sample_with_replacement,
                   expected_length = 1)

 }

 if(!is.null(sample_fraction)){

  check_arg_type(arg_value = sample_fraction,
                 arg_name = 'sample_fraction',
                 expected_type = 'numeric')

  check_arg_gt(arg_value = sample_fraction,
               arg_name = 'sample_fraction',
               bound = 0)

  check_arg_lteq(arg_value = sample_fraction,
                 arg_name = 'sample_fraction',
                 bound = 1)

  check_arg_length(arg_value = sample_fraction,
                   arg_name = 'sample_fraction',
                   expected_length = 1)

 }

 if(!is.null(leaf_min_events)){

  check_arg_type(arg_value = leaf_min_events,
                 arg_name = 'leaf_min_events',
                 expected_type = 'numeric')

  check_arg_is_integer(arg_value = leaf_min_events,
                       arg_name = 'leaf_min_events')

  check_arg_gteq(arg_value = leaf_min_events,
                 arg_name = 'leaf_min_events',
                 bound = 1)

  check_arg_length(arg_value = leaf_min_events,
                   arg_name = 'leaf_min_events',
                   expected_length = 1)
 }

 if(!is.null(leaf_min_obs)){

  check_arg_type(arg_value = leaf_min_obs,
                 arg_name = 'leaf_min_obs',
                 expected_type = 'numeric')

  check_arg_is_integer(arg_value = leaf_min_obs,
                       arg_name = 'leaf_min_obs')

  check_arg_gteq(arg_value = leaf_min_obs,
                 arg_name = 'leaf_min_obs',
                 bound = 1)

  check_arg_length(arg_value = leaf_min_obs,
                   arg_name = 'leaf_min_obs',
                   expected_length = 1)

 }

 if(!is.null(split_rule)){

  check_arg_type(arg_value = split_rule,
                 arg_name = 'split_rule',
                 expected_type = 'character')

  check_arg_length(arg_value = split_rule,
                   arg_name = 'split_rule',
                   expected_length = 1)

  check_arg_is_valid(arg_value = split_rule,
                     arg_name = 'split_rule',
                     valid_options = c("logrank", "cstat"))

 }

 if(!is.null(split_min_events)){

  check_arg_type(arg_value = split_min_events,
                 arg_name = 'split_min_events',
                 expected_type = 'numeric')

  check_arg_is_integer(arg_value = split_min_events,
                       arg_name = 'split_min_events')

  check_arg_gteq(arg_value = split_min_events,
                 arg_name = 'split_min_events',
                 bound = 1)

  check_arg_length(arg_value = split_min_events,
                   arg_name = 'split_min_events',
                   expected_length = 1)
 }

 if(!is.null(split_min_obs)){

  check_arg_type(arg_value = split_min_obs,
                 arg_name = 'split_min_obs',
                 expected_type = 'numeric')

  check_arg_is_integer(arg_value = split_min_obs,
                       arg_name = 'split_min_obs')

  check_arg_gteq(arg_value = split_min_obs,
                 arg_name = 'split_min_obs',
                 bound = 1)

  check_arg_length(arg_value = split_min_obs,
                   arg_name = 'split_min_obs',
                   expected_length = 1)

 }

 if(!is.null(split_min_stat)){

  check_arg_type(arg_value = split_min_stat,
                 arg_name = 'split_min_stat',
                 expected_type = 'numeric')

  check_arg_gteq(arg_value = split_min_stat,
                 arg_name = 'split_min_stat',
                 bound = 0)

  check_arg_length(arg_value = split_min_stat,
                   arg_name = 'split_min_stat',
                   expected_length = 1)

 }


 if(!is.null(oobag_pred_type)){

  check_arg_type(arg_value = oobag_pred_type,
                 arg_name = 'oobag_pred_type',
                 expected_type = 'character')

  check_arg_length(arg_value = oobag_pred_type,
                   arg_name = 'oobag_pred_type',
                   expected_length = 1)

  check_arg_is_valid(arg_value = oobag_pred_type,
                     arg_name = 'oobag_pred_type',
                     valid_options = c("none",
                                       "surv",
                                       "risk",
                                       "chf",
                                       "mort",
                                       "leaf"))

 }

 if(!is.null(oobag_pred_horizon)){

  check_arg_type(arg_value = oobag_pred_horizon,
                 arg_name = 'oobag_pred_horizon',
                 expected_type = 'numeric')

  # check_arg_length(arg_value = oobag_pred_horizon,
  #                  arg_name = 'oobag_pred_horizon',
  #                  expected_length = 1)

  for(i in seq_along(oobag_pred_horizon)){

   check_arg_gteq(arg_value = oobag_pred_horizon[i],
                  arg_name = 'oobag_pred_horizon',
                  bound = 0)

  }

 }


 if(!is.null(oobag_eval_every)){

  check_arg_type(arg_value = oobag_eval_every,
                 arg_name = 'oobag_eval_every',
                 expected_type = 'numeric')

  check_arg_is_integer(arg_value = oobag_eval_every,
                       arg_name = 'oobag_eval_every')

  check_arg_gteq(arg_value = oobag_eval_every,
                 arg_name = 'oobag_eval_every',
                 bound = 1)

  check_arg_lteq(arg_value = oobag_eval_every,
                 arg_name = 'oobag_eval_every',
                 bound = n_tree)

  check_arg_length(arg_value = oobag_eval_every,
                   arg_name = 'oobag_eval_every',
                   expected_length = 1)

 }

 if(!is.null(importance)){

  check_arg_type(arg_value = importance,
                 arg_name = 'importance',
                 expected_type = 'character')

  check_arg_length(arg_value = importance,
                   arg_name = 'importance',
                   expected_length = 1)

  check_arg_is_valid(arg_value = importance,
                     arg_name = 'importance',
                     valid_options = c("none",
                                       "anova",
                                       "negate",
                                       "permute"))


 }

 if(!is.null(tree_seeds)){

  check_arg_type(arg_value = tree_seeds,
                 arg_name = 'tree_seed',
                 expected_type = 'numeric')

  check_arg_is_integer(tree_seeds, arg_name = 'tree_seeds')

  if(length(tree_seeds) > 1 && length(tree_seeds) != n_tree){

   stop('tree_seeds should have length <', n_tree,
        "> (the number of trees) but instead has length <",
        length(tree_seeds), ">", call. = FALSE)

  }

 }

 if(!is.null(attach_data)){

  check_arg_type(arg_value = attach_data,
                 arg_name = 'attach_data',
                 expected_type = 'logical')

  check_arg_length(arg_value = attach_data,
                   arg_name = 'attach_data',
                   expected_length = 1)

 }

 if(!is.null(verbose_progress)){

  check_arg_type(arg_value = verbose_progress,
                 arg_name = 'verbose_progress',
                 expected_type = 'logical')

  check_arg_length(arg_value = verbose_progress,
                   arg_name = 'verbose_progress',
                   expected_length = 1)

 }

}

#' Check inputs for orsf_pd()
#'
#' @inheritParams orsf_pd
#'
#' @return check functions 'return' errors and the intent is
#'   to return nothing if nothing is wrong,
#'   so hopefully nothing is returned.
#'
#' @noRd
#'
check_pd_inputs <- function(object,
                            pred_spec = NULL,
                            expand_grid = NULL,
                            prob_values = NULL,
                            prob_labels = NULL,
                            oobag = NULL,
                            boundary_checks = NULL,
                            new_data = NULL,
                            pred_horizon = NULL,
                            pred_type = NULL,
                            na_action = NULL){

 check_arg_is(arg_value = object,
              arg_name = 'object',
              expected_class = 'orsf_fit')

 if(!is.null(boundary_checks)){

  check_arg_type(arg_value = boundary_checks,
                 arg_name = 'boundary_checks',
                 expected_type = 'logical')

  check_arg_length(arg_value = boundary_checks,
                   arg_name = 'boundary_checks',
                   expected_length = 1)

 }

 if(!is.null(pred_spec)){

  if(is_empty(pred_spec)){
   stop("pred_spec is empty", call. = FALSE)
  }

  if(is_empty(names(pred_spec))){
   stop("pred_spec is unnamed", call. = FALSE)
  }

  bad_name_index <- which(is.na(match(names(pred_spec), get_names_x(object))))

  if(!is_empty(bad_name_index)){

   bad_names <- names(pred_spec)[bad_name_index]

   stop("variables in pred_spec are not recognized as predictors in object: ",
        paste_collapse(bad_names, last = ' and '),
        call. = FALSE)

  }

  numeric_bounds <- get_numeric_bounds(object)
  numeric_names <- intersect(colnames(numeric_bounds), names(pred_spec))

  if(is.null(boundary_checks)) boundary_checks <- TRUE

  if(!is_empty(numeric_names) && boundary_checks){

   for(.name in numeric_names){

    vals_above_stop <- which(pred_spec[[.name]] > numeric_bounds['90%', .name])
    vals_below_stop <- which(pred_spec[[.name]] < numeric_bounds['10%', .name])

    boundary_error <- FALSE
    vals_above_list <- vals_below_list <- " "

    if(!is_empty(vals_above_stop)){
     vals_above_list <- paste_collapse(
      round_magnitude(pred_spec[[.name]][vals_above_stop]),
      last = ' and '
     )

     boundary_error <- TRUE
     vals_above_list <-
      paste0(" (",vals_above_list," > ", numeric_bounds['90%', .name],") ")

    }

    if(!is_empty(vals_below_stop)){

     vals_below_list <- paste_collapse(
      round_magnitude(pred_spec[[.name]][vals_below_stop]),
      last = ' and '
     )

     boundary_error <- TRUE

     vals_below_list <-
      paste0(" (",vals_below_list," < ", numeric_bounds['10%', .name],") ")

    }

    if(boundary_error)
     stop("Some values for ",
          .name,
          " in pred_spec are above",
          vals_above_list,
          "or below",
          vals_below_list,
          "90th or 10th percentiles in training data.",
          " Change pred_spec or set boundary_checks = FALSE",
          " to prevent this error",
          call. = FALSE)

   }

  }

 }

 if(!is.null(expand_grid)){

  check_arg_type(arg_value = expand_grid,
                 arg_name = 'expand_grid',
                 expected_type = 'logical')

  check_arg_length(arg_value = expand_grid,
                   arg_name = 'expand_grid',
                   expected_length = 1)

 }

 if(!is.null(prob_values)){

  check_arg_type(arg_value = prob_values,
                 arg_name = 'prob_values',
                 expected_type = 'numeric')

  check_arg_gteq(arg_value = prob_values,
                 arg_name = 'prob_values',
                 bound = 0)

  check_arg_lteq(arg_value = prob_values,
                 arg_name = 'prob_values',
                 bound = 1)

 }

 if(!is.null(prob_labels)){

  check_arg_type(arg_value = prob_labels,
                 arg_name = 'prob_labels',
                 expected_type = 'character')

 }

 if(!is.null(prob_values) && !is.null(prob_labels)){

  if(length(prob_values) != length(prob_labels)){
   stop("prob_values and prob_labels must have the same length.",
        call. = FALSE)
  }

 }

 if(!is.null(oobag)){

  check_arg_type(arg_value = oobag,
                 arg_name = 'oobag',
                 expected_type = 'logical')

  check_arg_length(arg_value = oobag,
                   arg_name = 'oobag',
                   expected_length = 1)

 }

 check_predict(object = object,
               new_data = new_data,
               pred_horizon = pred_horizon,
               pred_type = pred_type,
               na_action = na_action,
               valid_pred_types = c("risk", "surv", "chf", "mort"))

}

#' New data have same names as reference data
#'
#' @param new_data data.frame to check
#' @param ref_names character vector of names from reference data
#' @param label_new what to call the new data if an error is printed
#' @param label_ref what to call the reference data if an error is printed.
#' @param check_new_in_ref T/F; make sure all new names are in reference data?
#' @param check_ref_in_new T/F; make sure all reference names are in new data?
#'
#' @return check functions 'return' errors and the intent is
#'   to return nothing if nothing is wrong,
#'   so hopefully nothing is returned.
#'
#' @noRd

check_new_data_names <- function(new_data,
                                 ref_names,
                                 label_new,
                                 label_ref,
                                 check_new_in_ref = FALSE,
                                 check_ref_in_new = TRUE){

 new_names <- names(new_data)

 list_new <- FALSE

 if(check_new_in_ref) list_new <- !(new_names %in% ref_names)

 list_ref <- FALSE

 if(check_ref_in_new) list_ref <- !(ref_names %in% new_names)

 error_new <- any(list_new)
 error_ref <- any(list_ref)

 if(error_new){
  out_msg_new <- paste(
   label_new, " have columns not contained in ", label_ref, ": ",
   paste_collapse(new_names[list_new], last = ' and ')
  )
 }

 if(error_ref){
  out_msg_ref <- paste(
   label_ref, " have columns not contained in ", label_new, ": ",
   paste_collapse(ref_names[list_ref], last = ' and ')
  )
 }

 if(error_new && error_ref){
  out_msg <- c(out_msg_new, '\n Also, ', out_msg_ref)
 }

 if (error_new && !error_ref) {
  out_msg <- c(out_msg_new)
 }

 if (!error_new && error_ref){
  out_msg <- c(out_msg_ref)
 }

 any_error <- error_new | error_ref

 if(any_error){
  stop(out_msg, call. = FALSE)
 }

}

#' New data have same types as reference data
#'
#' If new data have an integer vector where the ref data
#'  had a factor vector, orsf_predict() will yell! Also
#'  it is good practice to make sure users are supplying
#'  consistent data types.
#'
#' @inheritParams check_new_data_names
#' @param ref_types the types of variables in reference data
#'
#' @return check functions 'return' errors and the intent is
#'   to return nothing if nothing is wrong,
#'   so hopefully nothing is returned.
#'
#' @noRd
check_new_data_types <- function(new_data,
                                 ref_names,
                                 ref_types,
                                 label_new,
                                 label_ref){

 var_types <- vector(mode = 'character', length = length(ref_names))

 for(i in seq_along(ref_names)){
  var_types[i] <- class(new_data[[ ref_names[i] ]])[1]
 }

 bad_types <- which(var_types != ref_types)

 if(!is_empty(bad_types)){

  vars_to_list <- ref_names[bad_types]
  types_to_list <- var_types[bad_types]

  meat <- paste0('<', vars_to_list, '> has type <',
                 types_to_list, '>', " in ", label_new,
                 "; type <", ref_types[bad_types], "> in ",
                 label_ref, collapse = '\n')

  msg <- paste("some variables in ", label_new,
               " have different type in ",
               label_ref, ":\n", meat)

  stop(msg, call. = FALSE)

 }

}

#' Check factor variables in new data
#'
#' Factors may have new levels in the testing data, which
#'   would certainly mess up orsf_predict(). So ask user
#'   to fix the factor level before calling predict.
#'
#' @param new_data data that are used for predicting risk or survival
#' @param names_x names of the x variables in training data
#' @param fi_ref factor info from training data
#' @param label_new what to call the new_data if error message is printed.
#'
#' @return check functions 'return' errors and the intent is
#'   to return nothing if nothing is wrong,
#'   so hopefully nothing is returned.
#'
#' @noRd

check_new_data_fctrs <- function(new_data,
                                 names_x,
                                 fi_ref,
                                 label_new){

 fctr_check(new_data, names_x)

 fi_new <- fctr_info(new_data, names_x)

 for(fi_col in fi_ref$cols){
  fctr_check_levels(ref = fi_ref$lvls[[fi_col]],
                    new = fi_new$lvls[[fi_col]],
                    name = fi_col,
                    label_ref = "training data",
                    label_new = label_new)
 }

}


#' check levels of individual factor
#'
#' @param ref levels of factor in reference data
#' @param new levels of factor in new data
#' @param name name of the factor variable
#' @param label_ref what to call reference data if error message is printed.
#' @param label_new what to call new data if error message is printed.
#'
#' @return check functions 'return' errors and the intent is
#'   to return nothing if nothing is wrong,
#'   so hopefully nothing is returned.
#'
#' @noRd

fctr_check_levels <- function(ref,
                              new,
                              name,
                              label_ref,
                              label_new){

 list_new  <- !(new %in% ref)

 if(any(list_new)){

  out_msg <- paste0(
   "variable ", name, " in ", label_new,
   " has levels not contained in ", label_ref, ": ",
   paste_collapse(new[list_new], last = ' and ')
  )

  stop(out_msg, call. = FALSE)

 }


}


#' check units
#'
#' @param new_data new data to check units in
#'
#' @param ui_train unit information in training data
#'
#' @return nada
#'
#' @noRd

check_units <- function(new_data, ui_train) {

 ui_new <- unit_info(data = new_data, .names = names(ui_train))

 ui_missing <- setdiff(names(ui_train), names(ui_new))

 if(!is_empty(ui_missing)){

  if(length(ui_missing) == 1){
   stop(ui_missing, " had unit attributes in training data but",
        " did not have unit attributes in testing data.",
        " Please ensure that variables in new data have the same",
        " units as their counterparts in the training data.",
        call. = FALSE)
  }

  stop(length(ui_missing),
       " variables (",
       paste_collapse(ui_missing, last = ' and '),
       ") had unit attributes in training",
       " data but did not have unit attributes in new data.",
       " Please ensure that variables in new data have the same",
       " units as their counterparts in the training data.",
       call. = FALSE)

 }

 for(i in names(ui_train)){

  if(ui_train[[i]]$label != ui_new[[i]]$label){

   msg <- paste("variable", i, 'has unit', ui_train[[i]]$label,
                'in the training data but has unit', ui_new[[i]]$label,
                'in new data')

   stop(msg, call. = FALSE)

  }

 }

}

#' Run prediction checks
#'
#' The intent of this function is to protect users from common
#'   inconsistencies that can occur between training data and
#'   testing data. Factor levels in training data need to be
#'   present in the corresponding factors of the testing data,
#'   and all variables used by the model need to be present in
#'   the testing data as well. In addition, the inputs of the
#'   orsf_predict() function are checked.
#'
#' @inheritParams orsf_predict
#'
#' @return check functions 'return' errors and the intent is
#'   to return nothing if nothing is wrong,
#'   so hopefully nothing is returned.
#'
#' @noRd

check_predict <- function(object,
                          new_data = NULL,
                          pred_horizon = NULL,
                          pred_type = NULL,
                          na_action = NULL,
                          boundary_checks = TRUE,
                          valid_pred_types = c("risk", "surv", "chf", "mort", "leaf")){

 if(!is.null(new_data)){

  check_arg_is(arg_value = new_data,
               arg_name = 'new_data',
               expected_class = 'data.frame')

  if(nrow(new_data) == 0 || ncol(new_data) ==  0){
   stop("new data are empty",
        call. = FALSE)
  }

  ui_train <- get_unit_info(object)

  # check unit info for new data if training data had unit variables
  if(!is_empty(ui_train)) check_units(new_data, ui_train)

  check_new_data_names(new_data  = new_data,
                       ref_names = get_names_x(object),
                       label_new = "new_data",
                       label_ref = 'training data')

  check_new_data_types(new_data  = new_data,
                       ref_names = get_names_x(object),
                       ref_types = get_types_x(object),
                       label_new = "new_data",
                       label_ref = 'training data')

  check_new_data_fctrs(new_data  = new_data,
                       names_x   = get_names_x(object),
                       fi_ref    = get_fctr_info(object),
                       label_new = "new_data")

  for(i in c(get_names_x(object))){

   if(any(is.infinite(new_data[[i]]))){
    stop("Please remove infinite values from ", i, ".",
         call. = FALSE)
   }

   # NaN values trigger is.na(), so this probably isn't needed.
   # if(any(is.nan(new_data[[i]]))){
   #  stop("Please remove NaN values from ", i, ".",
   #       call. = FALSE)
   # }

  }


 }

 if(!is.null(pred_type)){

  check_arg_type(arg_value = pred_type,
                 arg_name = 'pred_type',
                 expected_type = 'character')

  check_arg_length(arg_value = pred_type,
                   arg_name = 'pred_type',
                   expected_length = 1)

  check_arg_is_valid(arg_value = pred_type,
                     arg_name = 'pred_type',
                     valid_options = valid_pred_types)

 }

 if(!is.null(pred_horizon)){

  if(!is.null(boundary_checks)){

   check_arg_type(arg_value = boundary_checks,
                  arg_name = 'boundary_checks',
                  expected_type = 'logical')

   check_arg_length(arg_value = boundary_checks,
                    arg_name = 'boundary_checks',
                    expected_length = 1)

  }

  check_arg_type(arg_value = pred_horizon,
                 arg_name = 'pred_horizon',
                 expected_type = 'numeric')

  check_arg_gteq(arg_value = pred_horizon,
                 arg_name = 'pred_horizon',
                 bound = 0)

  if(any(pred_horizon > get_max_time(object))){

   if(boundary_checks == TRUE){
    stop("prediction horizon should ",
         "be <= max follow-up time ",
         "observed in training data: ",
         get_max_time(object),
         call. = FALSE)
   }

  }

}

 if(!is.null(na_action)){

  check_arg_type(arg_value = na_action,
                 arg_name = 'na_action',
                 expected_type = 'character')

  check_arg_length(arg_value = na_action,
                   arg_name = 'na_action',
                   expected_length = 1)

  check_arg_is_valid(arg_value = na_action,
                     arg_name = 'na_action',
                     valid_options = c("fail",
                                       "pass",
                                       "omit",
                                       "impute_meanmode"))

 }

}

check_oobag_fun <- function(oobag_fun){

 oobag_fun_args <- names(formals(oobag_fun))

 if(length(oobag_fun_args) != 3) stop(
  "oobag_fun should have 3 input arguments but instead has ",
  length(oobag_fun_args),
  call. = FALSE
 )

 if(oobag_fun_args[1] != 'y_mat') stop(
  "the first input argument of oobag_fun should be named 'y_mat' ",
  "but is instead named '", oobag_fun_args[1], "'",
  call. = FALSE
 )

 if(oobag_fun_args[2] != 'w_vec') stop(
  "the second input argument of oobag_fun should be named 'w_vec' ",
  "but is instead named '", oobag_fun_args[1], "'",
  call. = FALSE
 )

 if(oobag_fun_args[3] != 's_vec') stop(
  "the third input argument of oobag_fun should be named 's_vec' ",
  "but is instead named '", oobag_fun_args[2], "'",
  call. = FALSE
 )

 test_time <- seq(from = 1, to = 5, length.out = 100)
 test_status <- rep(c(0,1), each = 50)

 .y_mat <- cbind(time = test_time, status = test_status)
 .w_vec <- rep(1, times = 100)
 .s_vec <- seq(0.9, 0.1, length.out = 100)

 test_output <- try(oobag_fun(y_mat = .y_mat,
                              w_vec = .w_vec,
                              s_vec = .s_vec),
                    silent = FALSE)

 if(is_error(test_output)){

  stop("oobag_fun encountered an error when it was tested. ",
       "Please make sure your oobag_fun works for this case:\n\n",
       "test_time <- seq(from = 1, to = 5, length.out = 100)\n",
       "test_status <- rep(c(0,1), each = 50)\n\n",
       "y_mat <- cbind(time = test_time, status = test_status)\n",
       "w_vec <- rep(1, times = 100)\n",
       "s_vec <- seq(0.9, 0.1, length.out = 100)\n\n",
       "test_output <- oobag_fun(y_mat = y_mat, w_vec = w_vec, s_vec = s_vec)\n\n",
       "test_output should be a numeric value of length 1",
       call. = FALSE)

 }

 if(!is.numeric(test_output)) stop(
  "oobag_fun should return a numeric output but instead returns ",
  "output of type ", class(test_output)[1],
  call. = FALSE
 )

 if(length(test_output) != 1) stop(
  "oobag_fun should return output of length 1 instead returns ",
  "output of length ", length(test_output),
  call. = FALSE
 )

}

check_beta_fun <- function(beta_fun){

 beta_fun_args <- names(formals(beta_fun))

 if(length(beta_fun_args) != 3) stop(
  "beta_fun should have 3 input arguments but instead has ",
  length(beta_fun_args),
  call. = FALSE
 )

 arg_names_expected <- c("x_node",
                         "y_node",
                         "w_node")

 arg_names_refer <- c('first', 'second', 'third')

 for(i in seq_along(arg_names_expected)){
  if(beta_fun_args[i] != arg_names_expected[i])
   stop(
    "the ", arg_names_refer[i], " input argument of beta_fun ",
    "should be named '", arg_names_expected[i],"' ",
    "but is instead named '", beta_fun_args[i], "'",
    call. = FALSE
   )
 }

 .x_node <- matrix(seq(-1, 1, length.out = 300), ncol = 3)

 test_time <- seq(from = 1, to = 5, length.out = 100)
 test_status <- rep(c(0,1), each = 50)
 .y_node <- cbind(time = test_time, status = test_status)

 .w_node <- matrix(rep(c(1,2,3,4), each = 25), ncol = 1)

 test_output <- try(beta_fun(.x_node, .y_node, .w_node),
                    silent = FALSE)

 if(is_error(test_output)){

  stop("beta_fun encountered an error when it was tested. ",
       "Please make sure your beta_fun works for this case:\n\n",
       ".x_node <- matrix(seq(-1, 1, length.out = 300), ncol = 3)\n\n",
       "test_time <- seq(from = 1, to = 5, length.out = 100)\n",
       "test_status <- rep(c(0,1), each = 50)\n",
       ".y_node <- cbind(time = test_time, status = test_status)\n\n",
       ".w_node <- matrix(rep(c(1,2,3,4), each = 25), ncol = 1)\n\n",
       "test_output <- beta_fun(.x_node, .y_node, .w_node)\n\n",
       "test_output should be a numeric matrix with 1 column and",
       " with nrow(test_output) = ncol(.x_node)",
       call. = FALSE)

 }

 if(!is.matrix(test_output)) stop(
  "beta_fun should return a matrix output but instead returns ",
  "output of type ", class(test_output)[1],
  call. = FALSE
 )

 if(ncol(test_output) != 1) stop(
  "beta_fun should return a matrix with 1 column but instead ",
  " returns a matrix with ", ncol(test_output), " columns.",
  call. = FALSE
 )

 if(nrow(test_output) != ncol(.x_node)) stop(
  "beta_fun should return a matrix with 1 row for each column in x_node ",
  "but instead returns a matrix with ", nrow(test_output), " rows ",
  "in a testing case where x_node has ", ncol(.x_node), " columns",
  call. = FALSE
 )

}

#' check complete cases in new data
#'
#' @param cc (_integer vector_) the indices of complete cases
#' @param na_action the action to be taken for missing values
#'   (see orsf_predict)
#'
#' @return check functions 'return' errors and the intent is
#'   to return nothing if nothing is wrong,
#'   so hopefully nothing is returned.
#' @noRd
#'

check_complete_cases <- function(cc, na_action, n_total){

 if(length(cc) != n_total && na_action == 'fail'){
  stop("Please remove missing values from new_data, or impute them.",
       call. = FALSE)
 }

 if(length(cc) == 0){
  stop("There are no observations in new_data with complete data ",
       "for the predictors used by this orsf object.",
       call. = FALSE)
 }

}

Try the aorsf package in your browser

Any scripts or data that you put into this service are public.

aorsf documentation built on Oct. 26, 2023, 5:08 p.m.