R/surr_rsq_ci.R

Defines functions print.surr_rsq_ci surr_rsq_ci

Documented in print.surr_rsq_ci surr_rsq_ci

#' A function to calculate the interval estimate of the surrogate R-squared measure
#'
#' @description This function generates the interval measure of surrogate R-squared by bootstrap.
#' @param object A object of class `"surr_rsq"` that is generated by the function `"surr_rsq"`.
#' It contains the following components: `surr_rsq`, `reduced_model`, `full_model`, and `data`.
#' @param alpha The significance level alpha. The confidence level is 1-alpha.
#' @param B The number of bootstrap replications.
#' @param asym A logical argument whether use the asymptotic version of our surrogate R-squared.
#' More details are in the paper of Liu et al. (2023).
#' @param parallel logical argument whether conduct parallel for bootstrapping surrogate R-squared
#' to construct the interval estimate. The clusters need to be registered through
#'  `registerDoParallel(cl)` beforehand.
#' @param ... Additional optional arguments.
#'
#' @return An list that contains the CI_lower, CI_upper.
#'
#' @importFrom progress progress_bar
#' @importFrom stats update lm nobs quantile
#' @importFrom foreach foreach %dopar%
#' @importFrom scales percent
#'
#' @examples
#' data("RedWine")
#'
#' full_formula <- as.formula(quality ~ fixed.acidity + volatile.acidity + citric.acid
#' + residual.sugar + chlorides + free.sulfur.dioxide +
#' total.sulfur.dioxide + density + pH + sulphates + alcohol)
#'
#' fullmodel <- polr(formula = full_formula,data=RedWine, method  = "probit")
#'
#' select_model <- update(fullmodel, formula. = ". ~ . - fixed.acidity -
#' citric.acid - residual.sugar - density")
#'
#' surr_rsq_select <- surr_rsq(select_model, fullmodel, data = RedWine, avg.num = 30)
#'
#' # surr_rsq_ci(surr_rsq_select, alpha = 0.05, B = 2000, parallel = FALSE) # Not run, it takes time.
#'
#' # surr_rsq_ci(surr_rsq_select, alpha = 0.05, B = 2000, parallel = TRUE) # Not run, it takes time.
#'
#' @export
#'
surr_rsq_ci <-
  function(object,
           alpha = 0.05,
           B     = 2000,
           asym = FALSE,
           parallel = FALSE,
           ...){
    # Save B+1 surrogate rsq, the first one is calculated from full data.
    B <- B + 1

    # Estract components from surr_rsq object
    # `surr_rsq`, `reduced_model`, `full_model`, and `data`.
    res_s <- object[[1]]
    reduced_model <- object[[2]]
    full_model <- object[[3]]

    # Check if datasets from two model objects are the same!
    data <- checkDataSame(model = reduced_model, full_model = full_model)

    n <- nrow(data)
    # resultTable <- array(NA, dim = c(dim(data),1,B))
    # resultTable[,,1,1] <- res_s
    resultTable <- rep(NA, B)
    resultTable[1] <- res_s[[1]]

    # A function to repeat the bootstrap procedure for getting the surrogate R-squared
    # based on the bootstrap samples.
    doit <- function(data, reduced_model, full_model, asym) {
      BS_data <- data[sample(1:n, n, replace = T), ]
      # Update this to keep reduced and full model to have same dataset
      BS_reduced_model <- update(reduced_model, data = BS_data)
      BS_full_model <- update(full_model, data = BS_data)
      suppressMessages(
        surr_rsq_val <- surr_rsq(model = BS_reduced_model,
                          full_model = BS_full_model,
                          asym = asym)$surr_rsq
      )
      return(surr_rsq_val)
    }

    if (parallel) { # Use parallel for bootstrapping and "progress" package
      # Using foreach to parallel the doit() function for getting R-squared
      # allowing progress bar to be used in foreach -----------------------------
      pb <- txtProgressBar(max = B, style = 3)
      progress <- function(n) setTxtProgressBar(pb, n)
      opts <- list(progress = progress)

      resultTable[2:B] <-
        foreach(i=1:(B-1),
                .packages = c('MASS', 'stats', 'PAsso', 'SurrogateRsq'),
                .combine='c',
                .options.snow = opts
                ) %dopar% {
                  try_n <- 0
                  while(try_n <= 5) {
                    try_out <- try(
                      doit(data, reduced_model, full_model, asym), TRUE
                    )
                    if(inherits(try_out, "try-error")) {
                      try_n <- try_n + 1
                    } else {
                      break
                    }
                  }
                  # return the surrogate R-squared
                  try_out
                  # ProgressBar
                  # pb$tick(tokens = list(letter = progress_repNo[i]))
                }
    } else { # Without using parallel.
      # Add progress bar --------------------------------------------------------
      pb <- progress_bar$new(
        format = "Replication = :letter [:bar] :percent :elapsed | eta: :eta",
        total = B,
        width = 80)
      progress_repNo <- c(1:B)  # token reported in progress bar

      for (j in 2:B) {
        try_n <- 0
        while(try_n <= 5) {
          try_out <- try(
            doit(data, reduced_model, full_model, asym), TRUE
          )

          if(inherits(try_out, "try-error")) {
            try_n <- try_n + 1
          } else {
            break
          }
        }

        # return the surrogate R-squared
        resultTable[j] <- try_out

        # ProgressBar
        pb$tick(tokens = list(letter = progress_repNo[j]))
      }
    }
    # find the alpha/2 quantile as the lower bound
    # print(resultTable)
    resultTable <- as.numeric(resultTable)
    CI_lower <- quantile(x = resultTable[-1], probs = c(alpha/2), na.rm = TRUE)
    CI_lower <- round(CI_lower, 3)

    # find the 1 - alpha/2 quantile as the upper bound
    CI_upper <- quantile(x = resultTable[-1], probs = c(1 - alpha/2), na.rm = TRUE)
    CI_upper <- round(CI_upper, 3)

    rsq_ci <- data.frame(Lower = c(percent(alpha/2, 0.01), CI_lower),
                         Upper = c(percent(1 - alpha/2, 0.01), CI_upper),
                         row.names = c("Percentile", "Confidence Interval"))

    # Thanks to @indenkun for providing this revise.
    return_list <- list("surr_rsq" = res_s,
                        "surr_rsq_ci" = rsq_ci,
                        "surr_rsq_BS" = resultTable[-1],
                        "reduced_model" = reduced_model,
                        "full_model" = full_model,
                        "data" = data)

    # Add class to the result_table
    class(return_list) <- "surr_rsq_ci"

    return(return_list)
  }

#' @title Print surrogate R-squared confidence interval measure
#' @param x A surr_rsq_ci object for printing out results.
#'
#' @param digits A default number to specify decimal digit values.
#' @param ... Additional optional arguments.
#'
#' @name print
#' @method print surr_rsq_ci
#'
#' @return Print surrogate R-squared confidence interval measure
#'
#' @importFrom stats formula
#'
#' @export
#' @keywords internal
print.surr_rsq_ci <- function(x, digits = max(2, getOption("digits")-2), ...) {
  cat("------------------------------------------------------------------------ \n")
  cat("The surrogate R-squared of the model \n------------------------------------------------------------------------ \n",
      paste(format(formula(x$reduced_model$terms)), "\n"),
      "------------------------------------------------------------------------ \n",
      "the interval estimate of the surrogate R-squared is: \n", sep = "")

  print.data.frame(x$surr_rsq_ci, digits = digits)
}
XiaoruiZhu/R2Cate documentation built on March 25, 2024, 2:44 a.m.