R/functions.R

Defines functions make_annotations get_letters.emmGrid get_letters.glht get_letters.default get_letters is_pd is_sym gen_z make_diff_template plot.viztest make_segs vcov.vtcustom coef.eff coef.vtcustom make_vt_data print.viztest viztest.vtsim viztest.default viztest

Documented in gen_z get_letters make_annotations make_diff_template make_vt_data plot.viztest print.viztest viztest

globalVariables(c("bound_end", "bound_start", "est", "label", "lwr", "stim_end", "stim_start", "upr", "vbl", "i", "j", "vij", "zb", "smaller", "larger", "ll", "us", "olap", "ambig", "lwr_add", "upr_add"))

#' Calculate Correspondence Between Pairwise Test and CI Overlaps
#' 
#' @description `viztest()` does a grid search over `range_levels` to find the confidence level(s) such that the (non-)overlaps in 
#' confidence intervals corresponds as closely as possible with the results of pairwise tests.  To the extent that 
#' a level is found that accounts for all pairwise tests, confidence bounds at this level can be added to coefficient or marginal 
#' effects plots to enable readers to reliably identify estimates that are statistically different from each other. 
#' 
#' @details The algorithm first calculates results of a set of pairwise tests. For objects with estimates and a variance-covariance matrix, 
#' normal theory tests are calculated.  Optionally, these tests can be subjected to a multiplicity adjustment.  In the case of simulation results, 
#' something akin to p-values are calculated by identifying the probability that one estimate is larger than another.  To mimic the way we use p-values 
#' in the frequentist case, we subtract the probability of difference from 1, such that smaller values indicate more confidence in the difference.  
#' The algorithm then performs a grid search over `range_levels` at increments of `level_increment`.  For each candidate level, the 
#' confidence intervals for all parameters are calculated.  For each pair of estimates, it identifies whether the confidence intervals 
#' (or credible intervals if the input is a matrix of Bayesian simulation draws) overlaps.  For each candidate level, it calculates the proportion of times where
#' differences are significant/credible and confidence/credible intervals do not overlap or differences are not significant/credible and the intervals do overlap.  
#' The main idea is to find the level(s) such that the (non-)overlaps perfectly correspond with whether the differences are significant.  
#' 
#' If such a level can be found, a visual inspection of confidence or credible intervals at that level will identify whether a pair of estimates is 
#' statistically different or not.  
#' 
#' While most of the parameters are straightforward, the `sig_diffs` argument must be specified such that the stimuli are in order from highest to lowest.  This is most 
#' easily done by using `make_diff_template()` to identify the appropriate order of the comparisons.  
#'
#' @param obj A model object (or any object) where `coef()` and `vcov()` return estimates of coefficients and sampling variability.
#' @param test_level The type I error rate of the pairwise tests.
#' @param range_levels The range of confidence levels to try.
#' @param level_increment Step size of increase between the values of `range_levels`.
#' @param adjust Multiplicity adjustment to use when calculating the p-values for normal theory pairwise tests.
#' @param cifun For simulation results, the method used to calculate the confidence/credible interval either "quantile" (default) or "hdi" for highest density region. 
#' @param include_intercept Logical indicating whether the intercept should be included in the tests, defaults to `FALSE`.
#' @param include_zero Should univariate tests at zero be included, defaults to `TRUE`.
#' @param sig_diffs An optional vector of values identify whether each pair of values is statistically different (1) or not (0).  See Details for more information on specifying this value; there is some added complexity here. 
#' @param tol Tolerance for evaluation of symmetry and positive definiteness. 
#' @param ... Other arguments, currently not implemented.
#' @export
#' 
#' @references David A. Armstrong II and William Poirier. "Decoupling Visualization and Testing when Presenting Confidence Intervals" Political Analysis <doi:10.1017/pan.2024.24>.
#' @importFrom stats coef vcov qt pt p.adjust
#' @importFrom utils combn
#' @importFrom multcomp glht adjusted
#' @importFrom dplyr left_join 
#' @returns A list (of class "viztest") with the following elements: 
#' 1. tab: a data frame with results from the grid search.  The data frame has four variables: `level` - is the confidence level used in the grid search; `psame` - the proportion of (non-)overlaps that match the 
#' normal theory tests; `pdiff` - the proportion of pairwise tests that are statistically significant; `easy` - the ease with which the comparisons are made. 
#' 2. pw_tests: A logical vector indicating which tests are significantly significant. 
#' 3. ci_tests: A logical vector indicating whether the confidence intervals are disjoint (`TRUE`) or overlap (`FALSE`). 
#' 4. combs: The pairwise combinations of stimuli used in the test.  Note, the stimuli are reordered from largest to smallest, so the numbers do not represent the position in the original ordering. 
#' 5. param_names: A vector of the names of the parameters reordered by size - largest to smallest. 
#' 6. L: The lower confidence bounds from the grid search. 
#' 7. U: The upper confidence bounds from the grid search. 
#' 8. est: A data frame with the variables `vbl` - the parameter name; `est` - the parameter estimate; `se` - the parameter standard error. 
#' 9. call: model call
#' @examples
#' data(mtcars)
#' mtcars$cyl <- as.factor(mtcars$cyl)
#' mtcars$hp <- scale(mtcars$hp)
#' mtcars$wt <- scale(mtcars$wt)
#' mod <- lm(qsec ~ hp + wt + cyl, data=mtcars)
#' viztest(mod)
#' 
viztest <- function(obj,
                    test_level = 0.05,
                    range_levels = c(.25, .99),
                    level_increment = 0.01,
                    adjust = c("none", "holm", "hochberg", "hommel", "bonferroni", "BH", "BY", "fdr"),
                    cifun = c("quantile", "hdi"), 
                    include_intercept = FALSE,
                    include_zero = TRUE,
                    sig_diffs = NULL,
                    tol = 1e-08,
                    ...){
  UseMethod("viztest")
}

#' @method viztest default
#' @importFrom dplyr tibble bind_rows
#' @importFrom stats sd
#' @export
viztest.default <- function(obj,
                     test_level = 0.05,
                     range_levels = c(.25, .99),
                     level_increment = 0.01,
                     adjust = c("none", "holm", "hochberg", "hommel", "bonferroni", "BH", "BY", "fdr"),
                     cifun = c("quantile", "hdi"), 
                     include_intercept = FALSE,
                     include_zero = TRUE,
                     sig_diffs = NULL,
                     tol = 1e-08,
                     ...){
  
  adj <- match.arg(adjust)
  lev_seq <- seq(range_levels[1], range_levels[2], by=level_increment)
  resdf <- Inf
  if(!is.null(obj$df.residual) & inherits(obj, "lm")){
    resdf <- obj$df.residual
  }
  bhat <- coef(obj)
  if(is.null(names(bhat)))names(bhat) <- 1:length(bhat)
  V <- vcov(obj)
  if(any(is.na(bhat))){
    na_coef <- which(is.na(bhat))
    if(inherits(V, "matrix")){
      na_var <- which(apply(V, 1, \(x)all(is.na(x))))
    }else{
      na_var <- which(is.na(V))
    }
    coef_var_agree <- try(all(na_var == na_coef), silent=TRUE)
    if(!inherits(coef_var_agree, "try-error")){
      bhat <- bhat[-na_coef]
      V <- V[-na_var, -na_var]
      message(paste0("Rank-deficient model detected - parameter name(s):", paste(names(na_coef), collapse=", "), " removing NAs from estimates/variances and proceeding.\n"))
    }else{
      stop("NAs found in estimates and/or variances, but they are not in compatible places, please solve the problem and try again.\n")
    }
  }
  if(!is_pd(V, tol=tol)){
    stop("Variance-covariance matrix is not positive definite.  Cannot proceed with viztest().\n")
  }
  if(!(length(bhat) == nrow(V) & is_sym(V, tol=tol))){
    stop("Length of coefficient vector does not match dimensions of variance-covariance matrix.  Cannot proceed with viztest().\n")
  }
  if(!include_intercept){
    w_int <- grep("ntercept", names(bhat))
    if(length(w_int) > 0){
      bhat <- bhat[-w_int]
      V <- V[-w_int, -w_int]
    }
  }
  if(include_zero){
    bhat <- c(bhat, zero=0)
    V <- cbind(rbind(V, 0), 0)
  }
  o <- order(bhat, decreasing=TRUE)
  bhat <- bhat[o]
  V <- V[o, o]

  combs <- combn(length(bhat), 2)
  if(is.null(sig_diffs)){
    D <- matrix(0, nrow=length(bhat), ncol=ncol(combs))
    D[cbind(combs[1,], 1:ncol(combs))] <- 1
    D[cbind(combs[2,], 1:ncol(combs))] <- -1
    diffs <- bhat %*% D
    se_diffs <- sqrt(diag(t(D) %*% V %*% D))
    p_diff <- 2*pt(diffs/se_diffs, resdf, lower.tail=FALSE)
    p_diff <- p.adjust(p_diff, method=adj)
    s <- p_diff < test_level
  }else{
    s <- as.logical(sig_diffs)
  }
  L <- sapply(lev_seq, \(l)bhat - qt(1-(1-l)/2, resdf)*sqrt(diag(V)))
  U <- sapply(lev_seq, \(l)bhat + qt(1-(1-l)/2, resdf)*sqrt(diag(V)))
  s_star <- L[combs[1,], ] >= U[combs[2,], ]
  smat <- array(s, dim=dim(s_star))
  # remove tests against zero to calculate "easiness"
  if("zero" %in% rownames(L)){
    w <- which(rownames(L) == "zero")
    new_L <- L[-w, ]
    new_U <- U[-w, ]
    out <- which(combs[1,] == w | combs[2,] == w)
    new_combs <- combs[, -out, drop = FALSE]
    new_combs[which(new_combs > w, arr.ind = TRUE)] <- new_combs[which(new_combs > w, arr.ind = TRUE)] - 1
    new_s <- s[-out]
  }else{
    new_L <- L
    new_U <- U
    new_combs <- combs
    new_s <- s
  }
  diff_sig <- new_L[new_combs[1,new_s],, drop=FALSE] - new_U[new_combs[2,new_s],, drop=FALSE]
  diff_insig <- new_U[new_combs[2,!new_s],, drop=FALSE] - new_L[new_combs[1,!new_s],, drop=FALSE]
  diff_sig[which(diff_sig <= 0, arr.ind=TRUE)] <- NA
  diff_insig[which(diff_insig <= 0, arr.ind=TRUE)] <- NA
  d_sig <- suppressWarnings(apply(diff_sig, 2, min, na.rm=TRUE))
  d_insig <- suppressWarnings(apply(diff_insig, 2, min, na.rm=TRUE))
  d_sig <- ifelse(is.finite(d_sig), d_sig, 0)
  d_insig <- ifelse(is.finite(d_insig), d_insig, 0)
  easiness <- -abs(d_sig-d_insig)
  res <- data.frame(level = lev_seq,
                    psame = apply(s_star, 2, \(x)mean(x == s)),
                    pdiff = mean(s),
                    easy = easiness)
  est_data <- tibble(vbl = names(bhat), 
                     est = bhat, 
                     se = sqrt(diag(V)))
  if(adj == "none"){
    est_data$lwr_add <- bhat - qt(1-test_level/2, resdf)*sqrt(diag(V))
    est_data$upr_add <- bhat + qt(1-test_level/2, resdf)*sqrt(diag(V))
  }else{
    if("zero" %in% names(bhat)){
      tmp_bhat <- bhat
      tmp_V <- V
      zero_ind <- which(names(bhat) == "zero")
      tmp_bhat <- bhat[-zero_ind]
      tmp_V <- V[-zero_ind, -zero_ind]
    }else {
      tmp_bhat <- bhat
      tmp_V <- V
    }
    K <- diag(length(tmp_bhat))
    g <- glht(model = NULL, linfct = K, coef.=tmp_bhat, vcov.=tmp_V)
    smry <- summary(g, test=adjusted(adj))
    cis <- confint(smry)
    dat_add <- data.frame(vbl = names(tmp_bhat), 
                          lwr_add = cis$confint[,"lwr"], 
                          upr_add = cis$confint[,"upr"])
    est_data <- left_join(est_data, dat_add, by="vbl")
  }    
  res <- list(tab = res,
              pw_test = s,
              ci_tests = s_star,
              combs = combs,
              param_names = names(bhat),
              L = L,
              U = U, 
              est = est_data, 
              call = match.call(viztest, call = sys.call()))
  class(res) <- "viztest"
  attr(res, "adjusted") <- adj
  attr(res, "test_level") <- test_level 
  return(res)
}
#' @importFrom stats quantile
#' @importFrom HDInterval hdi
#' @method viztest vtsim
#' @export
viztest.vtsim <- function(obj,
                          test_level = 0.05,
                          range_levels = c(.25, .99),
                          level_increment = 0.01,
                          adjust = c("none", "holm", "hochberg", "hommel", "bonferroni", "BH", "BY", "fdr"),
                          cifun = c("quantile", "hdi"), 
                          include_intercept = FALSE,
                          include_zero = TRUE,
                          sig_diffs = NULL,
                          tol = 1e-08, 
                          ...){
  cif <- match.arg(cifun)
  est <- obj$est
  if(include_zero){
    est <- cbind(est, zero=rep(0, nrow(est)))
  }
  cm <- colMeans(est, na.rm=TRUE)
  o <- order(cm, decreasing=TRUE)
  est <- est[,o]
  combs <- combn(ncol(est), 2)
  if(is.null(sig_diffs)){
    D <- matrix(0, nrow=ncol(est), ncol=ncol(combs))
    D[cbind(combs[1,], 1:ncol(combs))] <- 1
    D[cbind(combs[2,], 1:ncol(combs))] <- -1
    diffs <- est %*% D
    pvals <- apply(diffs, 2, \(x)mean(x > 0))
    pvals <- 2*ifelse(pvals > .5, 1-pvals, pvals)
    s <- pvals < test_level
  }else{
    s <- sig_diffs
  }
  lev_seq <- seq(range_levels[1], range_levels[2], by=level_increment)
  if(cif == "quantile"){
    L <- sapply(lev_seq, \(l)apply(est, 2, quantile, probs=((1-l)/2)))
    U <- sapply(lev_seq, \(l)apply(est, 2, quantile, probs=(1-(1-l)/2)))
    tl_L <- apply(est, 2, quantile, probs=(1 - (1 - test_level/2)))
    tl_U <- apply(est, 2, quantile, probs=(1 - test_level / 2))
  }else{
    LU <- lapply(lev_seq, \(l)apply(est, 2, hdi, credMass = l))
    L <- sapply(LU, \(x)x[1,])
    U <- sapply(LU, \(x)x[2,])
    tl_LU <- apply(est, 2,  hdi, credMass = 1-test_level)
    tl_L <- tl_LU[1,]
    tl_U <- tl_LU[2,]
  }
  s_star <- L[combs[1,], , drop=FALSE] >= U[combs[2,], , drop=FALSE]
  smat <- array(s, dim=dim(s_star))
  if(include_zero){
    w <- which(apply(est, 2, \(x)all(x == 0)))
    new_L <- L[-w, ]
    new_U <- U[-w, ]
    out <- which(combs[1,] == w | combs[2,] == w)
    new_combs <- combs[, -out]
    new_combs[which(new_combs > w, arr.ind = TRUE)] <- new_combs[which(new_combs > w, arr.ind = TRUE)] - 1
    new_s <- s[-out]
  }else{
    new_L <- L
    new_U <- U
    new_combs <- combs
    new_s <- s
  }
  diff_sig <- new_L[new_combs[1,new_s],, drop=FALSE] - new_U[new_combs[2,new_s],, drop=FALSE]
  diff_insig <- new_U[new_combs[2,!new_s],, drop=FALSE] - new_L[new_combs[1,!new_s],, drop=FALSE]
  diff_sig[which(diff_sig <= 0, arr.ind=TRUE)] <- NA
  diff_insig[which(diff_insig <= 0, arr.ind=TRUE)] <- NA
  d_sig <- suppressWarnings(apply(diff_sig, 2, min, na.rm=TRUE))
  d_insig <- suppressWarnings(apply(diff_insig, 2, min, na.rm=TRUE))
  d_sig <- ifelse(is.finite(d_sig), d_sig, 0)
  d_insig <- ifelse(is.finite(d_insig), d_insig, 0)
  easiness <- -abs(d_sig-d_insig)
  res <- data.frame(level = lev_seq,
                    psame = apply(s_star, 2, \(x)mean(x == s)),
                    pdiff = mean(s),
                    easy = easiness)
  cme <- colMeans(est)
  esd <- apply(est, 2, sd)
  est_data <- tibble(vbl = names(cme), 
                     est = cme, 
                     sd = esd)
  est_data$lwr_add <- tl_L
  est_data$upr_add <- tl_U
  res <- list(tab = res,
              pw_test = s,
              ci_tests = s_star,
              combs = combs,
              param_names = colnames(est),
              L = L,
              U = U, 
              est = est_data, 
              call = match.call(viztest, call = sys.call()))
  class(res) <- "viztest"
  attr(res, "adjusted") <- "none"
  attr(res, "test_level") <- test_level 
  return(res)
}

#' Print Method for viztest Objects
#'
#' @description Prints a summary of the results from the `viztest()` function.  
#' 
#' @details The results are printed in such a way that the range of optional levels is produced including the range along with two candidates for the 
#' best levels to use - middle and easiest.  
#'
#' Prints the results from the viztest function
#' @param x An object of class `viztest`.
#' @param best Logical indicating whether the results should be filtered to include only the best level(s) or include all levels
#' @param missed_tests Logical indicating whether the tests not represented by the optimal visual testing intervals should be displayed
#' @param level Which level should be used as the optimal one.  If `NULL`, the easiest optimal level will be used.  Easiness is measured by the sum of the overlap
#' in confidence intervals for insignificant tests plus the distance between the lower and upper bound for tests that are significant.
#' @param ... Other arguments, currently not implemented.
#' @returns Printed results that give the level(s) that correspond most closely with the pairwise test results.  The values returned are the smallest, 
#' largest, middle and easiest.  By default this function also reports the tests that are not captured by the (non-)overlaps in confidence intervals
#' when each different level is used. 
#' 
#' @export
#' @importFrom dplyr mutate
#' @importFrom stats median
#' @method print viztest
print.viztest <- function(x, ..., best=TRUE, missed_tests=TRUE, level=NULL){
  cat("\nCorrespondents of PW Tests with CI Tests\n")
  tmp <- x$tab[which(x$tab$psame == max(x$tab$psame)), ]
  lowest <- tmp[1,]
  highest <- tmp[nrow(tmp), ]
  middle <- tmp[floor(median(1:nrow(tmp))), ]
  easiest <- tmp[which.max(tmp$easy), ]
  b_levs <- bind_rows(lowest, middle, highest, easiest) %>% 
    mutate(method = c("Lowest", "Middle", "Highest", "Easiest"))
  if(best){
    print(b_levs)
  }else{
    print(x$tab)
  }
  if(missed_tests){
    if(is.null(level)){
      ## missed tests for lowest level
      l1 <- b_levs$level[1]
      w11 <- which(round(x$tab$level, 10) == round(l1, 10))
      mt1 <- data.frame(bigger = x$param_names[x$combs[1,]],
                       smaller = x$param_names[x$combs[2,]],
                       pw_test = ifelse(x$pw_test, "Sig", "Insig"),
                       ci_olap = ifelse(x$ci_tests[,w11], "No", "Yes"))
      w21 <- which((mt1$pw_test == "Sig" & mt1$ci_olap == "Yes") |
                    (mt1$pw_test == "Insig" & mt1$ci_olap == "No"))
      
      if(length(w21) > 0){
        cat("\nMissed Tests for Lowest Level (n=", length(w21), " of ", length(x$pw_test), ")\n", sep="")
        print(mt1[w21, ])
      }
      
      ## missed tests for middle level
      l2 <- b_levs$level[2]
      w12 <- which(round(x$tab$level, 10) == round(l2, 10))
      mt2 <- data.frame(bigger = x$param_names[x$combs[1,]],
                        smaller = x$param_names[x$combs[2,]],
                        pw_test = ifelse(x$pw_test, "Sig", "Insig"),
                        ci_olap = ifelse(x$ci_tests[,w12], "No", "Yes"))
      w22 <- which((mt2$pw_test == "Sig" & mt2$ci_olap == "Yes") |
                     (mt2$pw_test == "Insig" & mt2$ci_olap == "No"))
      if(length(w22) > 0){
        cat("\nMissed Tests for Middle Level (n=", length(w22), " of ", length(x$pw_test), ")\n", sep="")
        print(mt2[w22, ])
      }
      
      ## missed tests for highest level
      l3 <- b_levs$level[3]
      w13 <- which(round(x$tab$level, 10) == round(l3, 10))
      mt3 <- data.frame(bigger = x$param_names[x$combs[1,]],
                        smaller = x$param_names[x$combs[2,]],
                        pw_test = ifelse(x$pw_test, "Sig", "Insig"),
                        ci_olap = ifelse(x$ci_tests[,w13], "No", "Yes"))
      w23 <- which((mt3$pw_test == "Sig" & mt3$ci_olap == "Yes") |
                     (mt3$pw_test == "Insig" & mt3$ci_olap == "No"))
      if(length(w23) > 0){
        cat("\nMissed Tests for Highest Level (n=", length(w23), " of ", length(x$pw_test), ")\n", sep="")
        print(mt3[w23, ])
      }
      
      ## missed tests for easiest level
      l4 <- b_levs$level[4]
      w14 <- which(round(x$tab$level, 10) == round(l4, 10))
      mt4 <- data.frame(bigger = x$param_names[x$combs[1,]],
                        smaller = x$param_names[x$combs[2,]],
                        pw_test = ifelse(x$pw_test, "Sig", "Insig"),
                        ci_olap = ifelse(x$ci_tests[,w14], "No", "Yes"))
      w24 <- which((mt4$pw_test == "Sig" & mt4$ci_olap == "Yes") |
                     (mt4$pw_test == "Insig" & mt4$ci_olap == "No"))
            tmp <- x$tab[which(x$tab$psame == max(x$tab$psame)), ]
      
      if(length(w24) > 0){
        cat("\nMissed Tests for Easiest Level (n=", length(w24), " of ", length(x$pw_test), ")\n", sep="")
        print(mt4[w24, ])
      }
      if(length(w21) == 0){
        cat("\nAll ", length(x$pw_test), " tests properly represented for by CI overlaps.\n")
      }
    }else{
      w <- which(round(x$tab$level, 10) == round(level, 10))
      mt <- data.frame(bigger = x$param_names[x$combs[1,]],
                       smaller = x$param_names[x$combs[2,]],
                       pw_test = ifelse(x$pw_test, "Sig", "Insig"),
                       ci_olap = ifelse(x$ci_tests[,w], "No", "Yes"))
      w2 <- which((mt$pw_test == "Sig" & mt$ci_olap == "Yes") |
                    (mt$pw_test == "Insig" & mt$ci_olap == "No"))
      if(length(w2) > 0){
        cat("\nMissed Tests (n=", length(w2), " of ", length(x$pw_test), ")\n", sep="")
        print(mt[w2, ])
      }else{
        cat("\nAll ", length(x$pw_test), " tests properly represented for by CI overlaps.\n")
      }
    }
  }
}

#' Make custom visual testing data
#' 
#' @description Makes custom visual testing objects that can be used as input to the `viztest()` function.  This is useful in the case
#' where `coef()` and `vcov()` do not function as expected on objects of interest, where the user wants to intervene with some 
#' modification to the usual estimates or (more likely) variance-covariance matrix or where normal theory tests may not be 
#' as useful (e.g., in the case of simulations of non-normal values).  The examples section below shows how this could be leveraged
#' to use a heteroskedasticity-consistent covariance matrix in the test rather than the one returned by `lm()`. 
#' 
#' @param estimates A vector of estimates if type is `"est_var"` and or a number of simulations by 
#' number of parameters matrix of simulated values if type is `"sim"`.  
#' @param variances In the case of independent estimates, a vector of variances of the same length 
#' as `estimates` if type is `"est_var"`.  These will be used as the diagonal elements in a variance-covariance matrix with 
#' zero covariances.  Alternatively, if type is `"est_var"`, this could be a variance-covariance matrix, with the same number 
#' of rows and columns as there are elements in the `estimates` vector.  If type is `"sim"`, variances should be `NULL`, but 
#' will be disregarded in any event. Also, note, these should be variances of the estimates (e.g., squared standard errors) and not 
#' raw variances from the data. 
#' @param type Indicates the type of input data either estimates with variances or a variance-covariance matrix or data from
#' a simulation. 
#' @param tol Tolerance for evaluation of symmetry and positive definiteness. 
#' @param ... Other arguments passed down, currently not implemented. 
#' @returns 
#' 1. If the input is a vector of parameter estimates and a variance-covariance matrix, then a list with estimates and a variance-covariance matrix of class `"vtcustom"` is returned.  In this case, the functionms `coef.vtcustom()` and `vcov.vtcustom()` are 
#' used to extract the coefficients and variance-covariance matrix in a way that will work with `viztest.default()`. 
#' 2. If the input is a matrix of simulation draws, an object of class `"vtsim"` that has a single element - the data giving the draws from the simulation is returned.  In this case, `viztest.vtsim()` does the relevant testing.  
#' @export
#' 
#' @examples
#' data(mtcars)
#' mtcars$cyl <- as.factor(mtcars$cyl)
#' mtcars$hp <- scale(mtcars$hp)
#' mtcars$wt <- scale(mtcars$wt)
#' mod <- lm(qsec ~ hp + wt + cyl, data=mtcars)
#' V <- sandwich::vcovHC(mod, "HC3")
#' vtdat <- make_vt_data(coef(mod), V)
#' viztest(vtdat, 
#'         test_level = .025, 
#'         include_intercept = FALSE, 
#'         include_zero = FALSE)
make_vt_data <- function(estimates, variances=NULL, type=c("est_var", "sim"), tol = 1e-08, ...){
  typ <- match.arg(type)
  if(typ == "est_var"){
    if(is.null(variances)){
      stop("When specifying type 'est_var', variances must also be specified.\n")
    }
    if(!inherits(variances, "matrix") & length(estimates) != length(variances)){
      stop("Variances must either be a matrix with rows and columns equal to those in estimates or a vector of the same length as estimates.\n")
    }
    if(inherits(variances, "matrix")){
      if (nrow(variances) != length(estimates) | ncol(variances) != length(estimates)){
        stop("Variance-covariance matrix must have the same number of rows and columns as there are elements in the estimates vector.\n")
      }
    }
    if(any(is.na(estimates))){
      na_coef <- which(is.na(estimates))
      if(inherits(variances, "matrix")){
        na_var <- which(apply(variances, 1, \(x)all(is.na(x))))
      }else{
        na_var <- which(is.na(variances))
      }
      coef_var_agree <- try(all(na_var == na_coef), silent=TRUE)
      if(!inherits(coef_var_agree, "try-error")){
        estimates <- estimates[-na_coef]
        variances <- variances[-na_var, -na_var]
        message(paste0("Rank-deficient model detected - parameter name(s):", paste(names(na_coef), collapse=", "), " removing NAs from estimates/variances and proceeding.\n"))
      }else{
        stop("NAs found in estimates and/or variances, but they are not in compatible places, please solve the problem and try again.\n")
      }
    }
    if(!inherits(variances, "matrix")){
      V <- diag(variances)
    }else{
      V <- variances
    }
    if(!is_pd(V, tol=tol)){
      stop("Variance-covariance matrix is not positive definite.  Cannot proceed with viztest().\n")
    }
    out <- structure(.Data = list(coef = estimates, vcov = V), class="vtcustom")
  }else{
    if(!is.null(variances)){
      message("When type is 'sim', variances are disregarded.\n")
    }
    out <- structure(.Data = list(est = estimates), class="vtsim")
  }
  return(out)
}


#' @method coef vtcustom
coef.vtcustom <- function(object, ...){
  object$coef
}

#' @method coef eff
coef.eff <- function(object, ...){
  object$fit
}

#' @method vcov vtcustom
vcov.vtcustom <- function(object, ...){
  object$vcov
}

#' @importFrom dplyr filter
make_segs <- function(.data, vdt = .02, ...){
  segs <- NULL
  for(i in 1:nrow(.data)){
    if(any(.data$lwr[i:nrow(.data)] < .data$upr[i])){
      segs <- rbind(segs, data.frame(stim_start=i, stim_end=(i-1) + max(which(.data$lwr[i:nrow(.data)] <= .data$upr[i])), bound_start=.data$upr[i], bound_end=.data$upr[i]))
    }
  }
  rg <- max(.data$upr, na.rm=TRUE) - min(.data$lwr, na.rm=TRUE)
  amb_thresh <- vdt*rg
  segs$ambiguous <- FALSE
  for(i in 1:(nrow(segs)-1)){
    if(any(abs(.data$lwr[(segs$stim_start[i]+1):nrow(.data)] - segs$bound_start[i]) < amb_thresh)){
      segs$ambiguous[i] <- TRUE
      segs$stim_end[i] <- max(which((abs(.data$lwr[(segs$stim_start[i]+1):nrow(.data)] - segs$bound_start[i]) < amb_thresh)==T))+i
    }
  } 
  segs
}


#' Plot Method for viztest Objects
#' 
#' Plots the output of viztest objects with optional reference lines 
#' @param x Object to be plotted, should be of class `viztest`
#' @param add_test_level Add the (1-test level) confidence interval to the plot.  For this to work, you must have specified `add_test_level` in the call to `viztest()` so that the appropriate confidence intervals can be calculated.  
#' @param ref_lines Reference lines to be plotted - one of "all", "ambiguous", "none".  This could also be a vector of stimulus names to plot - they should be the same as the names of the estimates in `x$est`. See details for explanation. 
#' @param viz_diff_thresh Threshold for identifying visual difficulty, see details. 
#' @param make_plot Logical indicating whether the plot should be constructed or the data returned. 
#' @param level Level at which to plot the estimates.  Accepts both numeric entries or one of "ce", "max", "min", "median" - defaults to "ce", the cognitively easiest level.  
#' @param trans A function to transform the estimates and their confidence intervals like `plogis`.
#' @param est_point_args A list of arguments to be passed to `geom_point()` that plots the point estimates. 
#' @param opt_ci_args A list of arguments to be passed to `geom_linerange()` to plot the optimal visual testing intervals. 
#' @param test_ci_args A list of arguments to be passed to `geom_linerange()` to plot the (1-test level) confidence intervals.
#' @param ref_line_args A list of arguments to be passed to `geom_segment()` to plot the reference lines.
#' @param scale_linewidth_args A list of arguments to be passed to `scale_linewidth_manual()` to change the thickness of the confidence intervals.
#' @param scale_color_args A list of arguments to be passed to `scale_color_manual()` to change the default color of the different confidence intervals when `add_test_level=TRUE`.
#' @param overall_theme A theme function that will be passed to the `ggplot` call before `theme()`. Default is `theme_bw`.
#' @param theme_arg A list of arguments to be passed to `theme()` to modify the theme of the plot.
#' @param remove_caption Logical indicating whether caption should be removed.  By default, it is printed to alert the user. 
#' @param ... Other arguments passed down.  Currently not implemented.
#' @details The `ref_lines` argument identifies what reference lines will be plotted in the figure.  For any particular stimulus, the reference lines run along the upper bound of the stimulus from the stimulus location to the most distant stimulus with overlapping confidence intervals.  
#' When `ref_lines = "all"`, all lines are plotted, though in displays with many stimuli, this can make for a messy graph.  When `"ref_lines = ambiguous"` is specified, then only the ones that help discriminate in cases where the result might be visually difficult to discern are plotted. 
#' A comparison is determined to be visually difficult if the upper bound of the stimulus in question is within `viz_diff_thresh` times the difference between the smallest lower bound and the largest upper bound.  If `ref_lines = "non"`, then none of the reference lines are plotted. 
#' Alternatively, you can specify the names of stimuli whose reference lines will be plotted.  These should be the same as the names in the data.  The `viztest()` function returns an object `est`, which contains the data that are used as input to this function.  The variable `vbl` in 
#' The `est` data frame contains the stimulus names. 
#' @returns By default, a ggplot is returned.  If `make_plot = FALSE`, the data for the plot are returned, but the plot is not constructed.  If the data are returned, the following variables are in the dataset: 
#' * `vbl` - The name of the parameter. 
#' * `est` - The parameter estimate
#' * `se` - The standard error of the estimate
#' * `lwr`, `upr` - The inferential confidence bounds being used
#' * `lwr_add`, `upr_add` - The confidence intervals that come from `add_level`. 
#' * `label` - Factor giving the parameter names
#' * `stim_start`, `stim_end` - y-axis bounds of the reference line
#' * `bound_start`, `bound_end` - x-axis values for reference lines
#' * `ambiguous` - Logical vector indicating whether the comparison is considered "ambiguous". 
#' @method plot viztest
#' @importFrom dplyr left_join arrange `%>%` join_by
#' @importFrom ggplot2 ggplot geom_pointrange geom_segment aes labs geom_point geom_linerange unit theme_bw scale_color_manual scale_linewidth_manual theme margin
#' @importFrom ggtext element_textbox_simple
#' @examples
#' data(mtcars)
#' mod2 <- lm(mpg ~ as.factor(cyl) + vs + am + as.factor(gear), data = mtcars)
#' v <- viztest(mod2)
#' plot(v, ref_lines="ambiguous") + ggplot2::theme_classic()
#' 
#' @export
plot.viztest <- function(x, 
                         ..., 
                         add_test_level = TRUE, 
                         ref_lines="none", 
                         viz_diff_thresh = .02, 
                         make_plot=TRUE, 
                         level=c("ce","max","min","median"),
                         trans=I, 
                         est_point_args = list(color="black", size=2), 
                         opt_ci_args = list(),
                         test_ci_args = list(),
                         ref_line_args = list(color="gray75", linetype=3),
                         scale_linewidth_args = list(values=c(3.5, .5)),
                         scale_color_args = list(values = c("gray75", "black")),
                         overall_theme = theme_bw,
                         theme_arg = list(legend.position="top", 
                                          plot.caption = element_textbox_simple(width = unit(1, "npc"),  # Wraps to plot width
                                                                                halign = 0,
                                                                                margin = margin(1, 0, 0, 0,"lines"))),# Prevents overlap of caption and x-axis title
                         remove_caption=FALSE){
  inp <- x$est
  tmp <- x$tab[which(x$tab$psame == max(x$tab$psame)), ]
  if(!is.numeric(level)){
    lvl <- match.arg(level)
    level <- switch(lvl,
                    "ce" = tmp[which(tmp$easy == max(tmp$easy)), ]$level,
                    "max" = tmp[which(tmp$level == max(tmp$level)), ]$level,
                    "min" = tmp[which(tmp$level == min(tmp$level)), ]$level,
                    "median" = tmp[which(round(tmp$level,2) == round(median(tmp$level),2)), ]$level)
  }
  w <- which(round(level, 10) == round(x$tab$level, 10))
  if(length(w) == 0)stop("level must be one in x$tab$level or one of ce, max, min, or median.\n")
  if(!(level %in% tmp$level))warning("chosen level outside of range of maximally representing CI overlaps. Visual tests may not be faithfull to pairwise test results!!!")
  inp$lwr <- x$L[,w]
  inp$upr <- x$U[,w]
  inp <- inp %>% arrange(est)
  inp <- inp %>% filter(vbl != "zero")
  segs <- make_segs(inp, vdt=viz_diff_thresh)
  segs$vbl <- rownames(segs)
  inp$label <- factor(1:nrow(inp), labels=inp$vbl)
  inp <- left_join(inp, segs, by=join_by(vbl))
  inp[,c("est","lwr","upr","lwr_add","upr_add","bound_start","bound_end")] <- apply(inp[,c("est","lwr","upr","lwr_add","upr_add","bound_start","bound_end")],2,trans)
  if(any(inp$vbl == "zero"))inp <- inp[-which(inp$vbl == "zero"), ]
  if(!make_plot){
    res <- inp
  }else{
    if(add_test_level){
      opt_ci_args$mapping <- aes(x=est, xmin=lwr, xmax=upr, y=label, 
                                 colour=sprintf('Optimal Visual Intervals (%.1f%%)', level*100), 
                                 linewidth=sprintf('Optimal Visual Intervals (%.1f%%)', level*100))
      test_ci_args$mapping <- aes(xmin = lwr_add, xmax=upr_add, y=label, 
                                  colour=sprintf('Standard Confidence Intervals (%.1f%%)', (1-attr(x, "test_level"))*100), 
                                  linewidth=sprintf('Standard Confidence Intervals (%.1f%%)', (1-attr(x, "test_level"))*100))
      lab_args <- list(x="Estimate", y= "Parameter", colour="", linewidth="")
    }else{
      opt_ci_args$mapping <- aes(x=est, xmin=lwr, xmax=upr, y=label)
      lab_args <- list(x=sprintf("Estimate and Optimal Visual CI (%.1f%%)", 
                               level*100), 
                       y = "Parameter")
    }
    est_point_args$mapping <- aes(x=est, y=label)
    ref_line_args$mapping <- aes(x=bound_start, xend=bound_end, y=stim_start, yend=stim_end)
    g <- ggplot(inp) + 
      do.call(geom_linerange, opt_ci_args) +
      do.call(scale_linewidth_manual,scale_linewidth_args) +
      do.call(scale_color_manual,scale_color_args)+
      overall_theme() +
      do.call(theme,theme_arg)
    if(add_test_level & all(c("lwr_add", "upr_add") %in% colnames(inp))){
      g <- g + do.call(geom_linerange, test_ci_args)
    }
    g <- g + do.call(geom_point, est_point_args)
    if("all" %in% ref_lines){
      g <- g + do.call(geom_segment, ref_line_args)
    }  
    if( "ambiguous" %in% ref_lines){
      ref_line_args$data <- inp[which(inp$ambiguous), ]
      g <- g + do.call(geom_segment, ref_line_args)
    }
    if(!any(c("ambiguous", "all", "none") %in% ref_lines)){
      incl <- which(ref_lines %in% inp$vbl)
      if(length(incl) == 0)stop("ref_lines should either be one of (all, ambiguous, or none) or a vector of names consistent with x$est$vbl.\n")
      ref_line_args$data <- inp[incl, ]
      g <- g + do.call(geom_segment, ref_line_args)
    }
    if(!remove_caption)lab_args$caption <- "Confidence intervals have been adjusted to permit visual testing as per Armstrong and Poirier (2025) [doi:10.1017/pan.2024.24]"
    g <- g + do.call(labs, lab_args)
    res <- g
  }
  return(res)
}


#' Make Template for Pairwise Significance Input
#' 
#' Provides a template for producing a binary vector indicating whether each pair of 
#' estimates has a significant difference. 
#' 
#' @param estimates A vector of point estimates (ideally, a named vector). 
#' @param include_zero Logical indicating whether tests against zero should be included. 
#' @param include_intercept Logical indicating whether the intercept should be included. 
#' @param ... Other arguments passed down, currently not implemented. 
#' 
#' @details The `viztest()` function uses a normal difference of means test to identify
#' whether there is a significant difference or not.  While this test could be done 
#' with adjustments for multiplicity or robust standard errors of all different kinds, 
#' there may be times when the user would prefer to identify the significant differences 
#' manually.  The `viztest()` function internally reorders the estimates from largest to smallest
#' so this function does that and then prints the pairs that will correspond with the 
#' visual testing grid search being done by `viztest()`.  
#' 
#' Please note that the `include_zero` and `include_intercept` arguments should be set the same
#' here as they are in your call to `viztest()`.  If they are not, `viztest()` will stop because
#' the results from the comparison of confidence intervals will have different dimensions than the 
#' differences that are manually provides. 
#' 
#' @returns A two-column data frame containing the names of the larger and smaller parameters in the appropriate order. This can be 
#' used to identify the appropriate order in which to specify the `sig_diffs` argument to `viztest()`. 
#' @examples
#' make_diff_template(estimates = c(e1 = 2, e2 = 1, e3 = 3))
#' @export
make_diff_template <- function(estimates, include_zero=TRUE, include_intercept=FALSE, ...){
  if(is.null(names(estimates))){
    names(estimates) <- paste0("est_", 1:length(estimates))
  }
  if(include_zero)estimates <- c(estimates, zero=0)
  if(!include_intercept)estimates <- estimates[!grepl("ntercept", names(estimates))]
  est_num <- seq_along(estimates)
  o <- order(estimates, decreasing = TRUE)
  estimates <- estimates[o]
  est_num <- est_num[o]
  combs <- t(combn(names(estimates), 2))
  colnames(combs) <- c("Larger", "Smaller")
  as.data.frame(combs)
}

#' Calculate z-score for Confidence Interval Overlap
#' 
#' Calculates the z-score required such that confidence intervals do not overlap under the null hypothesis withe a specified probability. 
#' @param b A vector of estiamtes
#' @param v The variance-covariance matrix for `b`. 
#' @param alpha The desired probability at which the confidence intervals do not overlap under the null hypothesis.  
#' @param df Degrees of freedom for the t-distribution, defaults to `Inf` indicating a normal distribution. 
#' @param ... Other arguments passed down, currently not implemented.
#' 
#' @importFrom stats cov2cor qnorm pnorm pt qt sd
#' @importFrom dplyr summarise group_by row_number mutate select
#' @importFrom tidyr pivot_longer
#' 
#' @returns A list with two elements:
#' `ave_z`: A data frame with one row for each estimate in `b` and the following variables: 
#' * `vij`: observation number 
#' * `s_zb`: standard deviation of the z-scores across all pairs of intervals containing that estimate. 
#' * `min_zb`, `max_zb`: The minimum and maximum z-scores for the pairs of intervals containing that estimate.
#' * `zb`: The mean z-score for the pairs of intervals containing that estimate.
#' * `ci`: The confidence level corresponding to `zb`. 
#' `all_z`: A data frame with one row for each pair of estimates in `b` and the following variables:
#' * `i`, `j`: The indices of the two estimates in the pair.
#' * `s_i`, `s_j`: The standard errors of the two estimates in the pair.
#' * `theta`: The ratio of the standard errors of the two estimates.
#' * `rho`: The correlation between the two estimates.
#' * `zb`: The z-score for the pair of estimates.
#' * `ci` : The confidence level corresponding to `zb`.
#' * `olap_ave` The probability that the two intervals do not overlap under the null hypothesis. 
#' * `olap_84` The probability that two 84% confidence intervals for the estimates in the pair would not overlap under the null hypothesis. 
#' @references 
#' Harvey Goldstein and Michael J.R. Healy.  (1995) "The Graphical Presentation of A Collection of Means." Journal of the Royal Statistical Society, Series A 158(1): 175-177 <doi:10.2307/2983411>. 
#' David Afshartous and Richard A. Preston.  (2010) "Confidence Intervals for Dependent Data: Equating Non-overlap with Statistical Significance." Computational Statistics and Data Analysis 54: 2296-2305 <doi:10.1016/j.csda.2010.04.011>
#' @export
#' @examples
#' data(mtcars)
#' mod <- lm(mpg ~ wt + hp + disp + vs, data=mtcars)
#' gen_z(coef(mod), vcov(mod))
#' 
gen_z <- function(b, v, alpha=.05, df = Inf, ...){
  f2.theta.rho <- function(theta, rho) { theta/(theta^{2} +1 - 2*rho*theta)^(1/2) + (1/theta)/(1 + theta^(-2) - 2*rho*theta^{-1})^(1/2)}
  f2.alpha.theta.rho.t.beta <- function(alpha, theta, df, rho) {qt(alpha/2, df, lower.tail=FALSE)/f2.theta.rho(theta, rho) }
  se <- sqrt(diag(v))
  rho <- cov2cor(v)
  eg <- expand.grid(i=1:length(b), j=1:length(b)) 
  eg <- eg[which(eg[,1] < eg[,2]), ]
  eg$s_i <- se[eg[,1]]
  eg$s_j <- se[eg[,2]]
  eg$theta = eg$s_i/eg$s_j
  eg$rho <- rho[cbind(eg[,1], eg[,2])]
  eg$zb <- f2.alpha.theta.rho.t.beta(alpha, eg$theta, df, eg$rho)
  eg$ci <- 1-2*pt(eg$zb, df, lower.tail=FALSE)
  eg$olap_ave = 2*pnorm(mean(eg$zb)*f2.theta.rho(eg$theta, eg$rho), lower.tail=FALSE)
  eg$olap_84 = 2*pnorm(qnorm(.92)*f2.theta.rho(eg$theta, eg$rho), lower.tail=FALSE)
  
  tmp <- eg %>% select(i, j, zb) %>% 
    mutate(obs = row_number()) %>% 
    pivot_longer(c("i", "j"), names_to = "ij", values_to = "vij") %>% 
    group_by(vij) %>% 
    summarise(s_zb = sd(zb), 
              min_zb = min(zb), 
              max_zb = max(zb), 
              zb = mean(zb)
    ) %>% 
    mutate(ci = 1-2*pt(zb, df, lower.tail=FALSE))
  
  return(list(ave_z = tmp, all_z = eg))
}


#' Internal function to check for symmetry
#' @noRd
is_sym <- function(x, tol=1e-08){
  rnd <- abs(floor(log10(abs(tol))))
  x <- round(x,rnd)
  sqr <- nrow(x) == ncol(x) 
  if(sqr){
    lt_ut <- all(x[lower.tri(x)] == t(x)[lower.tri(x)])
  }else{
    lt_ut <- FALSE
  }
  lt_ut
}

#' Internal function to check for positive definiteness
#' @noRd
is_pd <- function(x, tol=1e-08){
  if(is_sym(x, tol = tol) & is.numeric(x)){
    ev <- eigen(x, only.values = TRUE)$values
    ev <- ifelse(abs(ev) < tol, 0, ev)
    !any(ev < 0)
  }else{
    FALSE
  }
}


#' Get Letters for Multiple Comparisons
#' 
#' Gets the letter matrix for a compact letter display.  This can be passed to the `letter_plot()` function from the `psre` package to produce plots of 
#' confidence intervals with a letter display.  
#' 
#' @param x An object that can be one of the following classes: an object of class `glht` produced by `glht()` from the `multcomp` package, 
#' an object of class `emmGrid` produced by the `emmeans` package, or a list with elements `est` - a vector of estimates and `var` - a variance-covariance matrix for the estimates.
#' @param ... Additional arguments passed down either to `cld` if the object is an `emmGrid` class or `summary.glht` if the object is a `glht` class.  
#' If `x` is a list of estimates and variances, it will be converted to an `emmGrid` object internally and the `emmGrid` method dispatched, 
#' `...` will be passed to `cld` in that case. Additional arguments for the `summary.glht()` include `test` and a host of others.  See the help file for `summary.glht` for examples. 
#' Additional arguments for `cld` include `adjust` for multiplicity corrections and others.  See `?emmeans:::cld.emmGrid` for options and details. 
#' 
#' @returns A logical index indicating which estimates are in which letter group. 
#' @export
#' 
#' @importFrom multcomp glht cld
get_letters <- function(x=NULL, ...){
  UseMethod("get_letters")
}

#' @importFrom emmeans emmobj
#' @method get_letters default
#' @export
get_letters.default <- function(x, ...){
  if(!all(c("est", "var") %in% names(x))){
    stop("x should be a list with elements 'est' - a vector of estimates and 'var' - a variance-covariance matrix for the estimates.\n")
  }
  est <- x$est
  nms <- names(est)
  if(is.null(nms)){
    nms <- paste0("b_", seq_along(est))
    names(est) <- nms
  }
  obj <- emmobj(est, x$var, levels=nms)
  get_letters(obj, ...)
}

#' @method get_letters glht 
#' @export
get_letters.glht <- function(x, ...){
  args <- list(...)
  args$object <- x
  s <- do.call(summary, args)
  cl <- cld(s)
  cl$mcletters$LetterMatrix
}

#' @method get_letters emmGrid
#' @export
get_letters.emmGrid <- function(x, ...){
  args <- list(...)
  args$Letters <- letters
  args$object <- x
  cl <- do.call(cld, args)
  grps <- cl$.group
  unletters <- unique(c(unlist(strsplit(grps, ""))))
  if(any(unletters == " ")){
    unletters <- unletters[-which(unletters == " ")]
  }
  unletters <- sort(unletters)
  out <- matrix(NA, nrow = nrow(cl), ncol=length(unletters))
  colnames(out) <- unletters
  rownames(out) <- cl[[1]]
  for(l in seq_along(unletters)){
    out[,l] <- grepl(unletters[l], grps)
  }
  out
}



#' Make Annotations for Significance Brackets
#' 
#' Makes a list of annotations for significance brakcets produced by the `geom_signif()` function from the `ggsignif` package.  The annotations are added for 
#' pairs of estimates whose confidence intervals overlap, but the estimates are nonetheless significantly different from each other.  
#' 
#' @param obj An object of class `viztest` produced by the `viztest()` function.
#' @param type Indicates whether annotations are produced for overlapping intervals that are significantly different from each other or not. The `"auto"` option will
#' find the type that produces the fewest annotations. If `type="discrepancies"`, annotations will be made for pairs of estimates whose test results do not correspond with the (non-)overlaps in the confidence intervals. 
#' @param tol Tolerance for determining whether intervals are close enough to be considered ambiguous.  This also plots significance flags for intervals that 
#' do not overlap, but the distance between them is smaller than the tolerance.  The default is zero, but increasing the value will potentially produce more significance flags. 
#' @param nudge A vector of the same length as the number of brackets.  This will nudge the y-position of the brakcet by the indicated amount.  This will be difficult to 
#' specify ahead of time, but can be specified to clean up a plot after an initial run. 
#' @param ... Other arguments, currently ignored. 
#' 
#' @importFrom stats confint
#' @export
#' @examples
#' data(chickwts)
#' chick_mod <- lm(weight ~ feed, data=chickwts)
#' library(marginaleffects)
#' chick_preds <- avg_predictions(chick_mod, variables="feed")
#' b <- coef(chick_preds)
#' names(b) <- chick_preds$feed
#' v <- vcov(chick_preds)
#' chick_vt_data <- make_vt_data(b, v)
#' chick_vt <- viztest(chick_vt_data, test_level = 0.0001, include_zero=FALSE)
#' chick_vt
#' 
#' make_annotations(chick_vt, type="discrepancies")
make_annotations <- function(obj, type = c("auto", "significant", "insignificant", "discrepancies"),  tol = 0, nudge=NULL, ...){
  typ <- match.arg(type)
  if(typ == "auto"){
    ns <- make_annotations(obj, type="insignificant", tol=tol, nudge=nudge, ...)
    s <- make_annotations(obj, type="significant", tol=tol, nudge=nudge, ...)
    if(length(s$annotations) > length(ns$annotations)){
      return(ns)
    }else{
      return(s)
    }    
  }
  lwr <- obj$est$lwr_add
  names(lwr) <- obj$est$vbl
  upr <- obj$est$upr_add
  names(upr) <- obj$est$vbl
  rng <- max(upr) - min(lwr)
  if(typ != "discrepancies"){
    tmp_combs <- data.frame(
      smaller = obj$param_names[obj$combs[2,]], 
      larger = obj$param_names[obj$combs[1,]] 
    ) %>% 
      mutate(us = upr[smaller], 
             ll = lwr[larger],
             ul = upr[larger], 
             olap = ll < us) 
    if(typ == "significant"){
    tmp_combs <- tmp_combs %>%
      mutate(s = obj$pw_test, 
             ambig = ll > us & (ll - us) <= tol) %>% 
      filter(s & (olap | ambig))
    }else{
      tmp_combs <- tmp_combs %>%
        mutate(s = !obj$pw_test) %>% 
        filter(s & olap)
    }
    if(is.null(nudge))nudge <- rep(0, nrow(tmp_combs))
    if(length(nudge) != nrow(tmp_combs)){
      warning("nudge must be NULL or same length as number of annotations, currently ignored")
      nudge <- rep(0, nrow(tmp_combs))
    } 
    annot_flag <- ifelse(typ == "significant", "*", "NS")
    list(
      annotations = rep(annot_flag, nrow(tmp_combs)), 
      y_position = tmp_combs$ul + .05*rng + nudge, 
      xmin = tmp_combs$smaller, 
      xmax = tmp_combs$larger
    )
  }else{
    tab <- obj$tab
    tab$row <- 1:nrow(tab)
    use_levs <- tab[which(tab$psame == max(tab$psame)), ]
    use_row <- use_levs[which.max(use_levs$easy), "row"]
    tests <- obj$pw_test
    olaps <- obj$ci_tests[,use_row]
    discrep <- which(olaps != tests)
    if(length(discrep) > 0){
      dcombs <- obj$combs[, discrep, drop=FALSE]
      tmp_combs <- data.frame(
        smaller = obj$param_names[obj$combs[2,discrep]], 
        larger = obj$param_names[obj$combs[1,discrep]] 
      ) %>% 
        mutate(ul = upr[larger]) 
      if(is.null(nudge))nudge <- rep(0, nrow(tmp_combs))
      if(length(nudge) != nrow(tmp_combs)){
        warning("nudge must be NULL or same length as number of annotations, nudge being ignored")
        nudge <- rep(0, nrow(tmp_combs))
      } 
      
      list(
        annotations = ifelse(tests[discrep], "*", "NS"), 
        y_position = tmp_combs$ul + .05*rng + nudge, 
        xmin = tmp_combs$smaller, 
        xmax = tmp_combs$larger
      )
    }else{
      stop("No discrepancies between pairwise tests and (non-)overlaps in inferential confidence intervals.\n")
    }
  }
  
}

Try the VizTest package in your browser

Any scripts or data that you put into this service are public.

VizTest documentation built on Dec. 4, 2025, 9:07 a.m.